Skip to content

Commit 49273f2

Browse files
authored
Creating LitQAv2TaskDataset for agent training/evaluation (#401)
1 parent b1745b0 commit 49273f2

File tree

4 files changed

+316
-8
lines changed

4 files changed

+316
-8
lines changed

paperqa/agents/env.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import logging
22
from typing import cast
33

4-
from aviary.env import Environment as _Environment
5-
from aviary.env import Frame
4+
from aviary.env import Environment, Frame
65
from aviary.message import Message
76
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage
87

@@ -91,8 +90,8 @@ def settings_to_tools(
9190
return tools
9291

9392

94-
class Environment(_Environment[EnvironmentState]):
95-
"""Environment to connect agents with paper-qa."""
93+
class PaperQAEnvironment(Environment[EnvironmentState]):
94+
"""Environment connecting paper-qa's tools with state."""
9695

9796
def __init__(
9897
self,

paperqa/agents/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from paperqa.types import Answer
4747
from paperqa.utils import pqa_directory
4848

49-
from .env import Environment
49+
from .env import PaperQAEnvironment
5050
from .helpers import litellm_get_search_query, table_formatter
5151
from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler
5252
from .search import SearchDocumentStorage, SearchIndex
@@ -235,7 +235,7 @@ async def run_fake_agent(
235235
) = None,
236236
**env_kwargs,
237237
) -> tuple[Answer, AgentStatus]:
238-
env = Environment(query, docs, **env_kwargs)
238+
env = PaperQAEnvironment(query, docs, **env_kwargs)
239239
_, tools = await env.reset()
240240
if on_env_reset_callback:
241241
await on_env_reset_callback(env.state)
@@ -281,7 +281,7 @@ async def run_aviary_agent(
281281
) = None,
282282
**env_kwargs,
283283
) -> tuple[Answer, AgentStatus]:
284-
env = Environment(query, docs, **env_kwargs)
284+
env = PaperQAEnvironment(query, docs, **env_kwargs)
285285
done = False
286286

287287
try:
@@ -345,7 +345,7 @@ async def run_ldp_agent(
345345
) = None,
346346
**env_kwargs,
347347
) -> tuple[Answer, AgentStatus]:
348-
env = Environment(query, docs, **env_kwargs)
348+
env = PaperQAEnvironment(query, docs, **env_kwargs)
349349
done = False
350350

351351
try:

