Skip to content

Commit d5c4688

Browse files
authored
Adding filters to paper-qa Docs (#707)
This adds a new filter mechanism to exclude papers from the `Docs` object via settings. For example, to exclude a specific DOI ```py settings = Settings() settings.parsing.doc_filters = [{"!doi": "xxxx/xxxxxx"}] ``` Or to only consider years 2020 and 2018: ```py settings.parsing.doc_filters = [ {"year": "2020"}, {"year": "2018"}] ``` Description: Optional filters to only allow documents that match this filter. This is a dictionary where the keys are the fields from DocDetails or Docs to filter on, and the values are the values to filter for. to invert filter, prefix the key with a '!'. If the key is not found, by default the Doc is rejected. To change this behavior, prefix the key with a '?' to allow the Doc to pass if the key is not found. For example, {'!title': 'bad title', '?year': '2022'} would only allow Docs with a title that is not 'bad title' and a year of 2022 or no year at all.
1 parent 11f2727 commit d5c4688

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

paperqa/docs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ async def aadd_texts(
436436
"""
437437
Add chunked texts to the collection.
438438
439-
NOTE: this is useful if you have already chunked the texts yourself.
439+
This is useful to use if you have already chunked the texts yourself.
440440
441441
Returns:
442442
True if the doc was added, otherwise False if already in the collection.
@@ -451,6 +451,11 @@ async def aadd_texts(
451451
# want to embed now!
452452
embedding_model = all_settings.get_embedding_model()
453453

454+
# 0. Short-circuit if it is caught by a filter
455+
for doc_filter in all_settings.parsing.doc_filters or []:
456+
if not doc.matches_filter_criteria(doc_filter):
457+
return False
458+
454459
# 1. Calculate text embeddings if not already present
455460
if embedding_model and texts[0].embedding is None:
456461
for t, t_embedding in zip(

paperqa/settings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,17 @@ class ParsingSettings(BaseModel):
196196
),
197197
)
198198
chunking_algorithm: ChunkingOptions = ChunkingOptions.SIMPLE_OVERLAP
199+
doc_filters: list[dict] | None = Field(
200+
default=None,
201+
description=(
202+
"Optional filters to only allow documents that match this filter. This is a dictionary where the keys"
203+
" are the fields from DocDetails or Docs to filter on, and the values are the values to filter for."
204+
" To invert filter, prefix the key with a '!'. If the key is not found, by default the Doc is rejected."
205+
" To change this behavior, prefix the key with a '?' to allow the Doc to pass if the key is not found."
206+
" For example, {'!title': 'bad title', '?year': '2022'} would only allow Docs with a title that is not"
207+
" 'bad title' and a year of 2022 or no year at all."
208+
),
209+
)
199210

200211
def chunk_type(self, chunking_selection: ChunkingOptions | None = None) -> str:
201212
"""Future chunking implementations (i.e. by section) will get an elif clause here."""

paperqa/types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,24 @@ def __hash__(self) -> int:
134134
def formatted_citation(self) -> str:
135135
return self.citation
136136

137+
def matches_filter_criteria(self, filter_criteria: dict) -> bool:
138+
"""Returns True if the doc matches the filter criteria, False otherwise."""
139+
data_dict = self.model_dump()
140+
for key, value in filter_criteria.items():
141+
invert = key.startswith("!")
142+
relaxed = key.startswith("?")
143+
key = key.lstrip("!?")
144+
# we check if missing or sentinel/unset
145+
if relaxed and (key not in data_dict or data_dict[key] is None):
146+
continue
147+
if key not in data_dict:
148+
return False
149+
if invert and data_dict[key] == value:
150+
return False
151+
if not invert and data_dict[key] != value:
152+
return False
153+
return True
154+
137155

138156
class Text(Embeddable):
139157
text: str

tests/test_configs.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_formatted_variables,
1515
get_settings,
1616
)
17+
from paperqa.types import Doc, DocDetails
1718

1819

1920
def test_prompt_settings_validation() -> None:
@@ -90,3 +91,79 @@ def test_o1_requires_temp_equals_1() -> None:
9091
warnings.simplefilter("always")
9192
_ = Settings(llm="o1-thismodeldoesnotexist", temperature=1)
9293
assert not w
94+
95+
96+
@pytest.mark.parametrize(
97+
("doc_class", "doc_data", "filter_criteria", "expected_result"),
98+
[
99+
pytest.param(
100+
Doc,
101+
{
102+
"docname": "Test Paper",
103+
"citation": "Test Citation",
104+
"dockey": "key1",
105+
},
106+
{"docname": "Test Paper"},
107+
True,
108+
id="Doc-matching-docname",
109+
),
110+
pytest.param(
111+
Doc,
112+
{
113+
"docname": "Test Paper",
114+
"citation": "Test Citation",
115+
"dockey": "key1",
116+
},
117+
{"docname": "Another Paper"},
118+
False,
119+
id="Doc-nonmatching-docname",
120+
),
121+
pytest.param(
122+
DocDetails,
123+
{
124+
"title": "Test Paper",
125+
"authors": ["Alice", "Bob"],
126+
"year": 2020,
127+
},
128+
{"title": "Test Paper"},
129+
True,
130+
id="DocDetails-matching-title",
131+
),
132+
pytest.param(
133+
DocDetails,
134+
{
135+
"title": "Test Paper",
136+
"authors": ["Alice", "Bob"],
137+
"year": 2020,
138+
},
139+
{"!year": 2020, "?foo": "bar"},
140+
False,
141+
id="DocDetails-inverted-matching-year",
142+
),
143+
pytest.param(
144+
DocDetails,
145+
{
146+
"title": "Test Paper",
147+
"authors": ["Alice", "Bob"],
148+
"year": 2020,
149+
},
150+
{"year": 2020, "foo": "bar"},
151+
False,
152+
id="DocDetails-missing-param-fail",
153+
),
154+
pytest.param(
155+
DocDetails,
156+
{
157+
"title": "Test Paper",
158+
"authors": ["Alice", "Bob"],
159+
"year": 2020,
160+
},
161+
{"?volume": "10", "!title": "Another Paper"},
162+
True,
163+
id="DocDetails-relaxed-missing-volume",
164+
),
165+
],
166+
)
167+
def test_matches_filter_criteria(doc_class, doc_data, filter_criteria, expected_result):
168+
doc = doc_class(**doc_data)
169+
assert doc.matches_filter_criteria(filter_criteria) == expected_result

0 commit comments

Comments
 (0)