-
Notifications
You must be signed in to change notification settings - Fork 749
Implement support to BatchAPIs to gather evidence #687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 30 commits
2385dce
8a21055
5f59681
6f7bbb5
e8dc0d0
899de43
16c3988
d10a268
0fe9aa1
a9ad540
9a0a6c4
723650d
660bfa0
977a025
ee351f2
293658a
1ad1c7c
af32005
63e4b39
f61e629
d7dbd72
7c37f6d
de18907
3e72bd4
7c7f4b8
6c8f186
0e43a7c
17c26eb
c258306
4b8e1c3
8b5c1fa
ab40b54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,9 +22,10 @@ | |
) | ||
|
||
from paperqa.clients import DEFAULT_CLIENTS, DocMetadataClient | ||
from paperqa.core import llm_parse_json, map_fxn_summary | ||
from paperqa.core import gather_with_batch, llm_parse_json, map_fxn_summary | ||
from paperqa.llms import ( | ||
EmbeddingModel, | ||
LLMBatchModel, | ||
LLMModel, | ||
NumpyVectorStore, | ||
PromptRunner, | ||
|
@@ -559,14 +560,14 @@ def get_evidence( | |
) | ||
) | ||
|
||
async def aget_evidence( | ||
async def aget_evidence( # noqa: PLR0912 | ||
self, | ||
query: PQASession | str, | ||
exclude_text_filter: set[str] | None = None, | ||
settings: MaybeSettings = None, | ||
callbacks: list[Callable] | None = None, | ||
embedding_model: EmbeddingModel | None = None, | ||
summary_llm_model: LLMModel | None = None, | ||
summary_llm_model: LLMModel | LLMBatchModel | None = None, | ||
) -> PQASession: | ||
|
||
evidence_settings = get_settings(settings) | ||
|
@@ -629,28 +630,40 @@ async def aget_evidence( | |
) | ||
|
||
with set_llm_session_ids(session.id): | ||
results = await gather_with_concurrency( | ||
answer_config.max_concurrent_requests, | ||
[ | ||
map_fxn_summary( | ||
text=m, | ||
question=session.question, | ||
prompt_runner=prompt_runner, | ||
extra_prompt_data={ | ||
"summary_length": answer_config.evidence_summary_length, | ||
"citation": f"{m.name}: {m.doc.formatted_citation}", | ||
}, | ||
parser=llm_parse_json if prompt_config.use_json else None, | ||
callbacks=callbacks, | ||
) | ||
for m in matches | ||
], | ||
) | ||
if evidence_settings.use_batch_in_summary: | ||
results = await gather_with_batch( | ||
matches=matches, | ||
question=session.question, | ||
prompt_runner=prompt_runner, | ||
extra_prompt_data={ | ||
"summary_length": answer_config.evidence_summary_length, | ||
}, | ||
maykcaldas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parser=llm_parse_json if prompt_config.use_json else None, | ||
callbacks=callbacks, | ||
) | ||
else: | ||
results = await gather_with_concurrency( | ||
answer_config.max_concurrent_requests, | ||
[ | ||
map_fxn_summary( | ||
text=m, | ||
question=session.question, | ||
prompt_runner=prompt_runner, | ||
extra_prompt_data={ | ||
"summary_length": answer_config.evidence_summary_length, | ||
"citation": f"{m.name}: {m.doc.formatted_citation}", | ||
}, | ||
parser=llm_parse_json if prompt_config.use_json else None, | ||
callbacks=callbacks, | ||
) | ||
for m in matches | ||
], | ||
) | ||
|
||
for _, llm_result in results: | ||
session.add_tokens(llm_result) | ||
|
||
session.contexts += [r for r, _ in results if r is not None] | ||
session.contexts += [r for r, _ in results] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did we cut the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This gets the Maybe that's an edge case that I didn't see? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we correctly type hinted T = TypeVar("T")
async def gather_with_concurrency(n: int, coros: Iterable[Awaitable[T]]) -> list[T]:
...
``` |
||
return session | ||
|
||
def query( | ||
|
@@ -659,7 +672,7 @@ def query( | |
settings: MaybeSettings = None, | ||
callbacks: list[Callable] | None = None, | ||
llm_model: LLMModel | None = None, | ||
summary_llm_model: LLMModel | None = None, | ||
summary_llm_model: LLMModel | LLMBatchModel | None = None, | ||
embedding_model: EmbeddingModel | None = None, | ||
) -> PQASession: | ||
return get_loop().run_until_complete( | ||
|
@@ -679,7 +692,7 @@ async def aquery( # noqa: PLR0912 | |
settings: MaybeSettings = None, | ||
callbacks: list[Callable] | None = None, | ||
llm_model: LLMModel | None = None, | ||
summary_llm_model: LLMModel | None = None, | ||
summary_llm_model: LLMModel | LLMBatchModel | None = None, | ||
embedding_model: EmbeddingModel | None = None, | ||
) -> PQASession: | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.