diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index b6d5da562..517a0d433 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -47,7 +47,6 @@ def settings_to_tools( summary_llm_model = summary_llm_model or settings.get_summary_llm() embedding_model = embedding_model or settings.get_embedding_model() tools: list[Tool] = [] - has_answer_tool = False for tool_type in ( (PaperSearch, GatherEvidence, GenerateAnswer) if settings.agent.tool_names is None @@ -86,14 +85,9 @@ def settings_to_tools( else: raise NotImplementedError(f"Didn't handle tool type {tool_type}.") if tool.info.name == GenerateAnswer.gen_answer.__name__: - tools.append(tool) - has_answer_tool = True + tools.append(tool) # Place at the end else: tools.insert(0, tool) - if not has_answer_tool: - raise ValueError( - f"{GenerateAnswer.gen_answer.__name__} must be one of the tools." - ) return tools diff --git a/paperqa/settings.py b/paperqa/settings.py index 910ac3685..09a994067 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -523,21 +523,6 @@ class AgentSettings(BaseModel): exclude=True, ) - @field_validator("tool_names") - @classmethod - def validate_tool_names(cls, v: set[str] | None) -> set[str] | None: - if v is None: - return None - # imported here to avoid circular imports - from paperqa.agents.tools import GenerateAnswer - - answer_tool_name = GenerateAnswer.TOOL_FN_NAME - if answer_tool_name not in v: - raise ValueError( - f"If using an override, must contain at least the {answer_tool_name}." - ) - return v - @model_validator(mode="after") def _deprecated_field(self) -> Self: for deprecated_field_name, new_name in (("index_concurrency", "concurrency"),): diff --git a/tests/test_agents.py b/tests/test_agents.py index cd95b7a57..5636e8cdb 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -10,7 +10,7 @@ import time from copy import deepcopy from pathlib import Path -from typing import Any, cast +from typing import cast from unittest.mock import AsyncMock, patch from uuid import uuid4 @@ -21,7 +21,6 @@ from ldp.graph.memory import Memory, UIndexMemoryModel from ldp.graph.ops import OpResult from ldp.llms import EmbeddingModel, MultipleCompletionLLMModel -from pydantic import ValidationError from pytest_subtests import SubTests from tantivy import Index @@ -402,8 +401,8 @@ async def test_propagate_options(agent_test_settings: Settings) -> None: async def test_gather_evidence_rejects_empty_docs( agent_test_settings: Settings, ) -> None: - # Patch GenerateAnswerTool._arun so that if this tool is chosen first, we - # don't give a 'cannot answer' response. A 'cannot answer' response can + # Patch GenerateAnswerTool.gen_answer so that if this tool is chosen first, + # we don't give a 'cannot answer' response. A 'cannot answer' response can # lead to an unsure status, which will break this test's assertions. Since # this test is about a GatherEvidenceTool edge case, defeating # GenerateAnswerTool is fine @@ -726,25 +725,6 @@ def test_answers_are_striped() -> None: response.model_dump_json() -@pytest.mark.parametrize( - ("kwargs", "result"), - [ - ({}, None), - ({"tool_names": {GenerateAnswer.TOOL_FN_NAME}}, None), - ({"tool_names": set()}, ValidationError), - ({"tool_names": {PaperSearch.TOOL_FN_NAME}}, ValidationError), - ], -) -def test_agent_prompt_collection_validations( - kwargs: dict[str, Any], result: type[Exception] | None -) -> None: - if result is None: - AgentSettings(**kwargs) - else: - with pytest.raises(result): - AgentSettings(**kwargs) - - class TestGradablePaperQAEnvironment: @pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"]) @pytest.mark.asyncio