Skip to content

Commit 6cf3163

Browse files
mskarlinsidnarayanan
authored andcommitted
Add new unpaywall provider (#310)
* add unpaywall provider * remove unused clean query method, update test cassettes to use [email protected]
1 parent 4aa1dbc commit 6cf3163

14 files changed

+2430
-290
lines changed

paperqa/clients/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from .crossref import CrossrefProvider
1414
from .journal_quality import JournalQualityPostProcessor
1515
from .semantic_scholar import SemanticScholarProvider
16+
from .unpaywall import UnpaywallProvider
1617

1718
logger = logging.getLogger(__name__)
1819

19-
ALL_CLIENTS: (
20+
DEFAULT_CLIENTS: (
2021
Collection[type[MetadataPostProcessor | MetadataProvider]]
2122
| Sequence[Collection[type[MetadataPostProcessor | MetadataProvider]]]
2223
) = {
@@ -25,6 +26,13 @@
2526
JournalQualityPostProcessor,
2627
}
2728

29+
ALL_CLIENTS: (
30+
Collection[type[MetadataPostProcessor | MetadataProvider]]
31+
| Sequence[Collection[type[MetadataPostProcessor | MetadataProvider]]]
32+
) = DEFAULT_CLIENTS | { # type: ignore[operator]
33+
UnpaywallProvider,
34+
}
35+
2836

2937
class DocMetadataTask(BaseModel):
3038
"""Holder for provider and processor tasks."""
@@ -59,7 +67,7 @@ def __init__(
5967
clients: (
6068
Collection[type[MetadataPostProcessor | MetadataProvider]]
6169
| Sequence[Collection[type[MetadataPostProcessor | MetadataProvider]]]
62-
) = ALL_CLIENTS,
70+
) = DEFAULT_CLIENTS,
6371
) -> None:
6472
"""Metadata client for querying multiple metadata providers and processors.
6573

paperqa/clients/unpaywall.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from datetime import datetime
5+
from http import HTTPStatus
6+
from urllib.parse import quote
7+
8+
import aiohttp
9+
from pydantic import BaseModel, ConfigDict, ValidationError
10+
11+
from ..types import DocDetails
12+
from ..utils import (
13+
_get_with_retrying,
14+
strings_similarity,
15+
)
16+
from .client_models import DOIOrTitleBasedProvider, DOIQuery, TitleAuthorQuery
17+
from .exceptions import DOINotFoundError
18+
19+
UNPAYWALL_BASE_URL = "https://api.unpaywall.org/v2/"
20+
UNPAYWALL_TIMEOUT = float(os.environ.get("UNPAYWALL_TIMEOUT", "10.0")) # seconds
21+
22+
23+
class Author(BaseModel):
24+
family: str | None = None
25+
given: str | None = None
26+
sequence: str | None = None
27+
affiliation: list[dict[str, str]] | None = None
28+
model_config = ConfigDict(extra="allow")
29+
30+
31+
class BestOaLocation(BaseModel):
32+
updated: datetime | None = None
33+
url: str | None = None
34+
url_for_pdf: str | None = None
35+
url_for_landing_page: str | None = None
36+
evidence: str | None = None
37+
license: str | None = None
38+
version: str | None = None
39+
host_type: str | None = None
40+
is_best: bool | None = None
41+
pmh_id: str | None = None
42+
endpoint_id: str | None = None
43+
repository_institution: str | None = None
44+
oa_date: str | None = None
45+
model_config = ConfigDict(extra="allow")
46+
47+
48+
class UnpaywallResponse(BaseModel):
49+
doi: str
50+
doi_url: str | None = None
51+
title: str | None = None
52+
genre: str | None = None
53+
is_paratext: bool | None = None
54+
published_date: str | None = None
55+
year: int | None = None
56+
journal_name: str | None = None
57+
journal_issns: str | None = None
58+
journal_issn_l: str | None = None
59+
journal_is_oa: bool | None = None
60+
journal_is_in_doaj: bool | None = None
61+
publisher: str | None = None
62+
is_oa: bool
63+
oa_status: str | None = None
64+
has_repository_copy: bool | None = None
65+
best_oa_location: BestOaLocation | None = None
66+
updated: datetime | None = None
67+
z_authors: list[Author] | None = None
68+
69+
70+
class SearchResponse(BaseModel):
71+
response: UnpaywallResponse
72+
score: float
73+
snippet: str
74+
75+
76+
class SearchResults(BaseModel):
77+
results: list[SearchResponse]
78+
elapsed_seconds: float
79+
80+
81+
class UnpaywallProvider(DOIOrTitleBasedProvider):
82+
83+
async def get_doc_details(
84+
self, doi: str, session: aiohttp.ClientSession
85+
) -> DocDetails:
86+
87+
try:
88+
results = UnpaywallResponse(
89+
**(
90+
await _get_with_retrying(
91+
url=f"{UNPAYWALL_BASE_URL}{doi}?email={os.environ.get("UNPAYWALL_EMAIL", "[email protected]")}",
92+
params={},
93+
session=session,
94+
timeout=UNPAYWALL_TIMEOUT,
95+
http_exception_mappings={
96+
HTTPStatus.NOT_FOUND: DOINotFoundError(
97+
f"Unpaywall not find DOI for {doi}."
98+
)
99+
},
100+
)
101+
)
102+
)
103+
except ValidationError as e:
104+
raise DOINotFoundError(
105+
f"Unpaywall results returned with a bad schema for DOI {doi!r}."
106+
) from e
107+
108+
return self._create_doc_details(results)
109+
110+
async def search_by_title(
111+
self,
112+
query: str,
113+
session: aiohttp.ClientSession,
114+
title_similarity_threshold: float = 0.75,
115+
) -> DocDetails:
116+
try:
117+
results = SearchResults(
118+
**(
119+
await _get_with_retrying(
120+
url=(
121+
f"{UNPAYWALL_BASE_URL}search?query={quote(query)}"
122+
f'&email={os.environ.get("UNPAYWALL_EMAIL", "[email protected]")}'
123+
),
124+
params={},
125+
session=session,
126+
timeout=UNPAYWALL_TIMEOUT,
127+
http_exception_mappings={
128+
HTTPStatus.NOT_FOUND: DOINotFoundError(
129+
f"Could not find DOI for {query}."
130+
)
131+
},
132+
)
133+
)
134+
).results
135+
except ValidationError as e:
136+
raise DOINotFoundError(
137+
f"Unpaywall results returned with a bad schema for title {query!r}."
138+
) from e
139+
140+
if not results:
141+
raise DOINotFoundError(
142+
f"Unpaywall results did not match for title {query!r}."
143+
)
144+
145+
details = self._create_doc_details(results[0].response)
146+
147+
if (
148+
strings_similarity(
149+
details.title or "",
150+
query,
151+
)
152+
< title_similarity_threshold
153+
):
154+
raise DOINotFoundError(
155+
f"Unpaywall results did not match for title {query!r}."
156+
)
157+
return details
158+
159+
def _create_doc_details(self, data: UnpaywallResponse) -> DocDetails:
160+
return DocDetails( # type: ignore[call-arg]
161+
authors=[
162+
f"{author.given} {author.family}" for author in (data.z_authors or [])
163+
],
164+
publication_date=(
165+
None
166+
if not data.published_date
167+
else datetime.strptime(data.published_date, "%Y-%m-%d")
168+
),
169+
year=data.year,
170+
journal=data.journal_name,
171+
publisher=data.publisher,
172+
url=None if not data.best_oa_location else data.best_oa_location.url,
173+
title=data.title,
174+
doi=data.doi,
175+
doi_url=data.doi_url,
176+
other={
177+
"genre": data.genre,
178+
"is_paratext": data.is_paratext,
179+
"journal_issns": data.journal_issns,
180+
"journal_issn_l": data.journal_issn_l,
181+
"journal_is_oa": data.journal_is_oa,
182+
"journal_is_in_doaj": data.journal_is_in_doaj,
183+
"is_oa": data.is_oa,
184+
"oa_status": data.oa_status,
185+
"has_repository_copy": data.has_repository_copy,
186+
"best_oa_location": (
187+
None
188+
if not data.best_oa_location
189+
else data.best_oa_location.model_dump()
190+
),
191+
},
192+
)
193+
194+
async def _query(self, query: TitleAuthorQuery | DOIQuery) -> DocDetails | None:
195+
if isinstance(query, DOIQuery):
196+
return await self.get_doc_details(doi=query.doi, session=query.session)
197+
return await self.search_by_title(
198+
query=query.title,
199+
session=query.session,
200+
title_similarity_threshold=query.title_similarity_threshold,
201+
)

paperqa/docs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
except ImportError:
2323
USE_VOYAGE = False
2424

25-
from .clients import ALL_CLIENTS, DocMetadataClient
25+
from .clients import DEFAULT_CLIENTS, DocMetadataClient
2626
from .llms import (
2727
HybridEmbeddingModel,
2828
LLMModel,
@@ -473,7 +473,7 @@ async def aadd( # noqa: C901, PLR0912, PLR0915
473473
else:
474474
metadata_client = DocMetadataClient(
475475
session=kwargs.pop("session", None),
476-
clients=kwargs.pop("clients", ALL_CLIENTS),
476+
clients=kwargs.pop("clients", DEFAULT_CLIENTS),
477477
)
478478

479479
query_kwargs: dict[str, Any] = {}

paperqa/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,10 +565,12 @@ def populate_bibtex_key_citation( # noqa: C901, PLR0912
565565
logger.warning(
566566
f"Failed to generate bibtex for {data.get('docname') or data.get('citation')}"
567567
)
568-
if not data.get("citation"):
568+
if not data.get("citation") and data.get("bibtex") is not None:
569569
data["citation"] = format_bibtex(
570570
data["bibtex"], clean=True, missing_replacements=CITATION_FALLBACK_DATA # type: ignore[arg-type]
571571
)
572+
elif not data.get("citation"):
573+
data["citation"] = data.get("title") or CITATION_FALLBACK_DATA["title"]
572574
return data
573575

574576
@model_validator(mode="before")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ urls = {repository = "https://github.com/whitead/paper-qa"}
4545

4646
[project.optional-dependencies]
4747
agents = [
48+
"anthropic",
4849
"anyio",
4950
"langchain-community",
5051
"langchain-openai",

0 commit comments

Comments
 (0)