paperqa/agents/task.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
__all__ = [
2+
"ENV_NAME",
3+
"TASK_DATASET_NAME",
4+
"GradablePaperQAEnvironment",
5+
"LitQATaskDataset",
6+
"LitQAv2TaskDataset",
7+
"LitQAv2TaskSplit",
8+
]
9+
10+
from abc import ABC
11+
from collections.abc import Awaitable, Callable, Sequence
12+
from enum import StrEnum
13+
from typing import TYPE_CHECKING, assert_never
14+
15+
from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset
16+
from aviary.message import Message
17+
from aviary.tools import ToolRequestMessage, ToolResponseMessage
18+
19+
try:
20+
from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin
21+
except ImportError:
22+
23+
class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
24+
"""Placeholder for when ldp isn't installed."""
25+
26+
27+
from paperqa.docs import Docs
28+
from paperqa.litqa import (
29+
DEFAULT_EVAL_MODEL_NAME,
30+
DEFAULT_LABBENCH_HF_HUB_NAME,
31+
DEFAULT_REWARD_DISTRIBUTION,
32+
LitQAEvaluation,
33+
read_litqa_v2_from_hub,
34+
)
35+
from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel
36+
from paperqa.types import Answer
37+
38+
from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
39+
from .models import QueryRequest
40+
from .tools import GenerateAnswer
41+
42+
if TYPE_CHECKING:
43+
from ldp.data_structures import Trajectory
44+
45+
46+
class GradablePaperQAEnvironment(PaperQAEnvironment):
47+
"""Extended environment that can grade answers."""
48+
49+
def __init__(
50+
self,
51+
query: QueryRequest,
52+
docs: Docs,
53+
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
54+
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
55+
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
56+
evaluation_from_answer: (
57+
Callable[[Answer | str], Awaitable[LitQAEvaluation]] | None
58+
) = None,
59+
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
60+
evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None,
61+
**env_kwargs,
62+
):
63+
super().__init__(
64+
query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs
65+
)
66+
self._evaluation_from_answer = evaluation_from_answer
67+
self._evaluation_callback = evaluation_callback
68+
self._rewards = rewards
69+
70+
async def step(
71+
self, action: ToolRequestMessage
72+
) -> tuple[list[Message], float, bool, bool]:
73+
messages, reward, done, truncated = await super().step(action)
74+
if not done or not self._evaluation_from_answer:
75+
return messages, reward, done, truncated
76+
# Filter out non-answer messages (in case parallel tool calls)
77+
answer_tool_messages = [
78+
m
79+
for m in messages
80+
if isinstance(m, ToolResponseMessage)
81+
and m.name == GenerateAnswer.gen_answer.__name__
82+
]
83+
if not answer_tool_messages: # No answer, so no positive reward
84+
return messages, reward, done, truncated
85+
if len(answer_tool_messages) != 1:
86+
raise NotImplementedError(
87+
f"Expected just one answer message, got {messages}."
88+
)
89+
answer = GenerateAnswer.extract_answer_from_message(
90+
content=answer_tool_messages[0].content
91+
)
92+
if not answer:
93+
return messages, reward, done, truncated
94+
evaluation = await self._evaluation_from_answer(answer)
95+
if evaluation_callback := self._evaluation_callback:
96+
await evaluation_callback(evaluation)
97+
return messages, reward + self._rewards[evaluation.value], done, truncated
98+
99+
def export_frame(self) -> Frame:
100+
raise NotImplementedError("Didn't yet need to export a frame.")
101+
102+
103+
ENV_NAME = "paperqa-local"
104+
ENV_REGISTRY[ENV_NAME] = (
105+
GradablePaperQAEnvironment.__module__,
106+
GradablePaperQAEnvironment.__name__,
107+
)
108+
109+
110+
class LitQATaskDataset(
111+
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
112+
):
113+
"""
114+
Abstract base class for a task dataset of LitQA v1 or v2 questions.
115+
116+
This is an ABC because it's non-specific to a LitQA version.
117+
Examples include LitQA v1, v2, or a test stub version of LitQA.
118+
"""
119+
120+
def __init__(
121+
self,
122+
base_query_request: QueryRequest,
123+
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
124+
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
125+
**env_kwargs,
126+
):
127+
self._base_query_request = base_query_request
128+
self._rewards = rewards
129+
self._env_kwargs = env_kwargs
130+
self._eval_model = eval_model
131+
132+
def _make_gradable_environment(
133+
self,
134+
ideal: str,
135+
distractors: str | list[str],
136+
question: str,
137+
use_unsure: bool = True,
138+
) -> GradablePaperQAEnvironment:
139+
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
140+
ideal=ideal,
141+
distractors=distractors,
142+
question=question,
143+
use_unsure=use_unsure,
144+
eval_model=self._eval_model,
145+
)
146+
query_request = self._base_query_request.model_copy()
147+
query_request.query = qa_prompt
148+
return GradablePaperQAEnvironment(
149+
query=query_request,
150+
evaluation_from_answer=evaluation_from_answer,
151+
rewards=self._rewards,
152+
**self._env_kwargs,
153+
)
154+
155+
def compute_trajectory_metrics(
156+
self, trajectories: "Sequence[Trajectory]"
157+
) -> dict[str, list[float]]:
158+
return super().compute_trajectory_metrics(trajectories) | {
159+
"correct": [
160+
int(traj.steps[-1].reward == self._rewards[0]) for traj in trajectories
161+
],
162+
"correct_unsure": [
163+
int(traj.steps[-1].reward in {self._rewards[0], self._rewards[1]})
164+
for traj in trajectories
165+
],
166+
}
167+
168+
169+
class LitQAv2TaskSplit(StrEnum):
170+
TRAIN = "train"
171+
EVAL = "eval"
172+
173+
174+
class LitQAv2TaskDataset(LitQATaskDataset):
175+
"""Task dataset of LitQA v2 questions."""
176+
177+
def __init__(
178+
self,
179+
*args,
180+
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
181+
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
182+
**kwargs,
183+
):
184+
super().__init__(*args, **kwargs)
185+
train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset)
186+
split = LitQAv2TaskSplit(split)
187+
if split == LitQAv2TaskSplit.TRAIN:
188+
self.data = train_df
189+
elif split == LitQAv2TaskSplit.EVAL:
190+
self.data = eval_df
191+
else:
192+
assert_never(split)
193+
194+
def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
195+
return self._make_gradable_environment(
196+
ideal=self.data.iloc[idx].ideal,
197+
distractors=self.data.iloc[idx].distractors,
198+
question=self.data.iloc[idx].question,
199+
)
200+
201+
def __len__(self) -> int:
202+
return len(self.data)
203+
204+
205+
TASK_DATASET_NAME = "litqa-v2"
206+
TASK_DATASET_REGISTRY[TASK_DATASET_NAME] = (
207+
LitQAv2TaskDataset.__module__,
208+
LitQAv2TaskDataset.__name__,
209+
)

