Skip to content

Commit b1c01a7

Browse files
jamesbrazasidnarayanan
authored andcommitted
Exposed seeding of LitQA2 read and shuffling (#758)
1 parent 0130233 commit b1c01a7

File tree

4 files changed

+61
-14
lines changed

4 files changed

+61
-14
lines changed

paperqa/agents/task.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def __init__(
196196
base_query: QueryRequest | dict | None = None,
197197
base_docs: Docs | dict | None = None,
198198
rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
199+
question_kwargs: Mapping[str, Any] | None = None,
199200
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
200201
**env_kwargs,
201202
):
@@ -210,23 +211,23 @@ def __init__(
210211
base_docs = Docs(**base_docs)
211212
self._base_docs = base_docs
212213
self._rewards = rewards
213-
self._env_kwargs = env_kwargs
214+
self._question_kwargs = question_kwargs
214215
self._eval_model = eval_model
216+
self._env_kwargs = env_kwargs
215217

216218
def _make_gradable_environment(
217219
self,
218220
ideal: str,
219221
distractors: str | list[str],
220222
question: str,
221-
use_unsure: bool = True,
222223
sources: str | list[str] | None = None,
223224
) -> GradablePaperQAEnvironment:
224225
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
225226
ideal=ideal,
226227
distractors=distractors,
227228
question=question,
228-
use_unsure=use_unsure,
229229
eval_model=self._eval_model,
230+
**(self._question_kwargs or {}),
230231
)
231232
query = self._base_query.model_copy()
232233
query.query = qa_prompt
@@ -305,11 +306,14 @@ def __init__(
305306
self,
306307
*args,
307308
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
309+
read_data_kwargs: Mapping[str, Any] | None = None,
308310
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
309311
**kwargs,
310312
):
311313
super().__init__(*args, **kwargs)
312-
train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset)
314+
train_df, eval_df = read_litqa_v2_from_hub(
315+
labbench_dataset, **(read_data_kwargs or {})
316+
)
313317
split = LitQAv2TaskSplit(split)
314318
if split == LitQAv2TaskSplit.TRAIN:
315319
self.data = train_df

paperqa/litqa.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ast import literal_eval
99
from collections.abc import Awaitable, Callable, Mapping, Sequence
1010
from enum import StrEnum
11-
from typing import TYPE_CHECKING, Self
11+
from typing import TYPE_CHECKING, Literal, Self
1212

