Skip to content

Refactored TestGradablePaperQAEnvironment for DRY code #702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 80 additions & 78 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tantivy import Index

from paperqa.agents import SearchIndex, agent_query
from paperqa.agents.env import PaperQAEnvironment, settings_to_tools
from paperqa.agents.env import settings_to_tools
from paperqa.agents.main import FAKE_AGENT_TYPE
from paperqa.agents.models import AgentStatus, AnswerResponse, QueryRequest
from paperqa.agents.search import (
Expand Down Expand Up @@ -750,121 +750,86 @@ def test_answers_are_striped() -> None:
response.model_dump_json()


@pytest.mark.asyncio
async def test_sequential_tool_calls(agent_test_settings: Settings):

SLEEP_TIME = 2.0

async def fake_gather_evidence(*args, **kwargs) -> str: # noqa: ARG001
await asyncio.sleep(SLEEP_TIME)
return "fake evidence"

question = "How can you use XAI for chemical property prediction?"
env = PaperQAEnvironment(
query=QueryRequest(query=question, settings=agent_test_settings),
@pytest.fixture(name="stub_gradable_env")
def fixture_stub_gradable_env(
agent_test_settings: Settings,
) -> GradablePaperQAEnvironment:
return GradablePaperQAEnvironment(
query=QueryRequest(
query="How can you use XAI for chemical property prediction?",
settings=agent_test_settings,
),
docs=Docs(),
)
await env.reset()

gather_tool = next(
tool for tool in env.tools if tool.info.name == GatherEvidence.TOOL_FN_NAME
)

with patch.object(gather_tool, "_tool_fn", fake_gather_evidence):
tic = time.time()
await env.step(
ToolRequestMessage(
tool_calls=[
ToolCall.from_name(
"gather_evidence",
question="XAI for chemical property prediction",
),
ToolCall.from_name(
"gather_evidence",
question="XAI for chemical property prediction",
),
]
)
)

assert time.time() - tic > 2 * SLEEP_TIME # since they are sequential


class TestGradablePaperQAEnvironment:
@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"])
@pytest.mark.asyncio
async def test_deepcopy_env(self, agent_test_settings: Settings) -> None:
async def test_deepcopy_env(
self,
agent_test_settings: Settings,
stub_gradable_env: GradablePaperQAEnvironment,
) -> None:
await get_directory_index(settings=agent_test_settings) # Trigger build

question = "How can you use XAI for chemical property prediction?"
env = GradablePaperQAEnvironment(
query=QueryRequest(query=question, settings=agent_test_settings),
docs=Docs(),
)

# 1. Rollout until after gather evidence
await env.reset()
await stub_gradable_env.reset()
for tool_call in (
ToolCall.from_name(
"paper_search",
query="XAI for chemical property prediction",
min_year=2018,
max_year=2024,
),
ToolCall.from_name("gather_evidence", question=question),
ToolCall.from_name(
"gather_evidence", question=stub_gradable_env._query.query
),
):
await env.step(ToolRequestMessage(tool_calls=[tool_call]))
await stub_gradable_env.step(ToolRequestMessage(tool_calls=[tool_call]))

# 2. Now we deepcopy the environment
env_copy = deepcopy(env)
assert env.state == env_copy.state
stub_gradable_env_copy = deepcopy(stub_gradable_env)
assert stub_gradable_env.state == stub_gradable_env_copy.state

# 3. Generate an answer and complete for both, and confirm they are identical
gen_answer_action = ToolRequestMessage(
tool_calls=[ToolCall.from_name("gen_answer")]
)
await env.step(gen_answer_action)
_, _, done, _ = await env.step(
await stub_gradable_env.step(gen_answer_action)
_, _, done, _ = await stub_gradable_env.step(
ToolRequestMessage(tool_calls=[ToolCall.from_name("complete")])
)
assert done
assert not env.state.session.could_not_answer
assert env.state.session.used_contexts
await env_copy.step(gen_answer_action)
_, _, done, _ = await env.step(
assert not stub_gradable_env.state.session.could_not_answer
assert stub_gradable_env.state.session.used_contexts
await stub_gradable_env_copy.step(gen_answer_action)
_, _, done, _ = await stub_gradable_env.step(
ToolRequestMessage(tool_calls=[ToolCall.from_name("complete")])
)
assert done
assert not env_copy.state.session.could_not_answer
assert env_copy.state.session.used_contexts
assert sorted(env.state.session.used_contexts) == sorted(
env_copy.state.session.used_contexts
assert not stub_gradable_env_copy.state.session.could_not_answer
assert stub_gradable_env_copy.state.session.used_contexts
assert sorted(stub_gradable_env.state.session.used_contexts) == sorted(
stub_gradable_env_copy.state.session.used_contexts
)

@pytest.mark.asyncio
async def test_empty_tool_calls(self, agent_test_settings: Settings) -> None:
env = GradablePaperQAEnvironment(
query=QueryRequest(
query="How can you use XAI for chemical property prediction?",
settings=agent_test_settings,
),
docs=Docs(),
)

await env.reset()
obs, _, done, truncated = await env.step(ToolRequestMessage())
async def test_empty_tool_calls(
self, stub_gradable_env: GradablePaperQAEnvironment
) -> None:
await stub_gradable_env.reset()
obs, _, done, truncated = await stub_gradable_env.step(ToolRequestMessage())
assert len(obs) == 1
assert obs[0].content
assert "no tool calls" in obs[0].content.lower()
assert not done
assert not truncated

@pytest.mark.asyncio
async def test_unsure_answer(self, agent_test_settings: Settings) -> None:
env = GradablePaperQAEnvironment(
query=QueryRequest(query="stub", settings=agent_test_settings),
docs=Docs(),
)
async def test_unsure_answer(
self, stub_gradable_env: GradablePaperQAEnvironment
) -> None:
unsure_answer = "Based on the sources provided, it appears no one has done x."

async def emulate_answered_but_unsure(
Expand All @@ -873,9 +838,11 @@ async def emulate_answered_but_unsure(
query.answer = unsure_answer
return query

await env.reset()
with patch.object(type(env.state.docs), "aquery", emulate_answered_but_unsure):
obs, _, done, truncated = await env.step(
await stub_gradable_env.reset()
with patch.object(
type(stub_gradable_env.state.docs), "aquery", emulate_answered_but_unsure
):
obs, _, done, truncated = await stub_gradable_env.step(
ToolRequestMessage(tool_calls=[ToolCall.from_name("gen_answer")])
)
assert len(obs) == 1
Expand All @@ -884,3 +851,38 @@ async def emulate_answered_but_unsure(
assert unsure_answer in obs[0].content
assert not done
assert not truncated

@pytest.mark.asyncio
async def test_sequential_tool_calls(
self, stub_gradable_env: GradablePaperQAEnvironment
) -> None:
SLEEP_TIME = 2.0

async def fake_gather_evidence(*args, **kwargs) -> str: # noqa: ARG001
await asyncio.sleep(SLEEP_TIME)
return "fake evidence"

_, tools = await stub_gradable_env.reset()

gather_tool = next(
tool for tool in tools if tool.info.name == GatherEvidence.TOOL_FN_NAME
)

with patch.object(gather_tool, "_tool_fn", fake_gather_evidence):
tic = time.time()
await stub_gradable_env.step(
ToolRequestMessage(
tool_calls=[
ToolCall.from_name(
"gather_evidence",
question="XAI for chemical property prediction",
),
ToolCall.from_name(
"gather_evidence",
question="XAI for chemical property prediction",
),
]
)
)

assert time.time() - tic > 2 * SLEEP_TIME # since they are sequential