tests/test_task.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import pytest
2+
from aviary.env import TASK_DATASET_REGISTRY, TaskDataset
3+
from ldp.agent import SimpleAgent
4+
from ldp.alg.callbacks import MeanMetricsCallback
5+
from ldp.alg.runners import Evaluator, EvaluatorConfig
6+
7+
from paperqa import Docs, QueryRequest, Settings
8+
from paperqa.agents.task import (
9+
GradablePaperQAEnvironment,
10+
LitQATaskDataset,
11+
LitQAv2TaskDataset,
12+
LitQAv2TaskSplit,
13+
)
14+
15+
16+
@pytest.fixture(name="base_query_request")
17+
def fixture_base_query_request(agent_test_settings: Settings) -> QueryRequest:
18+
return QueryRequest(settings=agent_test_settings)
19+
20+
21+
class StubLitQADataset(LitQATaskDataset):
22+
"""Made up dataset of questions answerable from this repo's stub_data."""
23+
24+
def __init__(self, *args, **kwargs):
25+
super().__init__(*args, **kwargs)
26+
self.data: list[tuple[str, str | list[str], str]] = [
27+
("Politician", ["Technologist", "Plumber"], "Who is Frederick Bates?"),
28+
(
29+
"Make molecular counterfactuals",
30+
[
31+
"Generating images of cats",
32+
"Simple explanations of internet searches",
33+
],
34+
"How can you use XAI for chemical property prediction?",
35+
),
36+
(
37+
"Maple Leaf",
38+
["The Stars and Stripes", "The Blue and Yellow", "The Southern Cross"],
39+
"What is the national flag of Canada?",
40+
),
41+
]
42+
43+
def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
44+
return self._make_gradable_environment(
45+
ideal=self.data[idx][0],
46+
distractors=self.data[idx][1],
47+
question=self.data[idx][2],
48+
)
49+
50+
def __len__(self) -> int:
51+
return len(self.data)
52+
53+
54+
STUB_TASK_DATASET_NAME = "stub-litqa"
55+
TASK_DATASET_REGISTRY[STUB_TASK_DATASET_NAME] = (
56+
StubLitQADataset.__module__,
57+
StubLitQADataset.__name__,
58+
)
59+
60+
61+
class TestTaskDataset:
62+
@pytest.mark.parametrize(
63+
("split", "expected_length"),
64+
[(LitQAv2TaskSplit.TRAIN, 159), (LitQAv2TaskSplit.EVAL, 40)],
65+
)
66+
def test___len__(
67+
self,
68+
split: str | LitQAv2TaskSplit,
69+
expected_length: int,
70+
base_query_request: QueryRequest,
71+
) -> None:
72+
task_dataset = LitQAv2TaskDataset(
73+
base_query_request=base_query_request, split=split
74+
)
75+
assert len(task_dataset) == expected_length
76+
77+
@pytest.mark.asyncio
78+
async def test_evaluation(self, base_query_request: QueryRequest) -> None:
79+
agent = SimpleAgent()
80+
docs = Docs()
81+
dataset = TaskDataset.from_name(
82+
STUB_TASK_DATASET_NAME,
83+
base_query_request=base_query_request,
84+
docs=docs,
85+
)
86+
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)
87+
88+
evaluator = Evaluator(
89+
config=EvaluatorConfig(batch_size=3),
90+
agent=agent,
91+
dataset=dataset,
92+
callbacks=[metrics_callback],
93+
)
94+
await evaluator.evaluate()
95+
96+
assert (
97+
not base_query_request.query
98+
), "Should not have mutated query in base request"
99+
assert docs.docs, "Expected to have added content"
100+
assert isinstance(metrics_callback.eval_means["reward"], float)

0 commit comments

Comments
 (0)