Skip to content

Aviary agent max_timesteps and fixed test_gather_evidence_rejects_empty_docs #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ async def run_agent(
else:
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.")

if "cannot answer" in answer.answer.lower() and agent_status != AgentStatus.TIMEOUT:
if (
"cannot answer" in answer.answer.lower()
and agent_status != AgentStatus.TRUNCATED
):
agent_status = AgentStatus.UNSURE
# stop after, so overall isn't reported as long-running step.
logger.info(
Expand All @@ -140,6 +143,11 @@ async def run_fake_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
if query.settings.agent.max_timesteps is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jc -- why are we calling the agent steps timesteps? I think of timesteps from physical models (i.e. molecular dynamics) where each iteration is a unit of time, like 4 picoseconds or something. This is more like actionsteps in my mind.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hear what you're saying.

I think one can also encounter vagaries with "step", for example one can wonder does "step" mean:

  • One step = agent selection + environment step
  • Two steps = agent selection + environment step

I went with timestep to exactly match ldp.data_structures.Transition.timestep

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you think we should name it otherwise

logger.warning(
f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not"
" applicable with the fake agent, ignoring it."
)
env = PaperQAEnvironment(query, docs, **env_kwargs)
_, tools = await env.reset()
if on_env_reset_callback:
Expand Down Expand Up @@ -209,7 +217,14 @@ async def run_aviary_agent(
tools=tools,
)

timestep, max_timesteps = 0, query.settings.agent.max_timesteps
while not done:
if max_timesteps is not None and timestep >= max_timesteps:
logger.warning(
f"Agent didn't finish within {max_timesteps} timesteps, just answering."
)
await tools[-1]._tool_fn(question=query.query, state=env.state)
return env.state.answer, AgentStatus.TRUNCATED
agent_state.messages += obs
for attempt in Retrying(
stop=stop_after_attempt(5),
Expand All @@ -226,12 +241,13 @@ async def run_aviary_agent(
obs, reward, done, truncated = await env.step(action)
if on_env_step_callback:
await on_env_step_callback(obs, reward, done, truncated)
timestep += 1
status = AgentStatus.SUCCESS
except TimeoutError:
logger.warning(
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering."
)
status = AgentStatus.TIMEOUT
status = AgentStatus.TRUNCATED
await tools[-1]._tool_fn(question=query.query, state=env.state)
except Exception:
logger.exception(f"Agent {agent} failed.")
Expand All @@ -250,6 +266,11 @@ async def run_ldp_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
if query.settings.agent.max_timesteps is not None:
logger.warning(
f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not"
" yet implemented with the ldp agent, ignoring it."
)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

Expand All @@ -274,7 +295,7 @@ async def run_ldp_agent(
logger.warning(
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering."
)
status = AgentStatus.TIMEOUT
status = AgentStatus.TRUNCATED
await tools[-1]._tool_fn(question=query.query, state=env.state)
except Exception:
logger.exception(f"Agent {agent} failed.")
Expand Down
5 changes: 3 additions & 2 deletions paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ class AgentStatus(StrEnum):
FAIL = "fail"
# SUCCESS - answer was generated
SUCCESS = "success"
# TIMEOUT - agent took too long, but an answer was generated
TIMEOUT = "timeout"
# TRUNCATED - agent didn't finish naturally (e.g. timeout, too many actions),
# so we prematurely answered
TRUNCATED = "truncated"
# UNSURE - the agent was unsure, but an answer is present
UNSURE = "unsure"

Expand Down
4 changes: 4 additions & 0 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ class AgentSettings(BaseModel):
" supplied."
),
)
max_timesteps: int | None = Field(
default=None,
description="Optional upper limit on the number of environment steps.",
)

index_concurrency: int = Field(
default=30,
Expand Down
22 changes: 16 additions & 6 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ async def test_timeout(agent_test_settings: Settings, agent_type: str | type) ->
agent_type=agent_type,
)
# ensure that GenerateAnswerTool was called
assert response.status == AgentStatus.TIMEOUT, "Agent did not timeout"
assert response.status == AgentStatus.TRUNCATED, "Agent did not timeout"
assert "I cannot answer" in response.answer.answer


Expand Down Expand Up @@ -287,18 +287,28 @@ async def test_gather_evidence_rejects_empty_docs() -> None:
# lead to an unsure status, which will break this test's assertions. Since
# this test is about a GatherEvidenceTool edge case, defeating
# GenerateAnswerTool is fine
original_doc = GenerateAnswer.gen_answer.__doc__
with patch.object(
GenerateAnswer, "gen_answer", return_value="Failed to answer question."
):
settings = Settings()
settings.agent.tool_names = {"gather_evidence", "gen_answer"}
GenerateAnswer,
"gen_answer",
return_value="Failed to answer question.",
autospec=True,
) as mock_gen_answer:
mock_gen_answer.__doc__ = original_doc
settings = Settings(
agent=AgentSettings(
tool_names={"gather_evidence", "gen_answer"}, max_timesteps=3
)
)
response = await agent_query(
query=QueryRequest(
query="Are COVID-19 vaccines effective?", settings=settings
),
docs=Docs(),
)
assert response.status == AgentStatus.FAIL, "Agent should have registered a failure"
assert (
response.status == AgentStatus.TRUNCATED
), "Agent should have hit its max timesteps"


@pytest.mark.flaky(reruns=3, only_rerun=["AssertionError", "EmptyDocsError"])
Expand Down