|
| 1 | +"""Regression tests for symlink rejection in prompt/agent integrators. |
| 2 | +
|
| 3 | +Verifies that find_prompt_files() and find_agent_files() reject symlinks, |
| 4 | +preventing supply-chain file disclosure attacks via malicious APM packages. |
| 5 | +""" |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +import os |
| 10 | +from pathlib import Path |
| 11 | + |
| 12 | +import pytest |
| 13 | + |
| 14 | +from apm_cli.integration.agent_integrator import AgentIntegrator |
| 15 | +from apm_cli.integration.prompt_integrator import PromptIntegrator |
| 16 | + |
| 17 | + |
| 18 | +@pytest.fixture |
| 19 | +def package_with_symlinks(tmp_path: Path) -> Path: |
| 20 | + """Create a fixture package with symlinks under .apm/ directories.""" |
| 21 | + pkg = tmp_path / "pkg" |
| 22 | + (pkg / ".apm" / "prompts").mkdir(parents=True) |
| 23 | + (pkg / ".apm" / "agents").mkdir(parents=True) |
| 24 | + (pkg / ".apm" / "chatmodes").mkdir(parents=True) |
| 25 | + |
| 26 | + # Create a sentinel file outside the package |
| 27 | + sentinel = tmp_path / "sentinel.txt" |
| 28 | + sentinel.write_text("REGRESSION-SENTINEL-CONTENT") |
| 29 | + |
| 30 | + # Create legitimate files |
| 31 | + (pkg / ".apm" / "prompts" / "legit.prompt.md").write_text("legit prompt") |
| 32 | + (pkg / ".apm" / "agents" / "legit.agent.md").write_text("legit agent") |
| 33 | + (pkg / ".apm" / "chatmodes" / "legit.chatmode.md").write_text("legit chatmode") |
| 34 | + |
| 35 | + # Create symlinks pointing outside |
| 36 | + (pkg / ".apm" / "prompts" / "leak.prompt.md").symlink_to(sentinel) |
| 37 | + (pkg / ".apm" / "agents" / "leak.agent.md").symlink_to(sentinel) |
| 38 | + (pkg / ".apm" / "chatmodes" / "leak.chatmode.md").symlink_to(sentinel) |
| 39 | + |
| 40 | + # Create a symlink with absolute path target |
| 41 | + (pkg / "abs.agent.md").symlink_to(sentinel) |
| 42 | + |
| 43 | + return pkg |
| 44 | + |
| 45 | + |
| 46 | +class TestPromptIntegratorSymlinkRejection: |
| 47 | + """Verify PromptIntegrator rejects symlinked files.""" |
| 48 | + |
| 49 | + def test_find_prompt_files_excludes_symlinks(self, package_with_symlinks: Path) -> None: |
| 50 | + integrator = PromptIntegrator() |
| 51 | + result = integrator.find_prompt_files(package_with_symlinks) |
| 52 | + |
| 53 | + # Should find the legit file but not the symlink |
| 54 | + assert all(not p.is_symlink() for p in result) |
| 55 | + assert not any(p.name == "leak.prompt.md" for p in result) |
| 56 | + assert any(p.name == "legit.prompt.md" for p in result) |
| 57 | + |
| 58 | + def test_copy_prompt_rejects_symlink_source( |
| 59 | + self, package_with_symlinks: Path, tmp_path: Path |
| 60 | + ) -> None: |
| 61 | + integrator = PromptIntegrator() |
| 62 | + symlink_source = package_with_symlinks / ".apm" / "prompts" / "leak.prompt.md" |
| 63 | + target = tmp_path / "output.prompt.md" |
| 64 | + |
| 65 | + with pytest.raises(ValueError, match=r"symlink"): |
| 66 | + integrator.copy_prompt(symlink_source, target) |
| 67 | + |
| 68 | + |
| 69 | +class TestAgentIntegratorSymlinkRejection: |
| 70 | + """Verify AgentIntegrator rejects symlinked files.""" |
| 71 | + |
| 72 | + def test_find_agent_files_excludes_symlinks(self, package_with_symlinks: Path) -> None: |
| 73 | + integrator = AgentIntegrator() |
| 74 | + result = integrator.find_agent_files(package_with_symlinks) |
| 75 | + |
| 76 | + # Should find legit files but not symlinks |
| 77 | + assert all(not p.is_symlink() for p in result) |
| 78 | + assert not any(p.name == "leak.agent.md" for p in result) |
| 79 | + assert not any(p.name == "leak.chatmode.md" for p in result) |
| 80 | + assert not any(p.name == "abs.agent.md" for p in result) |
| 81 | + assert any(p.name == "legit.agent.md" for p in result) |
| 82 | + assert any(p.name == "legit.chatmode.md" for p in result) |
| 83 | + |
| 84 | + def test_copy_agent_rejects_symlink_source( |
| 85 | + self, package_with_symlinks: Path, tmp_path: Path |
| 86 | + ) -> None: |
| 87 | + integrator = AgentIntegrator() |
| 88 | + symlink_source = package_with_symlinks / ".apm" / "agents" / "leak.agent.md" |
| 89 | + target = tmp_path / "output.agent.md" |
| 90 | + |
| 91 | + with pytest.raises(ValueError, match=r"symlink"): |
| 92 | + integrator.copy_agent(symlink_source, target) |
| 93 | + |
| 94 | + |
| 95 | +class TestHardlinkRejection: |
| 96 | + """Verify integrators reject hardlinked files.""" |
| 97 | + |
| 98 | + @pytest.mark.skipif(os.name == "nt", reason="Hardlinks may require privileges on Windows") |
| 99 | + def test_find_prompt_files_excludes_hardlinks(self, tmp_path: Path) -> None: |
| 100 | + pkg = tmp_path / "pkg" |
| 101 | + (pkg / ".apm" / "prompts").mkdir(parents=True) |
| 102 | + |
| 103 | + # Create a file and a hardlink to it |
| 104 | + original = tmp_path / "original.txt" |
| 105 | + original.write_text("hardlink content") |
| 106 | + hardlink = pkg / ".apm" / "prompts" / "linked.prompt.md" |
| 107 | + os.link(original, hardlink) |
| 108 | + |
| 109 | + integrator = PromptIntegrator() |
| 110 | + result = integrator.find_prompt_files(pkg) |
| 111 | + |
| 112 | + # Hardlink has st_nlink > 1, should be rejected |
| 113 | + assert not any(p.name == "linked.prompt.md" for p in result) |
0 commit comments