-
Notifications
You must be signed in to change notification settings - Fork 749
Aviary agent max_timesteps
and fixed test_gather_evidence_rejects_empty_docs
#515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,7 +121,10 @@ async def run_agent( | |
else: | ||
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.") | ||
|
||
if "cannot answer" in answer.answer.lower() and agent_status != AgentStatus.TIMEOUT: | ||
if ( | ||
"cannot answer" in answer.answer.lower() | ||
and agent_status != AgentStatus.TRUNCATED | ||
): | ||
agent_status = AgentStatus.UNSURE | ||
# stop after, so overall isn't reported as long-running step. | ||
logger.info( | ||
|
@@ -140,6 +143,11 @@ async def run_fake_agent( | |
) = None, | ||
**env_kwargs, | ||
) -> tuple[Answer, AgentStatus]: | ||
if query.settings.agent.max_timesteps is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. jc -- why are we calling the agent steps timesteps? I think of timesteps from physical models (i.e. molecular dynamics) where each iteration is a unit of time, like 4 picoseconds or something. This is more like actionsteps in my mind. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hear what you're saying. I think one can also encounter vagaries with "step", for example one can wonder does "step" mean:
I went with timestep to exactly match There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me know if you think we should name it otherwise |
||
logger.warning( | ||
f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not" | ||
" applicable with the fake agent, ignoring it." | ||
) | ||
env = PaperQAEnvironment(query, docs, **env_kwargs) | ||
_, tools = await env.reset() | ||
if on_env_reset_callback: | ||
|
@@ -209,7 +217,14 @@ async def run_aviary_agent( | |
tools=tools, | ||
) | ||
|
||
timestep, max_timesteps = 0, query.settings.agent.max_timesteps | ||
while not done: | ||
if max_timesteps is not None and timestep >= max_timesteps: | ||
logger.warning( | ||
f"Agent didn't finish within {max_timesteps} timesteps, just answering." | ||
) | ||
await tools[-1]._tool_fn(question=query.query, state=env.state) | ||
return env.state.answer, AgentStatus.TRUNCATED | ||
agent_state.messages += obs | ||
for attempt in Retrying( | ||
stop=stop_after_attempt(5), | ||
|
@@ -226,12 +241,13 @@ async def run_aviary_agent( | |
obs, reward, done, truncated = await env.step(action) | ||
if on_env_step_callback: | ||
await on_env_step_callback(obs, reward, done, truncated) | ||
timestep += 1 | ||
status = AgentStatus.SUCCESS | ||
except TimeoutError: | ||
logger.warning( | ||
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering." | ||
) | ||
status = AgentStatus.TIMEOUT | ||
status = AgentStatus.TRUNCATED | ||
await tools[-1]._tool_fn(question=query.query, state=env.state) | ||
except Exception: | ||
logger.exception(f"Agent {agent} failed.") | ||
|
@@ -250,6 +266,11 @@ async def run_ldp_agent( | |
) = None, | ||
**env_kwargs, | ||
) -> tuple[Answer, AgentStatus]: | ||
if query.settings.agent.max_timesteps is not None: | ||
logger.warning( | ||
f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not" | ||
" yet implemented with the ldp agent, ignoring it." | ||
) | ||
env = PaperQAEnvironment(query, docs, **env_kwargs) | ||
done = False | ||
|
||
|
@@ -274,7 +295,7 @@ async def run_ldp_agent( | |
logger.warning( | ||
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering." | ||
) | ||
status = AgentStatus.TIMEOUT | ||
status = AgentStatus.TRUNCATED | ||
await tools[-1]._tool_fn(question=query.query, state=env.state) | ||
except Exception: | ||
logger.exception(f"Agent {agent} failed.") | ||
|
Uh oh!
There was an error while loading. Please reload this page.