Skip to content

Including paper and evidence counts in metrics #435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"LitQAv2TaskSplit",
]

import logging
import re
from abc import ABC
from collections.abc import Awaitable, Callable, Sequence
from enum import StrEnum
Expand Down Expand Up @@ -42,6 +44,8 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
if TYPE_CHECKING:
from ldp.data_structures import Trajectory

logger = logging.getLogger(__name__)


class GradablePaperQAEnvironment(PaperQAEnvironment):
"""Extended environment that can grade answers."""
Expand Down Expand Up @@ -158,13 +162,38 @@ def _make_gradable_environment(
def compute_trajectory_metrics(
self, trajectories: "Sequence[Trajectory]"
) -> dict[str, list[float]]:
total_paper_count: list[float] = []
relevant_paper_count: list[float] = []
evidence_count: list[float] = []
for t in trajectories:
split_answers = [
re.split(
pattern=GenerateAnswer.ANSWER_SPLIT_REGEX_PATTERN,
string=obs.content,
)
for obs in t.steps[-1].next_observation
if (
isinstance(obs, ToolResponseMessage)
and obs.name == GenerateAnswer.TOOL_FN_NAME
)
]
for i, metric_list in enumerate(
(total_paper_count, relevant_paper_count, evidence_count),
start=1, # Regex extraction of status starts after answer
):
metric_list.append( # Use mean to allow for multiple answers
sum(int(sa[i]) for sa in split_answers) / len(split_answers)
)
return super().compute_trajectory_metrics(trajectories) | {
"total_paper_count": total_paper_count,
"relevant_paper_count": relevant_paper_count,
"evidence_count": evidence_count,
"correct": [
int(traj.steps[-1].reward == self._rewards[0]) for traj in trajectories
int(t.steps[-1].reward == self._rewards[0]) for t in trajectories
],
"correct_unsure": [
int(traj.steps[-1].reward in {self._rewards[0], self._rewards[1]})
for traj in trajectories
int(t.steps[-1].reward in {self._rewards[0], self._rewards[1]})
for t in trajectories
],
}

Expand Down
1 change: 1 addition & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,4 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None:
), "Should not have mutated query in base request"
assert not docs.docs, "Should not have mutated docs in base docs"
assert isinstance(metrics_callback.eval_means["reward"], float)
assert isinstance(metrics_callback.eval_means["total_paper_count"], float)