1313
try:
1414
from ldp.utils import discounted_returns
@@ -92,6 +92,7 @@ def make_mc_options(
9292

9393
DEFAULT_EVAL_MODEL_NAME = "gpt-4-turbo-2024-04-09"
9494
DEFAULT_REWARD_MAPPING = {"correct": 1.0, "unsure": 0.1, "incorrect": -1.0}
95+
SEED_USING_QUESTION: Literal["SEED_USING_QUESTION"] = "SEED_USING_QUESTION" # Sentinel
9596

9697

9798
class LitQAEvaluation(StrEnum):
@@ -161,7 +162,7 @@ def from_question(
161162
question: str,
162163
use_unsure: bool = True,
163164
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
164-
seed: int | None = None,
165+
seed: int | Literal["SEED_USING_QUESTION"] | None = None,
165166
) -> tuple[str, Callable[[PQASession | str], Awaitable[LitQAEvaluation]]]:
166167
"""
167168
Create a LitQA question and an answer-to-evaluation function.
@@ -174,11 +175,15 @@ def from_question(
174175
eval_model: Evaluation model to use for multiple choice letter extraction
175176
from a text answer.
176177
seed: Optional seed to use in randomization of multiple choice letters.
178+
Optionally pass in the string literal "SEED_USING_QUESTION" to hash the
179+
input question for the seed.
177180
178181
Returns:
179182
Two-tuple of created LitQA question, function (that can be thought of as
180183
stateless) to use to extract an evaluation result from an answer.
181184
"""
185+
if seed == SEED_USING_QUESTION:
186+
seed = hash(question)
182187
text, ideal_answer, unsure_answer, distractor_answers = make_mc_options(
183188
ideal=ideal,
184189
distractors=distractors,

tests/test_litqa.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from paperqa.litqa import LitQAEvaluation, read_litqa_v2_from_hub
6+
from paperqa.litqa import SEED_USING_QUESTION, LitQAEvaluation, read_litqa_v2_from_hub
77
from tests.conftest import VCR_DEFAULT_MATCH_ON
88

99

@@ -140,16 +140,38 @@ def test_consistent_mc_options(self) -> None:
140140
"""Tests that creating multiple evaluations with the same seed results in the same prompt."""
141141
question, ideal, distractors = self.MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS
142142

143-
qa_prompt_1, _ = LitQAEvaluation.from_question(
143+
qa_prompt_1a, _ = LitQAEvaluation.from_question(
144144
ideal=ideal, distractors=distractors, question=question, seed=0
145145
)
146-
self._assert_prompt_is_valid(qa_prompt_1, question, ideal, distractors)
146+
self._assert_prompt_is_valid(qa_prompt_1a, question, ideal, distractors)
147147

148-
qa_prompt_2, _ = LitQAEvaluation.from_question(
148+
qa_prompt_1b, _ = LitQAEvaluation.from_question(
149149
ideal=ideal, distractors=distractors, question=question, seed=0
150150
)
151-
self._assert_prompt_is_valid(qa_prompt_1, question, ideal, distractors)
152-
assert qa_prompt_1 == qa_prompt_2
151+
self._assert_prompt_is_valid(qa_prompt_1b, question, ideal, distractors)
152+
assert qa_prompt_1a == qa_prompt_1b, "Same seeding should lead to same prompts"
153+
154+
qa_prompt_2a, _ = LitQAEvaluation.from_question(
155+
ideal=ideal,
156+
distractors=distractors,
157+
question=question,
158+
seed=SEED_USING_QUESTION,
159+
)
160+
self._assert_prompt_is_valid(qa_prompt_2a, question, ideal, distractors)
161+
162+
qa_prompt_2b, _ = LitQAEvaluation.from_question(
163+
ideal=ideal,
164+
distractors=distractors,
165+
question=question,
166+
seed=SEED_USING_QUESTION,
167+
)
168+
self._assert_prompt_is_valid(qa_prompt_2b, question, ideal, distractors)
169+
assert (
170+
qa_prompt_2a == qa_prompt_2b
171+
), "Same seeding strategy should lead to same prompts"
172+
assert (
173+
qa_prompt_2a != qa_prompt_1a
174+
), "Different seeding strategies should lead to different prompts"
153175

154176
def test_creating_litqa_questions(self) -> None:
155177
"""Test making LitQA eval questions after downloading from Hugging Face Hub."""

tests/test_task.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
LitQAv2TaskSplit,
2121
)
2222
from paperqa.agents.tools import GenerateAnswer
23-
from paperqa.litqa import DEFAULT_REWARD_MAPPING, LitQAEvaluation
23+
from paperqa.litqa import DEFAULT_REWARD_MAPPING, SEED_USING_QUESTION, LitQAEvaluation
2424

2525

2626
@pytest.fixture(name="base_query_request")
@@ -103,12 +103,27 @@ async def test___len__(
103103
expected_length: int,
104104
base_query_request: QueryRequest,
105105
) -> None:
106-
task_dataset = LitQAv2TaskDataset(base_query=base_query_request, split=split)
106+
task_dataset = LitQAv2TaskDataset(
107+
base_query=base_query_request,
108+
question_kwargs={"seed": 42},
109+
read_data_kwargs={"seed": 42},
110+
split=split,
111+
)
107112
assert len(task_dataset) == expected_length
108113

109114
# Now let's check we could use the sources in a validation
110115
for i in range(len(task_dataset)):
111116
env = task_dataset.get_new_env_by_idx(i)
117+
if i == 0 and split == LitQAv2TaskSplit.TRAIN:
118+
# Yes this assertion is somewhat brittle, but it reliably
119+
# checks the seeding's behavior so we keep it
120+
obs, _ = await env.reset()
121+
assert (
122+
"Q: SLC14A1 been identified as a specific marker for endothelial"
123+
" cells in which organ?\n\nOptions:\nA) heart\nB) eye\nC)"
124+
" prostate\nD) Insufficient information to answer this question\nE)"
125+
" liver" in (obs[0].content or "")
126+
)
112127
assert env.sources, "Sources need to be accessible"
113128
assert isinstance(
114129
env.sources, Iterable
@@ -144,6 +159,7 @@ async def test_evaluation(
144159
"deleted_dockeys",
145160
}
146161
),
162+
"question_kwargs": {"seed": SEED_USING_QUESTION},
147163
},
148164
)
149165
# NOTE: set base_query after construction of the TaskConfig. because in

0 commit comments

Comments
 (0)