From 2385dce5096ff09c09858707a4f30476de14f808 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Thu, 14 Nov 2024 14:26:49 -0800 Subject: [PATCH 01/26] Implements OpenAIBatchLLMModel This class is used to submit batch calls to the OpenAI batch API --- paperqa/__init__.py | 2 + paperqa/llms.py | 209 +++++++++++++++++++++++++++++++++++++++++++- paperqa/settings.py | 25 +++++- tests/test_llms.py | 63 +++++++++++++ 4 files changed, 297 insertions(+), 2 deletions(-) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 008b18255..4844a753f 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -15,6 +15,7 @@ HybridEmbeddingModel, LiteLLMEmbeddingModel, LiteLLMModel, + OpenAIBatchLLMModel, LLMModel, LLMResult, NumpyVectorStore, @@ -38,6 +39,7 @@ "LLMResult", "LiteLLMEmbeddingModel", "LiteLLMModel", + "OpenAIBatchLLMModel" "NumpyVectorStore", "PQASession", "QueryRequest", diff --git a/paperqa/llms.py b/paperqa/llms.py index ac2092f54..8884e4695 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -19,6 +19,12 @@ from typing import Any, TypeVar, cast import litellm + +import openai +import json +import os +import tempfile + import numpy as np import tiktoken from pydantic import ( @@ -325,7 +331,7 @@ def count_tokens(self, text: str) -> int: async def run_prompt( self, prompt: str, - data: dict, + data: dict | list[dict[str, str]], callbacks: list[Callable] | None = None, name: str | None = None, skip_system: bool = False, @@ -761,6 +767,207 @@ def infer_llm_type(self) -> str: def count_tokens(self, text: str) -> int: return litellm.token_counter(model=self.name, text=text) +class OpenAIBatchLLMModel(LLMModel): + """A wrapper around the OpenAI library to use the batch API.""" + name: str = "gpt-4o-mini" + config: dict = Field( + default_factory=dict, + description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", + ) + + def write_jsonl(self, + data: list[dict[str, str]], + filename: str): + + batch_template = { + "custom_id": None, + "method": "POST", + "url": self.config.get('endpoint'), + "body": { + "model": None, + "messages": None, + "max_tokens": None + } + } + with open(filename, "w") as f: + for i, d in enumerate(data): + batch_template["custom_id"] = str(i) + batch_template["body"]["model"] = self.config.get('model') + batch_template["body"]["messages"] = d + batch_template["body"]["max_tokens"] = self.config.get('max_tokens') + f.write(json.dumps(batch_template) + "\n") + + @rate_limited + async def acomplete(self): + raise NotImplementedError("Only chat models are supported by openAI batch API.") + + @rate_limited + async def acomplete_iter(self): + raise NotImplementedError("Async generator not supported for batch calls and nly chat models are supported by openAI batch API.") + + async def _run_chat( + self, + prompt: str, + data: list[dict[str,str]], + callbacks: list[Callable] | None = None, + name: str | None = None, + skip_system: bool = False, + system_prompt: str = default_system_prompt, + ) -> list[LLMResult]: + if callbacks: + sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] + async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] + + system_message_prompt = {"role": "system", "content": system_prompt} + human_message_prompt = {"role": "user", "content": prompt} + + batch = [] + for d in data: + messages = [ + {"role": m["role"], "content": m["content"].format(**d)} + for m in ( + [human_message_prompt] + if skip_system + else [system_message_prompt, human_message_prompt] + ) + ] + batch.append(messages) + + start_clock = asyncio.get_running_loop().time() + chunks = await self.achat(batch) + batch_time = asyncio.get_running_loop().time() - start_clock + + if callbacks: + for chunk in chunks: + await do_callbacks( + async_callbacks, sync_callbacks, chunk.text, name + ) + + results = [ + LLMResult( + model=self.name, + name=name, + prompt=messages, + prompt_count=chunk.prompt_tokens, + text=chunk.text, + completion_count=chunk.completion_tokens, + seconds_to_first_token=batch_time, + seconds_to_last_token=batch_time, + ) for messages, chunk in zip(batch, chunks) + ] + + return results + + @rate_limited + async def achat(self, + messages: list[dict[str, str]] + ) -> list[Chunk]: + client = openai.OpenAI() + + with tempfile.NamedTemporaryFile(suffix=".jsonl", delete=True) as tmp_file: + tmp_filename = tmp_file.name + self.write_jsonl(messages, tmp_filename) + file = client.files.create( + file=open(tmp_filename, "rb"), + purpose="batch" + ) + + batch = client.batches.create( + input_file_id=file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": "" + } + ) + + while batch.status != "completed": + batch = client.batches.retrieve(batch.id) + if batch.status == "failed": + raise Exception("Batch failed. \n\nReason: \n" + "\n".join([k.message for k in batch.errors.data])) + await asyncio.sleep(5) + + responses = client.files.content(batch.output_file_id) + response_lines = responses.read().decode('utf-8').splitlines() + responses = [json.loads(line) for line in response_lines] + sorted_responses = sorted(responses, key=lambda x: int(x["custom_id"])) # The batchAPI doesn't guarantee the order of the responses + + chunks = [ + Chunk( + text=response["response"]["body"]["choices"][0]["message"]["content"], + prompt_tokens=response["response"]["body"]["usage"]["prompt_tokens"], + completion_tokens=response["response"]["body"]["usage"]["completion_tokens"], + ) for response in sorted_responses + ] + + return chunks + + @rate_limited + async def achat_iter(self): + raise NotImplementedError("Async generator not supported for batch calls. Use achat instead.") + + def infer_llm_type(self): + self.config['endpoint'] = "/v1/chat/completions" + return "chat" + + def count_tokens(self, text: str) -> int: + return len(text) // 4 + + async def check_rate_limit(self, token_count: float, **kwargs) -> None: + if "rate_limit" in self.config: + await GLOBAL_LIMITER.try_acquire( + ("client", self.name), + self.config["rate_limit"].get(self.name, None), + weight=max(int(token_count), 1), + **kwargs, + ) + + +class AnthropicBatchLLMModel(LLMModel): + # TODO: This class is not implemented yet. + + @rate_limited + async def acomplete(self): + raise NotImplementedError("Completion models are not supported yet") + + @rate_limited + async def acomplete_iter(self): + raise NotImplementedError("Completion models are not supported yet") + + async def _run_chat(sellf): + '''Processes the batch and call the chat completion method''' + ... + + @rate_limited + async def achat(self, messages): + ... + + @rate_limited + async def achat_iter(self): + raise NotImplementedError("support to callbacks is not implemented yet") + + def infer_llm_type(self): + return "chat" #TODO: Support completion models + + def count_tokens(self, text: str) -> int: + return len(text) // 4 #TODO: Check if OpenAI has a method for that. Currently it's not being used. The token usage is directly retrieved from the response. + + def __getstate__(self): + # Prevent _router from being pickled, SEE: https://stackoverflow.com/a/2345953 + state = super().__getstate__() + state["__dict__"] = state["__dict__"].copy() + state["__dict__"].pop("_router", None) + return state + + async def check_rate_limit(self, token_count: float, **kwargs) -> None: + if "rate_limit" in self.config: + await GLOBAL_LIMITER.try_acquire( + ("client", self.name), + self.config["rate_limit"].get(self.name, None), + weight=max(int(token_count), 1), + **kwargs, + ) + def cosine_similarity(a, b): norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) diff --git a/paperqa/settings.py b/paperqa/settings.py index 495f7fc85..f31221d56 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -40,7 +40,7 @@ except ImportError: HAS_LDP_INSTALLED = False -from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory +from paperqa.llms import EmbeddingModel, LiteLLMModel, OpenAIBatchLLMModel, embedding_model_factory from paperqa.prompts import ( CONTEXT_INNER_PROMPT, CONTEXT_OUTER_PROMPT, @@ -577,6 +577,15 @@ def make_default_litellm_model_list_settings( ] } +def make_default_openai_batch_llm_settings( + llm: str, temperature: float = 0.0 +) -> dict: + return { + "model": llm, + "temperature": temperature, + "max_tokens": 2048, + + } class Settings(BaseSettings): model_config = SettingsConfigDict(extra="ignore") @@ -609,6 +618,10 @@ class Settings(BaseSettings): " router_kwargs key with router kwargs as values." ), ) + use_batch_in_summary: bool = Field( + default=False, + description="Whether to use batch API for LLMs in summarization", + ) embedding: str = Field( default="text-embedding-3-small", description="Default embedding model for texts", @@ -793,6 +806,16 @@ def get_llm(self) -> LiteLLMModel: ) def get_summary_llm(self) -> LiteLLMModel: + if self.use_batch_in_summary: + # TODO: support other LLM providers as well. + # TODO: Make it fail if we don't support the batchAPI for the LLM being used + return OpenAIBatchLLMModel( + name=self.summary_llm, + config=self.summary_llm_config + or make_default_openai_batch_llm_settings( + self.summary_llm, self.temperature + ), + ) return LiteLLMModel( name=self.summary_llm, config=self.summary_llm_config diff --git a/tests/test_llms.py b/tests/test_llms.py index 9cf827bdc..7dbd4ac0e 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -10,6 +10,7 @@ HybridEmbeddingModel, LiteLLMEmbeddingModel, LiteLLMModel, + OpenAIBatchLLMModel, SentenceTransformerEmbeddingModel, SparseEmbeddingModel, embedding_model_factory, @@ -158,6 +159,68 @@ def test_pickling(self, tmp_path: pathlib.Path) -> None: assert llm.config == rehydrated_llm.config assert llm.router.deployment_names == rehydrated_llm.router.deployment_names +class TestOpenAIBatchLLMModel: + @pytest.fixture(scope="class") + def config(self, request) -> dict[str, Any]: + model_name = request.param + return { + "model": model_name, + "temperature": 0.0, + "max_tokens": 64, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "config",[ + pytest.param("gpt-4o-mini", id="chat-model"), + pytest.param("gpt-3.5-turbo-instruct", id="completion-model") + ], indirect=True + ) + async def test_run_prompt(self, config: dict[str, Any], request) -> None: + llm = OpenAIBatchLLMModel(name="gpt-4o-mini", config=config) + + data = [ + {"animal": "duck"}, + {"animal": "dog"}, + {"animal": "cat"} + ] + + if request.node.name == "test_run_prompt[completion-model]": + with pytest.raises(Exception) as e_info: + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + skip_system=True, + ) + assert "Batch failed" in str(e_info.value) + assert "not supported" in str(e_info.value) + + if request.node.name == "test_run_prompt[chat-model]": + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + skip_system=True, + ) + + assert all([completion[k].model == config['model'] for k, _ in enumerate(data)]) + assert all([completion[k].seconds_to_first_token > 0 for k, _ in enumerate(data)]) + assert all([completion[k].prompt_count > 0 for k, _ in enumerate(data)]) + assert all([completion[k].completion_count > 0 for k, _ in enumerate(data)]) + assert all([completion[k].completion_count < config['max_tokens'] for k, _ in enumerate(data)]) + assert sum([completion[k].cost for k, _ in enumerate(data)]) > 0 + + def test_pickling(self, tmp_path: pathlib.Path, config: dict[str,Any]) -> None: + pickle_path = tmp_path / "llm_model.pickle" + llm = OpenAIBatchLLMModel( + name="gpt-4o-mini", + config=config, + ) + with pickle_path.open("wb") as f: + pickle.dump(llm, f) + with pickle_path.open("rb") as f: + rehydrated_llm = pickle.load(f) + assert llm.name == rehydrated_llm.name + assert llm.config == rehydrated_llm.config @pytest.mark.asyncio async def test_embedding_model_factory_sentence_transformer() -> None: From 8a21055890ba3a5ccd2bbf414fc9a0d98da7510d Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Thu, 14 Nov 2024 14:30:33 -0800 Subject: [PATCH 02/26] Incorporates OpenAIBatchLLMModel to get_evidence --- paperqa/docs.py | 78 ++++++++++++++++++++++++++++++++++--------- tests/test_paperqa.py | 11 ++++-- 2 files changed, 70 insertions(+), 19 deletions(-) diff --git a/paperqa/docs.py b/paperqa/docs.py index ea7fa6071..f8164f5d4 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -40,6 +40,7 @@ LLMResult, PQASession, Text, + Context, set_llm_session_ids, ) from paperqa.utils import ( @@ -50,6 +51,8 @@ maybe_is_text, md5sum, name_in_text, + extract_score, + strip_citations ) logger = logging.getLogger(__name__) @@ -600,23 +603,66 @@ async def aget_evidence( ) with set_llm_session_ids(session.id): - results = await gather_with_concurrency( - answer_config.max_concurrent_requests, - [ - map_fxn_summary( - text=m, - question=session.question, - prompt_runner=prompt_runner, - extra_prompt_data={ - "summary_length": answer_config.evidence_summary_length, - "citation": f"{m.name}: {m.doc.formatted_citation}", - }, - parser=llm_parse_json if prompt_config.use_json else None, - callbacks=callbacks, - ) + if evidence_settings.use_batch_in_summary: + # TODO: Should we implement a `gather_with_batch` function that receives `matches` and return results to keep this dry? + + data = [ + {"question": session.question, + "citation": m.name + ": " + m.doc.formatted_citation, + "text": m.text} | + {"summary_length": answer_config.evidence_summary_length, + "citation": f"{m.name}: {m.doc.formatted_citation}", + "evidence": m.name} for m in matches - ], - ) + ] + + llm_results = await prompt_runner( + data, + callbacks, + ) + + results_data = [] + scores = [] + for r in llm_results: + try: + results_data.append(llm_parse_json(r.text)) + scores.append(r.pop("relevance_score")) + # just in case question was present + r.pop("question", None) + except ValueError: + results_data.append({}) + scores.append(extract_score(r.text)) + + results = [ + ( + Context( + context=strip_citations(llm_result.text), + text=m, + model_extra={}, + score=score, + **r, + ), + llm_result, + ) for r, m, llm_result, score in zip(results_data, matches, llm_results, scores) + ] + else: + results = await gather_with_concurrency( + answer_config.max_concurrent_requests, + [ + map_fxn_summary( + text=m, + question=session.question, + prompt_runner=prompt_runner, + extra_prompt_data={ + "summary_length": answer_config.evidence_summary_length, + "citation": f"{m.name}: {m.doc.formatted_citation}", + }, + parser=llm_parse_json if prompt_config.use_json else None, + callbacks=callbacks, + ) + for m in matches + ], + ) for _, llm_result in results: session.add_tokens(llm_result) diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index d40959237..2a92bbac2 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -505,10 +505,15 @@ async def test_docs_lifecycle(subtests: SubTests, stub_data_dir: Path) -> None: assert docs.texts assert all(t not in docs.texts_index for t in docs.texts) - -def test_evidence(docs_fixture) -> None: +@pytest.mark.parametrize("use_batch", [ + pytest.param(True, id="using-batch"), + pytest.param(False, id="not-using-batch") + ] + ) +def test_evidence(docs_fixture, use_batch) -> None: debug_settings = Settings.from_name("debug") - evidence = docs_fixture.get_evidence( + debug_settings.use_batch_in_summary = use_batch + evidence = docs_fixture.get_evidence( PQASession(question="What does XAI stand for?"), settings=debug_settings, ).contexts From e8dc0d0f50e9cf2d86686e65f55a9f7f927cc2a8 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Fri, 15 Nov 2024 09:46:50 -0800 Subject: [PATCH 03/26] Started anthropic batch api support implementation also added a dependency group in pyproject.toml to install openai and anthropic only if the user wants to use batches, refactored the logic of sumarizing evidences in batch and moved the code to core.py --- paperqa/__init__.py | 4 +- paperqa/core.py | 47 +++++++++++++++++++ paperqa/docs.py | 57 +++++++---------------- paperqa/llms.py | 102 +++++++++++++++++++++++++++++++++++------- paperqa/settings.py | 53 +++++++++++++++++----- pyproject.toml | 4 ++ tests/test_llms.py | 39 +++++++++++++++- tests/test_paperqa.py | 2 + 8 files changed, 237 insertions(+), 71 deletions(-) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 4844a753f..2346f0b0c 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -16,6 +16,7 @@ LiteLLMEmbeddingModel, LiteLLMModel, OpenAIBatchLLMModel, + AnthropicBatchLLMModel, LLMModel, LLMResult, NumpyVectorStore, @@ -39,7 +40,8 @@ "LLMResult", "LiteLLMEmbeddingModel", "LiteLLMModel", - "OpenAIBatchLLMModel" + "OpenAIBatchLLMModel", + "AnthropicBatchLLMModel", "NumpyVectorStore", "PQASession", "QueryRequest", diff --git a/paperqa/core.py b/paperqa/core.py index 5ceb00602..786e26ddd 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -115,3 +115,50 @@ async def map_fxn_summary( ), llm_result, ) + +async def gather_with_batch( + matches: list[Text], + question: str, + prompt_runner: PromptRunner | None, + extra_prompt_data: dict[str, str] | None = None, + parser: Callable[[str], dict[str, Any]] | None = None, + callbacks: list[Callable[[str], None]] | None = None, + ) -> list[tuple[Context, LLMResult]]: + """Gathers a batch of results for a given text.""" + data = [ + {"question": question, + "citation": m.name + ": " + m.doc.formatted_citation, + "text": m.text} | + extra_prompt_data or {} + for m in matches + ] + + llm_results = await prompt_runner( + data, + callbacks, + ) + + results_data = [] + scores = [] + for r in llm_results: + try: + results_data.append(parser(r.text)) + scores.append(r.pop("relevance_score")) + # just in case question was present + r.pop("question", None) + except: + results_data.append({}) + scores.append(extract_score(r.text)) + + return [ + ( + Context( + context=strip_citations(llm_result.text), + text=m, + model_extra={}, + score=score, + **r, + ), + llm_result, + ) for r, m, llm_result, score in zip(results_data, matches, llm_results, scores) + ] \ No newline at end of file diff --git a/paperqa/docs.py b/paperqa/docs.py index 046077adf..0f5d8832a 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -22,7 +22,11 @@ ) from paperqa.clients import DEFAULT_CLIENTS, DocMetadataClient -from paperqa.core import llm_parse_json, map_fxn_summary +from paperqa.core import ( + llm_parse_json, + map_fxn_summary, + gather_with_batch +) from paperqa.llms import ( EmbeddingModel, LLMModel, @@ -604,47 +608,16 @@ async def aget_evidence( with set_llm_session_ids(session.id): if evidence_settings.use_batch_in_summary: - # TODO: Should we implement a `gather_with_batch` function that receives `matches` and return results to keep this dry? - - data = [ - {"question": session.question, - "citation": m.name + ": " + m.doc.formatted_citation, - "text": m.text} | - {"summary_length": answer_config.evidence_summary_length, - "citation": f"{m.name}: {m.doc.formatted_citation}", - "evidence": m.name} - for m in matches - ] - - llm_results = await prompt_runner( - data, - callbacks, - ) - - results_data = [] - scores = [] - for r in llm_results: - try: - results_data.append(llm_parse_json(r.text)) - scores.append(r.pop("relevance_score")) - # just in case question was present - r.pop("question", None) - except ValueError: - results_data.append({}) - scores.append(extract_score(r.text)) - - results = [ - ( - Context( - context=strip_citations(llm_result.text), - text=m, - model_extra={}, - score=score, - **r, - ), - llm_result, - ) for r, m, llm_result, score in zip(results_data, matches, llm_results, scores) - ] + results = await gather_with_batch( + matches = matches, + question = session.question, + prompt_runner=prompt_runner, + extra_prompt_data={ + "summary_length": answer_config.evidence_summary_length, + }, + parser=llm_parse_json if prompt_config.use_json else None, + callbacks=callbacks, + ) else: results = await gather_with_concurrency( answer_config.max_concurrent_requests, diff --git a/paperqa/llms.py b/paperqa/llms.py index 857b12beb..fc28c2ed0 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -20,7 +20,6 @@ import litellm -import openai import json import os import tempfile @@ -759,6 +758,7 @@ def infer_llm_type(self) -> str: def count_tokens(self, text: str) -> int: return litellm.token_counter(model=self.name, text=text) + class OpenAIBatchLLMModel(LLMModel): """A wrapper around the OpenAI library to use the batch API.""" name: str = "gpt-4o-mini" @@ -854,9 +854,17 @@ async def _run_chat( async def achat(self, messages: list[dict[str, str]] ) -> list[Chunk]: + try: + import openai + except ImportError as exc: + raise ImportError( + "Please install paper-qa[batch] to use" + " OpenAIBatchLLMModel." + ) + client = openai.OpenAI() - with tempfile.NamedTemporaryFile(suffix=".jsonl", delete=True) as tmp_file: + with tempfile.NamedTemporaryFile(suffix=".jsonl") as tmp_file: tmp_filename = tmp_file.name self.write_jsonl(messages, tmp_filename) file = client.files.create( @@ -916,7 +924,12 @@ async def check_rate_limit(self, token_count: float, **kwargs) -> None: class AnthropicBatchLLMModel(LLMModel): - # TODO: This class is not implemented yet. + """A wrapper around the anthropic library to use the batch API.""" + name: str = "claude-3-5-sonnet-20241022" + config: dict = Field( + default_factory=dict, + description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", + ) @rate_limited async def acomplete(self): @@ -926,13 +939,77 @@ async def acomplete(self): async def acomplete_iter(self): raise NotImplementedError("Completion models are not supported yet") - async def _run_chat(sellf): - '''Processes the batch and call the chat completion method''' - ... + async def _run_chat( + self, + prompt: str, + data: list[dict[str,str]], + callbacks: list[Callable] | None = None, + name: str | None = None, + skip_system: bool = False, + system_prompt: str = default_system_prompt, + ) -> list[LLMResult]: + if callbacks: + sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] + async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] + + system_message_prompt = {"role": "system", "content": system_prompt} + human_message_prompt = {"role": "user", "content": prompt} + + batch = [] + for d in data: + messages = [ + {"role": m["role"], "content": m["content"].format(**d)} + for m in ( + [human_message_prompt] + if skip_system + else [system_message_prompt, human_message_prompt] + ) + ] + batch.append(messages) + + start_clock = asyncio.get_running_loop().time() + chunks = await self.achat(batch) + batch_time = asyncio.get_running_loop().time() - start_clock @rate_limited - async def achat(self, messages): - ... + async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: + try: + import anthropic + from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming + from anthropic.types.beta.messages.batch_create_params import Request + except ImportError as exc: + raise ImportError( + "Please install paper-qa[batch] to use" + " AnthropicBatchLLMModel." + ) + + client = anthropic.Anthropic() + + requests = [ + Request( + custom_id=str(i), + params=MessageCreateParamsNonStreaming( + model=self.config.get('model'), + max_tokens=self.config.get('max_tokens'), + messages=m + ) + ) for i, m in enumerate(messages) + ] + + batch = client.beta.messages.batches.create( + requests=requests + ) + + while batch.processing_status != "ended": + batch = client.beta.messages.batches.retrieve(batch.id) + print(batch.processing_status) + await asyncio.sleep(5) + + responses = client.beta.messages.batches.results(batch.id) + + # TODO: [WIP] Extract the completions from response. But I am having a bad time waiting for the API to return the results. + return + @rate_limited async def achat_iter(self): @@ -942,14 +1019,7 @@ def infer_llm_type(self): return "chat" #TODO: Support completion models def count_tokens(self, text: str) -> int: - return len(text) // 4 #TODO: Check if OpenAI has a method for that. Currently it's not being used. The token usage is directly retrieved from the response. - - def __getstate__(self): - # Prevent _router from being pickled, SEE: https://stackoverflow.com/a/2345953 - state = super().__getstate__() - state["__dict__"] = state["__dict__"].copy() - state["__dict__"].pop("_router", None) - return state + return len(text) // 4 async def check_rate_limit(self, token_count: float, **kwargs) -> None: if "rate_limit" in self.config: diff --git a/paperqa/settings.py b/paperqa/settings.py index 6c309026d..0e91e9b18 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -40,7 +40,13 @@ except ImportError: HAS_LDP_INSTALLED = False -from paperqa.llms import EmbeddingModel, LiteLLMModel, OpenAIBatchLLMModel, embedding_model_factory +from paperqa.llms import ( + EmbeddingModel, + LiteLLMModel, + OpenAIBatchLLMModel, + AnthropicBatchLLMModel, + embedding_model_factory +) from paperqa.prompts import ( CONTEXT_INNER_PROMPT, CONTEXT_OUTER_PROMPT, @@ -607,7 +613,15 @@ class Settings(BaseSettings): ) use_batch_in_summary: bool = Field( default=False, - description="Whether to use batch API for LLMs in summarization", + description=( + "Whether to use batch API for LLMs in summarization." + "This option requests a batch of summaries from the LLM on the `GatherEvidence` step." + "It uses all the candidate papers found in the `GatherEvidence` step" + "to generate a list of prompts that are formatted accordingly to the" + "requirements of the llm provider." + "This option is only available for Claude(https://docs.anthropic.com/en/api/creating-message-batches)" + "and OpenAI (https://platform.openai.com/docs/guides/batch) chat models." + ), ) embedding: str = Field( default="text-embedding-3-small", @@ -794,15 +808,32 @@ def get_llm(self) -> LiteLLMModel: def get_summary_llm(self) -> LiteLLMModel: if self.use_batch_in_summary: - # TODO: support other LLM providers as well. - # TODO: Make it fail if we don't support the batchAPI for the LLM being used - return OpenAIBatchLLMModel( - name=self.summary_llm, - config=self.summary_llm_config - or make_default_openai_batch_llm_settings( - self.summary_llm, self.temperature - ), - ) + import openai + client = openai.OpenAI() + openai_models = [k.id for _, k in enumerate(client.models.list().data) + if k.owned_by in ['system', "openai"]] + if self.summary_llm.startswith("claude-"): + return AnthropicBatchLLMModel( + name=self.summary_llm, + config=self.summary_llm_config + or make_default_openai_batch_llm_settings( + self.summary_llm, self.temperature + ), + ) + elif self.summary_llm in openai_models: + return OpenAIBatchLLMModel( + name=self.summary_llm, + config=self.summary_llm_config + or make_default_openai_batch_llm_settings( + self.summary_llm, self.temperature + ), + ) + else: + raise NotImplementedError( + "`use_batch_in_summary` is set to True, but the summary LLM is not supported" + "for batch processing.\nEither use a Claude or an OpenAI chat model or set " + "`use_batch_in_summary` to False." + ) return LiteLLMModel( name=self.summary_llm, config=self.summary_llm_config diff --git a/pyproject.toml b/pyproject.toml index 41e07aaf3..9a30ca6c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,10 @@ typing = [ zotero = [ "pyzotero", ] +batch = [ + "openai", + "anthropic", +] [project.scripts] pqa = "paperqa.agents:main" diff --git a/tests/test_llms.py b/tests/test_llms.py index 91a9caf24..eef4a7792 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -11,6 +11,7 @@ LiteLLMEmbeddingModel, LiteLLMModel, OpenAIBatchLLMModel, + AnthropicBatchLLMModel, SentenceTransformerEmbeddingModel, SparseEmbeddingModel, embedding_model_factory, @@ -177,7 +178,7 @@ def config(self, request) -> dict[str, Any]: ], indirect=True ) async def test_run_prompt(self, config: dict[str, Any], request) -> None: - llm = OpenAIBatchLLMModel(name="gpt-4o-mini", config=config) + llm = OpenAIBatchLLMModel(name=config['model'], config=config) data = [ {"animal": "duck"}, @@ -209,6 +210,11 @@ async def test_run_prompt(self, config: dict[str, Any], request) -> None: assert all([completion[k].completion_count < config['max_tokens'] for k, _ in enumerate(data)]) assert sum([completion[k].cost for k, _ in enumerate(data)]) > 0 + @pytest.mark.parametrize( + "config",[ + pytest.param("gpt-4o-mini"), + ], indirect=True + ) def test_pickling(self, tmp_path: pathlib.Path, config: dict[str,Any]) -> None: pickle_path = tmp_path / "llm_model.pickle" llm = OpenAIBatchLLMModel( @@ -222,6 +228,37 @@ def test_pickling(self, tmp_path: pathlib.Path, config: dict[str,Any]) -> None: assert llm.name == rehydrated_llm.name assert llm.config == rehydrated_llm.config +class TestAnthropicBatchLLMModel: + @pytest.fixture(scope="class") + def config(self, request) -> dict[str, Any]: + model_name = request.param + return { + "model": model_name, + "temperature": 0.0, + "max_tokens": 64, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "config",[ + pytest.param("claude-3-haiku-20240307", id="chat-model"), + ], indirect=True + ) + async def test_run_prompt(self, config: dict[str, Any], request) -> None: + llm = AnthropicBatchLLMModel(name=config['model'], config=config) + + data = [ + {"animal": "duck"}, + {"animal": "dog"}, + {"animal": "cat"} + ] + + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + skip_system=True, + ) + @pytest.mark.asyncio async def test_embedding_model_factory_sentence_transformer() -> None: """Test that the factory creates a SentenceTransformerEmbeddingModel when given an 'st-' prefix.""" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index c31ca1be1..f5642104a 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -513,6 +513,8 @@ async def test_docs_lifecycle(subtests: SubTests, stub_data_dir: Path) -> None: def test_evidence(docs_fixture, use_batch) -> None: debug_settings = Settings.from_name("debug") debug_settings.use_batch_in_summary = use_batch + if use_batch: + debug_settings.summary_llm = "gpt-3.5-turbo" evidence = docs_fixture.get_evidence( PQASession(question="What does XAI stand for?"), settings=debug_settings, From 899de43005ed2bd92c52413d388b02a3420bd3af Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Fri, 15 Nov 2024 09:59:03 -0800 Subject: [PATCH 04/26] Removed the skip_system argument from the new classes and tests to make it compatible with #680 --- paperqa/llms.py | 14 ++++++-------- tests/test_llms.py | 3 --- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index fc28c2ed0..ec127346d 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -803,7 +803,6 @@ async def _run_chat( data: list[dict[str,str]], callbacks: list[Callable] | None = None, name: str | None = None, - skip_system: bool = False, system_prompt: str = default_system_prompt, ) -> list[LLMResult]: if callbacks: @@ -818,9 +817,9 @@ async def _run_chat( messages = [ {"role": m["role"], "content": m["content"].format(**d)} for m in ( - [human_message_prompt] - if skip_system - else [system_message_prompt, human_message_prompt] + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] ) ] batch.append(messages) @@ -945,7 +944,6 @@ async def _run_chat( data: list[dict[str,str]], callbacks: list[Callable] | None = None, name: str | None = None, - skip_system: bool = False, system_prompt: str = default_system_prompt, ) -> list[LLMResult]: if callbacks: @@ -960,9 +958,9 @@ async def _run_chat( messages = [ {"role": m["role"], "content": m["content"].format(**d)} for m in ( - [human_message_prompt] - if skip_system - else [system_message_prompt, human_message_prompt] + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] ) ] batch.append(messages) diff --git a/tests/test_llms.py b/tests/test_llms.py index eef4a7792..efbc23f02 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -191,7 +191,6 @@ async def test_run_prompt(self, config: dict[str, Any], request) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data=data, - skip_system=True, ) assert "Batch failed" in str(e_info.value) assert "not supported" in str(e_info.value) @@ -200,7 +199,6 @@ async def test_run_prompt(self, config: dict[str, Any], request) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data=data, - skip_system=True, ) assert all([completion[k].model == config['model'] for k, _ in enumerate(data)]) @@ -256,7 +254,6 @@ async def test_run_prompt(self, config: dict[str, Any], request) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data=data, - skip_system=True, ) @pytest.mark.asyncio From 16c398847b25ec90da800b77b49d691993f1bd0f Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Fri, 15 Nov 2024 16:27:23 -0800 Subject: [PATCH 05/26] Switched to async OpenAI client Also bugfix in tests and created Enums to avoid hardcoding the batch status identifiers --- paperqa/llms.py | 55 +++++++++++++++++++++++++++++++++------------ paperqa/settings.py | 8 +++---- tests/test_llms.py | 15 +++++++++++-- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index ec127346d..09ec55512 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -74,6 +74,24 @@ class EmbeddingModes(StrEnum): QUERY = "query" +class OpenAIBatchStatus(StrEnum): + COMPLETE = "completed" + PROGRESS = "in_progress" + SUCESS = "completed" + FAILURE = "failed" + EXPIRE = "expired" + CANCEL = "cancelled" + + +class AnthropicBatchStatus(StrEnum): + COMPLETE = "ended" + PROGRESS = "in_progress" + SUCESS = "succeeded" + FAILURE = "errored" + EXPIRE = "expired" + CANCEL = "canceled" + + # Estimate from OpenAI's FAQ # https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them CHARACTERS_PER_TOKEN_ASSUMPTION: float = 4.0 @@ -766,6 +784,10 @@ class OpenAIBatchLLMModel(LLMModel): default_factory=dict, description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", ) + status: OpenAIBatchStatus = Field( + default=OpenAIBatchStatus, + description="Statuses used to report the status of the API request.", + ) def write_jsonl(self, data: list[dict[str, str]], @@ -809,7 +831,6 @@ async def _run_chat( sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] - system_message_prompt = {"role": "system", "content": system_prompt} human_message_prompt = {"role": "user", "content": prompt} batch = [] @@ -861,17 +882,17 @@ async def achat(self, " OpenAIBatchLLMModel." ) - client = openai.OpenAI() + client = openai.AsyncOpenAI() with tempfile.NamedTemporaryFile(suffix=".jsonl") as tmp_file: tmp_filename = tmp_file.name self.write_jsonl(messages, tmp_filename) - file = client.files.create( - file=open(tmp_filename, "rb"), - purpose="batch" - ) + file = await client.files.create( + file=open(tmp_filename, "rb"), + purpose="batch" + ) - batch = client.batches.create( + batch = await client.batches.create( input_file_id=file.id, endpoint="/v1/chat/completions", completion_window="24h", @@ -880,13 +901,15 @@ async def achat(self, } ) - while batch.status != "completed": - batch = client.batches.retrieve(batch.id) - if batch.status == "failed": + while batch.status != self.status.COMPLETE: + batch = await client.batches.retrieve(batch.id) + if batch.status == self.status.FAILURE: raise Exception("Batch failed. \n\nReason: \n" + "\n".join([k.message for k in batch.errors.data])) + elif batch.status == self.status.CANCEL: + raise Exception("Batch was cancelled.") await asyncio.sleep(5) - responses = client.files.content(batch.output_file_id) + responses = await client.files.content(batch.output_file_id) response_lines = responses.read().decode('utf-8').splitlines() responses = [json.loads(line) for line in response_lines] sorted_responses = sorted(responses, key=lambda x: int(x["custom_id"])) # The batchAPI doesn't guarantee the order of the responses @@ -929,6 +952,10 @@ class AnthropicBatchLLMModel(LLMModel): default_factory=dict, description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", ) + status: AnthropicBatchStatus = Field( + default=AnthropicBatchStatus, + description="Statuses used to report the status of the API request.", + ) @rate_limited async def acomplete(self): @@ -950,7 +977,6 @@ async def _run_chat( sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] - system_message_prompt = {"role": "system", "content": system_prompt} human_message_prompt = {"role": "user", "content": prompt} batch = [] @@ -998,12 +1024,13 @@ async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: requests=requests ) - while batch.processing_status != "ended": + while batch.processing_status != self.status.COMPLETE: batch = client.beta.messages.batches.retrieve(batch.id) print(batch.processing_status) await asyncio.sleep(5) responses = client.beta.messages.batches.results(batch.id) + # TODO: [WIP] Extract the completions from response. But I am having a bad time waiting for the API to return the results. return @@ -1014,7 +1041,7 @@ async def achat_iter(self): raise NotImplementedError("support to callbacks is not implemented yet") def infer_llm_type(self): - return "chat" #TODO: Support completion models + return "chat" def count_tokens(self, text: str) -> int: return len(text) // 4 diff --git a/paperqa/settings.py b/paperqa/settings.py index 0e91e9b18..46a35e9e0 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -614,11 +614,9 @@ class Settings(BaseSettings): use_batch_in_summary: bool = Field( default=False, description=( - "Whether to use batch API for LLMs in summarization." - "This option requests a batch of summaries from the LLM on the `GatherEvidence` step." - "It uses all the candidate papers found in the `GatherEvidence` step" - "to generate a list of prompts that are formatted accordingly to the" - "requirements of the llm provider." + "Whether to use batch API for LLMs in summarization, " + "which means multiple messages are sent in one API request " + "to the LLM provider's batch API." "This option is only available for Claude(https://docs.anthropic.com/en/api/creating-message-batches)" "and OpenAI (https://platform.openai.com/docs/guides/batch) chat models." ), diff --git a/tests/test_llms.py b/tests/test_llms.py index efbc23f02..1ea9e0214 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -170,16 +170,24 @@ def config(self, request) -> dict[str, Any]: "max_tokens": 64, } - @pytest.mark.asyncio + # @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON])# , "body"]) @pytest.mark.parametrize( "config",[ pytest.param("gpt-4o-mini", id="chat-model"), pytest.param("gpt-3.5-turbo-instruct", id="completion-model") ], indirect=True ) + @pytest.mark.asyncio async def test_run_prompt(self, config: dict[str, Any], request) -> None: llm = OpenAIBatchLLMModel(name=config['model'], config=config) + outputs = [] + def accum(x) -> None: + outputs.append(x) + + async def ac(x) -> None: + pass + data = [ {"animal": "duck"}, {"animal": "dog"}, @@ -199,14 +207,16 @@ async def test_run_prompt(self, config: dict[str, Any], request) -> None: completion = await llm.run_prompt( prompt="The {animal} says", data=data, + callbacks=[accum, ac], ) assert all([completion[k].model == config['model'] for k, _ in enumerate(data)]) assert all([completion[k].seconds_to_first_token > 0 for k, _ in enumerate(data)]) assert all([completion[k].prompt_count > 0 for k, _ in enumerate(data)]) assert all([completion[k].completion_count > 0 for k, _ in enumerate(data)]) - assert all([completion[k].completion_count < config['max_tokens'] for k, _ in enumerate(data)]) + assert all([completion[k].completion_count <= config['max_tokens'] for k, _ in enumerate(data)]) assert sum([completion[k].cost for k, _ in enumerate(data)]) > 0 + assert all([str(completion[k]) == outputs[k] for k, _ in enumerate(data)]) @pytest.mark.parametrize( "config",[ @@ -236,6 +246,7 @@ def config(self, request) -> dict[str, Any]: "max_tokens": 64, } + @pytest.mark.vcr @pytest.mark.asyncio @pytest.mark.parametrize( "config",[ From d10a268c67c15388ffee7056beb8b15c6fc7cf30 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Fri, 15 Nov 2024 16:54:15 -0800 Subject: [PATCH 06/26] Added logging to the batch processing The timelimit and the pooling time for the batches are now in the Settings --- paperqa/llms.py | 12 +++++++++++- paperqa/settings.py | 8 ++++++++ tests/test_llms.py | 2 ++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index 09ec55512..c77446456 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -23,6 +23,7 @@ import json import os import tempfile +import logging import numpy as np import tiktoken @@ -40,6 +41,8 @@ from paperqa.types import Embeddable, LLMResult from paperqa.utils import is_coroutine_callable +logger = logging.getLogger(__name__) + PromptRunner = Callable[ [dict, list[Callable[[str], None]] | None, str | None], Awaitable[LLMResult], @@ -901,13 +904,20 @@ async def achat(self, } ) + start_clock = asyncio.get_running_loop().time() while batch.status != self.status.COMPLETE: batch = await client.batches.retrieve(batch.id) if batch.status == self.status.FAILURE: raise Exception("Batch failed. \n\nReason: \n" + "\n".join([k.message for k in batch.errors.data])) elif batch.status == self.status.CANCEL: raise Exception("Batch was cancelled.") - await asyncio.sleep(5) + + batch_time = asyncio.get_running_loop().time() - start_clock + if batch_time > self.config.get('batch_summary_timelimit'): + raise Exception("Batch took too long to complete.") + + logger.info(f"Summary batch status: {batch.status} | Time elapsed: {batch_time}") + await asyncio.sleep(self.config.get('batch_polling_interval')) responses = await client.files.content(batch.output_file_id) response_lines = responses.read().decode('utf-8').splitlines() diff --git a/paperqa/settings.py b/paperqa/settings.py index 46a35e9e0..f0789d72d 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -621,6 +621,14 @@ class Settings(BaseSettings): "and OpenAI (https://platform.openai.com/docs/guides/batch) chat models." ), ) + batch_summary_timelimit: int = Field( + default=24*60*60, + description="Time limit for batch summarization in seconds", + ) + batch_polling_interval: int = Field( + default=30, + description="Polling interval for batch summarization in seconds", + ) embedding: str = Field( default="text-embedding-3-small", description="Default embedding model for texts", diff --git a/tests/test_llms.py b/tests/test_llms.py index 1ea9e0214..ec531a24f 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -168,6 +168,8 @@ def config(self, request) -> dict[str, Any]: "model": model_name, "temperature": 0.0, "max_tokens": 64, + "batch_summary_timelimit": 24*60*60, + "batch_polling_interval": 5, } # @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON])# , "body"]) From 0fe9aa1e4b730283e6d6b1354e122371738b4a8a Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Sun, 17 Nov 2024 18:20:20 -0800 Subject: [PATCH 07/26] Created mock server to test openAI batch API --- paperqa/settings.py | 4 +- tests/test_llms.py | 156 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 125 insertions(+), 35 deletions(-) diff --git a/paperqa/settings.py b/paperqa/settings.py index f0789d72d..72143653d 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -816,8 +816,8 @@ def get_summary_llm(self) -> LiteLLMModel: if self.use_batch_in_summary: import openai client = openai.OpenAI() - openai_models = [k.id for _, k in enumerate(client.models.list().data) - if k.owned_by in ['system', "openai"]] + openai_models = [k.id for k in client.models.list().data + if k.owned_by in ('system', "openai")] if self.summary_llm.startswith("claude-"): return AnthropicBatchLLMModel( name=self.summary_llm, diff --git a/tests/test_llms.py b/tests/test_llms.py index ec531a24f..9e793221a 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -1,10 +1,11 @@ import pathlib import pickle from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import litellm import pytest +import json from paperqa import ( HybridEmbeddingModel, @@ -16,7 +17,11 @@ SparseEmbeddingModel, embedding_model_factory, ) -from paperqa.llms import Chunk +from paperqa.llms import ( + Chunk, + OpenAIBatchStatus, + AnthropicBatchStatus, +) from tests.conftest import VCR_DEFAULT_MATCH_ON @@ -160,6 +165,7 @@ def test_pickling(self, tmp_path: pathlib.Path) -> None: assert llm.config == rehydrated_llm.config assert llm.router.deployment_names == rehydrated_llm.router.deployment_names + class TestOpenAIBatchLLMModel: @pytest.fixture(scope="class") def config(self, request) -> dict[str, Any]: @@ -172,7 +178,6 @@ def config(self, request) -> dict[str, Any]: "batch_polling_interval": 5, } - # @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON])# , "body"]) @pytest.mark.parametrize( "config",[ pytest.param("gpt-4o-mini", id="chat-model"), @@ -180,45 +185,130 @@ def config(self, request) -> dict[str, Any]: ], indirect=True ) @pytest.mark.asyncio - async def test_run_prompt(self, config: dict[str, Any], request) -> None: - llm = OpenAIBatchLLMModel(name=config['model'], config=config) - - outputs = [] - def accum(x) -> None: - outputs.append(x) + async def test_run_prompt(self, monkeypatch, config: dict[str, Any], request) -> None: + ############################## + # # + # Create a mock batch client # + # # + ############################## + + mock_client = AsyncMock() + + # Define mock methods for the client + mock_files_create = AsyncMock() + mock_batches_create = AsyncMock() + mock_batches_retrieve = AsyncMock() + mock_files_content = AsyncMock() + + mock_client.files.create = mock_files_create + mock_client.batches.create = mock_batches_create + mock_client.batches.retrieve = mock_batches_retrieve + mock_client.files.content = mock_files_content - async def ac(x) -> None: - pass + mock_file_id = 'file-123' + mock_files_create.return_value = MagicMock(id=mock_file_id) + + mock_batch_id = 'batch_123' + mock_batches_create.return_value = MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.PROGRESS) - data = [ - {"animal": "duck"}, - {"animal": "dog"}, - {"animal": "cat"} + if request.node.name == "test_run_prompt[completion-model]": + batch_retrieve_calls = [ + MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.FAILURE, + errors=MagicMock( + data=[ + MagicMock(message="Batch failed: The model gpt-3.5-turbo-instruct is not supported for batch completions.")] + ) + ), ] - + elif request.node.name == "test_run_prompt[chat-model]": + batch_retrieve_calls = [ + MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.PROGRESS), + MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.COMPLETE, output_file_id='file-789') + ] + mock_batches_retrieve.side_effect = batch_retrieve_calls + if request.node.name == "test_run_prompt[completion-model]": - with pytest.raises(Exception) as e_info: + sample_responses = [] + elif request.node.name == "test_run_prompt[chat-model]": + sample_responses = [ + { + 'id': 'file-789', 'custom_id': '0', + 'response': { + 'body': { + 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The duck says "quack." This vocalization is characteristic of the species Anas platyrhynchos, commonly known as the mallard duck, which is often used as a representative example for the duck family, Anatidae.', 'refusal': None}, 'logprobs': None, 'finish_reason': 'stop'}], 'usage': { + 'prompt_tokens': 46, 'completion_tokens': 47, 'total_tokens': 93, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0} + }, + } + }, + }, + { + 'id': 'file-789', 'custom_id': '1', + 'response': { + 'body': { + 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The dog says "bark." This is a vocalization commonly associated with canines, used for communication purposes such as alerting, expressing excitement, or seeking attention.', 'refusal': None}, 'logprobs': None, 'finish_reason': 'stop'}], + 'usage': { + 'prompt_tokens': 46, 'completion_tokens': 34, 'total_tokens': 80, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0} + }, + } + }, + }, + { + 'id': 'file-789', 'custom_id': '2', + 'response': { + 'body': { + 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'It seems you\'re quoting or referencing "the cat says." If you\'re looking for a specific context, such as a phrase, a song, or a scientific observation (like feline vocalizations), please provide more details for a precise response.', 'refusal': None}, 'logprobs': None, 'finish_reason': 'stop'}], + 'usage': { + 'prompt_tokens': 46, 'completion_tokens': 46, 'total_tokens': 92, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0} + }, + } + }, + } + ] + + response_data = '\n'.join(json.dumps(resp) for resp in sample_responses) + mock_response_content = MagicMock() + mock_response_content.read.return_value = response_data.encode('utf-8') + mock_files_content.return_value = mock_response_content + + with patch('openai.AsyncOpenAI', return_value=mock_client): + llm = OpenAIBatchLLMModel(name=config['model'], config=config) + + outputs = [] + def accum(x) -> None: + outputs.append(x) + + async def ac(x) -> None: + pass + + data = [ + {"animal": "duck"}, + {"animal": "dog"}, + {"animal": "cat"} + ] + + if request.node.name == "test_run_prompt[completion-model]": + with pytest.raises(Exception) as e_info: + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + ) + assert "Batch failed" in str(e_info.value) + assert "not supported" in str(e_info.value) + + if request.node.name == "test_run_prompt[chat-model]": completion = await llm.run_prompt( prompt="The {animal} says", data=data, + callbacks=[accum, ac], ) - assert "Batch failed" in str(e_info.value) - assert "not supported" in str(e_info.value) - - if request.node.name == "test_run_prompt[chat-model]": - completion = await llm.run_prompt( - prompt="The {animal} says", - data=data, - callbacks=[accum, ac], - ) - assert all([completion[k].model == config['model'] for k, _ in enumerate(data)]) - assert all([completion[k].seconds_to_first_token > 0 for k, _ in enumerate(data)]) - assert all([completion[k].prompt_count > 0 for k, _ in enumerate(data)]) - assert all([completion[k].completion_count > 0 for k, _ in enumerate(data)]) - assert all([completion[k].completion_count <= config['max_tokens'] for k, _ in enumerate(data)]) - assert sum([completion[k].cost for k, _ in enumerate(data)]) > 0 - assert all([str(completion[k]) == outputs[k] for k, _ in enumerate(data)]) + assert all([completion[k].model == config['model'] for k in range(len(data))]) + assert all([completion[k].seconds_to_first_token > 0 for k in range(len(data))]) + assert all([completion[k].prompt_count > 0 for k in range(len(data))]) + assert all([completion[k].completion_count > 0 for k in range(len(data))]) + assert all([completion[k].completion_count <= config['max_tokens'] for k in range(len(data))]) + assert sum([completion[k].cost for k in range(len(data))]) > 0 + assert all([str(completion[k]) == outputs[k] for k in range(len(data))]) @pytest.mark.parametrize( "config",[ From a9ad540ca618392e87404ab5e69a2116746f6fb9 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Mon, 18 Nov 2024 11:15:47 -0800 Subject: [PATCH 08/26] Implemented batch support to Anthropic --- paperqa/llms.py | 46 ++++++++++++-- tests/test_llms.py | 147 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 168 insertions(+), 25 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index c77446456..5299de71b 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1005,6 +1005,27 @@ async def _run_chat( chunks = await self.achat(batch) batch_time = asyncio.get_running_loop().time() - start_clock + if callbacks: + for chunk in chunks: + await do_callbacks( + async_callbacks, sync_callbacks, chunk.text, name + ) + + results = [ + LLMResult( + model=self.name, + name=name, + prompt=messages, + prompt_count=chunk.prompt_tokens, + text=chunk.text, + completion_count=chunk.completion_tokens, + seconds_to_first_token=batch_time, + seconds_to_last_token=batch_time, + ) for messages, chunk in zip(batch, chunks) + ] + + return results + @rate_limited async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: try: @@ -1034,16 +1055,29 @@ async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: requests=requests ) + start_clock = asyncio.get_running_loop().time() while batch.processing_status != self.status.COMPLETE: batch = client.beta.messages.batches.retrieve(batch.id) - print(batch.processing_status) - await asyncio.sleep(5) + + batch_time = asyncio.get_running_loop().time() - start_clock + if batch_time > self.config.get('batch_summary_timelimit'): + raise Exception("Batch took too long to complete.") - responses = client.beta.messages.batches.results(batch.id) - + logger.info(f"Summary batch status: {batch.processing_status} | Time elapsed: {batch_time}") + await asyncio.sleep(self.config.get('batch_polling_interval')) + + responses = [r for r in client.beta.messages.batches.results(batch.id)] + sorted_responses = sorted(responses, key=lambda x: int(x.custom_id)) # The batchAPI doesn't guarantee the order of the responses - # TODO: [WIP] Extract the completions from response. But I am having a bad time waiting for the API to return the results. - return + chunks = [ + Chunk( + text=response.result.message.content[0].text, + prompt_tokens=response.result.message.usage.input_tokens, + completion_tokens=response.result.message.usage.output_tokens, + ) for response in sorted_responses + ] + + return chunks @rate_limited diff --git a/tests/test_llms.py b/tests/test_llms.py index 9e793221a..1e4b4281b 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -186,15 +186,9 @@ def config(self, request) -> dict[str, Any]: ) @pytest.mark.asyncio async def test_run_prompt(self, monkeypatch, config: dict[str, Any], request) -> None: - ############################## - # # - # Create a mock batch client # - # # - ############################## - + mock_client = AsyncMock() - # Define mock methods for the client mock_files_create = AsyncMock() mock_batches_create = AsyncMock() mock_batches_retrieve = AsyncMock() @@ -336,9 +330,10 @@ def config(self, request) -> dict[str, Any]: "model": model_name, "temperature": 0.0, "max_tokens": 64, + "batch_summary_timelimit": 24*60*60, + "batch_polling_interval": 5, } - @pytest.mark.vcr @pytest.mark.asyncio @pytest.mark.parametrize( "config",[ @@ -346,18 +341,132 @@ def config(self, request) -> dict[str, Any]: ], indirect=True ) async def test_run_prompt(self, config: dict[str, Any], request) -> None: - llm = AnthropicBatchLLMModel(name=config['model'], config=config) - data = [ - {"animal": "duck"}, - {"animal": "dog"}, - {"animal": "cat"} - ] - - completion = await llm.run_prompt( - prompt="The {animal} says", - data=data, - ) + mock_client = AsyncMock() + + # Define mock methods for the client + mock_client = MagicMock() + mock_batches = MagicMock() + mock_messages = MagicMock() + mock_beta = MagicMock() + + mock_client.beta = mock_beta + mock_beta.messages = mock_messages + mock_messages.batches = mock_batches + + mock_batches.create = MagicMock() + mock_batches.retrieve = MagicMock() + mock_batches.results = MagicMock() + + mock_batch_id = 'msgbatch_123' + mock_batch = MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS) + mock_batches.create.return_value = mock_batch + + batch_retrieve_call = [ + MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS), + MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.COMPLETE, ) + ] + mock_batches.retrieve.side_effect = batch_retrieve_call + + mock_responses = [ + MagicMock( + custom_id='0', + result=MagicMock( + message=MagicMock( + id='msg_0143L9rPswgaUyENkHkPJLcn', + content=[ + MagicMock( + text="I don't actually hear any ducks saying anything. As an AI assistant, I don't have the ability to hear or interpret sounds from the physical world. I can only respond based on the text you provide to me through this chat interface. If you'd like, you can tell me what you think the duck is", + ) + ], + model='claude-3-haiku-20240307', + role='assistant', + stop_reason='max_tokens', + stop_sequence=None, + type='message', + usage=MagicMock( + input_tokens=10, + output_tokens=64 + ) + ), + type='succeeded' + ) + ), + MagicMock( + custom_id='1', + result=MagicMock( + message=MagicMock( + id='msg_01KujiHEB5S8pfRUCmrbabu4', + content=[ + MagicMock( + text="Unfortunately, I don't actually hear a dog speaking. As an AI assistant without physical senses, I can't directly perceive animals making sounds. Could you please provide more context about what the dog is saying, or what you would like me to respond to regarding the dog? I'd be happy to try to assist", + ) + ], + model='claude-3-haiku-20240307', + role='assistant', + stop_reason='max_tokens', + stop_sequence=None, + type='message', + usage=MagicMock( + input_tokens=10, + output_tokens=64 + ) + ), + type='succeeded' + ) + ), + MagicMock( + custom_id='2', + result=MagicMock( + message=MagicMock( + id='msg_01Pf2LqV7wjnwqerkZubbofA', + content=[ + MagicMock( + text="I'm afraid I don't actually hear a cat speaking. As an AI assistant, I don't have the ability to hear or communicate with animals directly. I can only respond based on the text you provide to me. If you'd like, you can tell me what you imagine the cat is saying, and I'll", + ) + ], + model='claude-3-haiku-20240307', + role='assistant', + stop_reason='max_tokens', + stop_sequence=None, + type='message', + usage=MagicMock( + input_tokens=10, + output_tokens=64 + ) + ), + type='succeeded' + ) + ), + ] + + # Create a generator function + def mock_results_generator(batch_id): + for response in mock_responses: + yield response + + mock_batches.results.return_value = mock_results_generator(mock_batch_id) + + with patch('anthropic.Anthropic', return_value=mock_client): + llm = AnthropicBatchLLMModel(name=config['model'], config=config) + + data = [ + {"animal": "duck"}, + {"animal": "dog"}, + {"animal": "cat"} + ] + + completion = await llm.run_prompt( + prompt="The {animal} says", + data=data, + ) + + assert all([completion[k].model == config['model'] for k in range(len(data))]) + assert all([completion[k].seconds_to_first_token > 0 for k in range(len(data))]) + assert all([completion[k].prompt_count > 0 for k in range(len(data))]) + assert all([completion[k].completion_count > 0 for k in range(len(data))]) + assert all([completion[k].completion_count <= config['max_tokens'] for k in range(len(data))]) + assert sum([completion[k].cost for k in range(len(data))]) > 0 @pytest.mark.asyncio async def test_embedding_model_factory_sentence_transformer() -> None: From 723650dfa82ec46aa4100ddbc332b270fb8d424e Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Mon, 18 Nov 2024 14:31:00 -0800 Subject: [PATCH 09/26] Updated uv.lock to include imports for the batch API --- uv.lock | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index a45f5a877..7f8c43f23 100644 --- a/uv.lock +++ b/uv.lock @@ -110,6 +110,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, ] +[[package]] +name = "anthropic" +version = "0.39.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/02/2ea51930009d7537c4648f51d1bb3202ec76704cbb39a2a863ab38bee3dd/anthropic-0.39.0.tar.gz", hash = "sha256:94671cc80765f9ce693f76d63a97ee9bef4c2d6063c044e983d21a2e262f63ba", size = 189339 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/61/2580eaa171cab20708d59d39cadd15f78a6c617759e8d0a12e18fe3302d1/anthropic-0.39.0-py3-none-any.whl", hash = "sha256:ea17093ae0ce0e1768b0c46501d6086b5bcd74ff39d68cd2d6396374e9de7c09", size = 198392 }, +] + [[package]] name = "anyio" version = "4.6.2.post1" @@ -1511,7 +1529,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.4.1.dev16+ga004d22" +version = "5.4.1.dev28+g9a0a6c4" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1535,6 +1553,10 @@ dependencies = [ ] [package.optional-dependencies] +batch = [ + { name = "anthropic" }, + { name = "openai" }, +] datasets = [ { name = "datasets" }, ] @@ -1583,6 +1605,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiohttp" }, + { name = "anthropic", marker = "extra == 'batch'" }, { name = "anyio" }, { name = "coredis" }, { name = "datasets", marker = "extra == 'datasets'" }, @@ -1593,6 +1616,7 @@ requires-dist = [ { name = "limits" }, { name = "litellm", specifier = ">=1.44" }, { name = "numpy" }, + { name = "openai", marker = "extra == 'batch'" }, { name = "pandas-stubs", marker = "extra == 'typing'" }, { name = "pybtex" }, { name = "pydantic", specifier = "~=2.0" }, From 660bfa0b570785673adef6493fa0a3adb7a1a8b1 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Mon, 18 Nov 2024 14:32:20 -0800 Subject: [PATCH 10/26] Implements tests with a mocked server for anthropic --- paperqa/core.py | 2 +- paperqa/llms.py | 9 +++--- tests/test_llms.py | 73 ++++++++++++++++++++++--------------------- tests/test_paperqa.py | 10 +----- 4 files changed, 44 insertions(+), 50 deletions(-) diff --git a/paperqa/core.py b/paperqa/core.py index 786e26ddd..a3ffc1157 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -124,7 +124,7 @@ async def gather_with_batch( parser: Callable[[str], dict[str, Any]] | None = None, callbacks: list[Callable[[str], None]] | None = None, ) -> list[tuple[Context, LLMResult]]: - """Gathers a batch of results for a given text.""" + """Gathers evidence considering a batch of texts. The completions are obtained using a batch API.""" data = [ {"question": question, "citation": m.name + ": " + m.doc.formatted_citation, diff --git a/paperqa/llms.py b/paperqa/llms.py index 6da1d6695..ea753763f 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1048,7 +1048,7 @@ async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: " AnthropicBatchLLMModel." ) - client = anthropic.Anthropic() + client = anthropic.AsyncAnthropic() requests = [ Request( @@ -1061,13 +1061,13 @@ async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: ) for i, m in enumerate(messages) ] - batch = client.beta.messages.batches.create( + batch = await client.beta.messages.batches.create( requests=requests ) start_clock = asyncio.get_running_loop().time() while batch.processing_status != self.status.COMPLETE: - batch = client.beta.messages.batches.retrieve(batch.id) + batch = await client.beta.messages.batches.retrieve(batch.id) batch_time = asyncio.get_running_loop().time() - start_clock if batch_time > self.config.get('batch_summary_timelimit'): @@ -1076,7 +1076,8 @@ async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: logger.info(f"Summary batch status: {batch.processing_status} | Time elapsed: {batch_time}") await asyncio.sleep(self.config.get('batch_polling_interval')) - responses = [r for r in client.beta.messages.batches.results(batch.id)] + api_responses = await client.beta.messages.batches.results(batch.id) + responses = [r for r in api_responses] sorted_responses = sorted(responses, key=lambda x: int(x.custom_id)) # The batchAPI doesn't guarantee the order of the responses chunks = [ diff --git a/tests/test_llms.py b/tests/test_llms.py index 1e4b4281b..b86fe44a5 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -4,6 +4,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import litellm +import openai +import anthropic import pytest import json @@ -187,23 +189,21 @@ def config(self, request) -> dict[str, Any]: @pytest.mark.asyncio async def test_run_prompt(self, monkeypatch, config: dict[str, Any], request) -> None: - mock_client = AsyncMock() + mock_client = AsyncMock(spec_set=openai.AsyncOpenAI()) - mock_files_create = AsyncMock() - mock_batches_create = AsyncMock() - mock_batches_retrieve = AsyncMock() - mock_files_content = AsyncMock() - - mock_client.files.create = mock_files_create - mock_client.batches.create = mock_batches_create - mock_client.batches.retrieve = mock_batches_retrieve - mock_client.files.content = mock_files_content - mock_file_id = 'file-123' - mock_files_create.return_value = MagicMock(id=mock_file_id) + mock_client.files.create = AsyncMock( + return_value=MagicMock( + id=mock_file_id + ) + ) mock_batch_id = 'batch_123' - mock_batches_create.return_value = MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.PROGRESS) + mock_client.batches.create = AsyncMock( + return_value=MagicMock( + id=mock_batch_id, + status=OpenAIBatchStatus.PROGRESS) + ) if request.node.name == "test_run_prompt[completion-model]": batch_retrieve_calls = [ @@ -219,12 +219,11 @@ async def test_run_prompt(self, monkeypatch, config: dict[str, Any], request) -> MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.PROGRESS), MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.COMPLETE, output_file_id='file-789') ] - mock_batches_retrieve.side_effect = batch_retrieve_calls + mock_client.batches.retrieve = AsyncMock( + side_effect=batch_retrieve_calls + ) - if request.node.name == "test_run_prompt[completion-model]": - sample_responses = [] - elif request.node.name == "test_run_prompt[chat-model]": - sample_responses = [ + sample_responses = [ { 'id': 'file-789', 'custom_id': '0', 'response': { @@ -262,7 +261,9 @@ async def test_run_prompt(self, monkeypatch, config: dict[str, Any], request) -> response_data = '\n'.join(json.dumps(resp) for resp in sample_responses) mock_response_content = MagicMock() mock_response_content.read.return_value = response_data.encode('utf-8') - mock_files_content.return_value = mock_response_content + mock_client.files.content = AsyncMock( + return_value = mock_response_content + ) with patch('openai.AsyncOpenAI', return_value=mock_client): llm = OpenAIBatchLLMModel(name=config['model'], config=config) @@ -342,31 +343,29 @@ def config(self, request) -> dict[str, Any]: ) async def test_run_prompt(self, config: dict[str, Any], request) -> None: - mock_client = AsyncMock() + mock_client = AsyncMock(spec_set=anthropic.AsyncAnthropic()) # Define mock methods for the client mock_client = MagicMock() mock_batches = MagicMock() - mock_messages = MagicMock() - mock_beta = MagicMock() - - mock_client.beta = mock_beta - mock_beta.messages = mock_messages - mock_messages.batches = mock_batches - - mock_batches.create = MagicMock() - mock_batches.retrieve = MagicMock() - mock_batches.results = MagicMock() + # mock_client.beta = MagicMock() + # mock_client.beta.messages = MagicMock() + mock_client.beta.messages.batches = mock_batches mock_batch_id = 'msgbatch_123' - mock_batch = MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS) - mock_batches.create.return_value = mock_batch + mock_batches.create = AsyncMock( + return_value=MagicMock( + id=mock_batch_id, + processing_status=AnthropicBatchStatus.PROGRESS), + ) batch_retrieve_call = [ MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS), - MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.COMPLETE, ) + MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.COMPLETE) ] - mock_batches.retrieve.side_effect = batch_retrieve_call + mock_batches.retrieve = AsyncMock( + side_effect=batch_retrieve_call + ) mock_responses = [ MagicMock( @@ -445,9 +444,11 @@ def mock_results_generator(batch_id): for response in mock_responses: yield response - mock_batches.results.return_value = mock_results_generator(mock_batch_id) + mock_batches.results = AsyncMock( + return_value=mock_results_generator(mock_batch_id) + ) - with patch('anthropic.Anthropic', return_value=mock_client): + with patch('anthropic.AsyncAnthropic', return_value=mock_client): llm = AnthropicBatchLLMModel(name=config['model'], config=config) data = [ diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index f5642104a..a9d4c0718 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -505,16 +505,8 @@ async def test_docs_lifecycle(subtests: SubTests, stub_data_dir: Path) -> None: assert docs.texts assert all(t not in docs.texts_index for t in docs.texts) -@pytest.mark.parametrize("use_batch", [ - pytest.param(True, id="using-batch"), - pytest.param(False, id="not-using-batch") - ] - ) -def test_evidence(docs_fixture, use_batch) -> None: +def test_evidence(docs_fixture) -> None: debug_settings = Settings.from_name("debug") - debug_settings.use_batch_in_summary = use_batch - if use_batch: - debug_settings.summary_llm = "gpt-3.5-turbo" evidence = docs_fixture.get_evidence( PQASession(question="What does XAI stand for?"), settings=debug_settings, From 977a025d85ea77162a54539db8c708b30cd4a1c4 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Tue, 19 Nov 2024 08:07:34 -0800 Subject: [PATCH 11/26] Added type hints to satisfy the pre-commit --- paperqa/__init__.py | 8 +- paperqa/agents/env.py | 8 +- paperqa/core.py | 104 ++++++---- paperqa/docs.py | 24 +-- paperqa/llms.py | 400 ++++++++++++++++++++------------------ paperqa/settings.py | 43 +++-- pyproject.toml | 8 +- tests/test_llms.py | 436 ++++++++++++++++++++++++++---------------- tests/test_paperqa.py | 3 +- uv.lock | 2 +- 10 files changed, 604 insertions(+), 432 deletions(-) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 2346f0b0c..ab8ea15ba 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -11,15 +11,15 @@ from paperqa.agents.models import QueryRequest # noqa: E402 from paperqa.docs import Docs, PQASession, print_callback # noqa: E402 from paperqa.llms import ( # noqa: E402 + AnthropicBatchLLMModel, EmbeddingModel, HybridEmbeddingModel, LiteLLMEmbeddingModel, LiteLLMModel, - OpenAIBatchLLMModel, - AnthropicBatchLLMModel, LLMModel, LLMResult, NumpyVectorStore, + OpenAIBatchLLMModel, SentenceTransformerEmbeddingModel, SparseEmbeddingModel, embedding_model_factory, @@ -30,6 +30,7 @@ __all__ = [ "Answer", + "AnthropicBatchLLMModel", "Context", "Doc", "DocDetails", @@ -40,9 +41,8 @@ "LLMResult", "LiteLLMEmbeddingModel", "LiteLLMModel", - "OpenAIBatchLLMModel", - "AnthropicBatchLLMModel", "NumpyVectorStore", + "OpenAIBatchLLMModel", "PQASession", "QueryRequest", "SentenceTransformerEmbeddingModel", diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index 9660eef99..29ef53c01 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -13,7 +13,11 @@ ) from paperqa.docs import Docs -from paperqa.llms import EmbeddingModel, LiteLLMModel +from paperqa.llms import ( + EmbeddingModel, + LiteLLMModel, + LLMBatchModel, +) from paperqa.settings import Settings from paperqa.types import PQASession from paperqa.utils import get_year @@ -36,7 +40,7 @@ def settings_to_tools( settings: Settings, llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, - summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, + summary_llm_model: LiteLLMModel | LLMBatchModel | None = POPULATE_FROM_SETTINGS, embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS, ) -> list[Tool]: """ diff --git a/paperqa/core.py b/paperqa/core.py index a3ffc1157..18a327cd8 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -68,12 +68,13 @@ async def map_fxn_summary( success = False if prompt_runner: - llm_result = await prompt_runner( + result = await prompt_runner( {"question": question, "citation": citation, "text": text.text} | (extra_prompt_data or {}), callbacks, "evidence:" + text.name, ) + llm_result = result if isinstance(result, LLMResult) else result[0] context = llm_result.text result_data = parser(context) if parser else {} success = bool(result_data) @@ -116,6 +117,7 @@ async def map_fxn_summary( llm_result, ) + async def gather_with_batch( matches: list[Text], question: str, @@ -123,42 +125,66 @@ async def gather_with_batch( extra_prompt_data: dict[str, str] | None = None, parser: Callable[[str], dict[str, Any]] | None = None, callbacks: list[Callable[[str], None]] | None = None, - ) -> list[tuple[Context, LLMResult]]: - """Gathers evidence considering a batch of texts. The completions are obtained using a batch API.""" - data = [ - {"question": question, - "citation": m.name + ": " + m.doc.formatted_citation, - "text": m.text} | - extra_prompt_data or {} - for m in matches - ] - - llm_results = await prompt_runner( - data, - callbacks, - ) +) -> list[tuple[Context, LLMResult]]: + """ + Gathers evidence considering a batch of texts. The completions are obtained using a batch API. - results_data = [] - scores = [] - for r in llm_results: - try: - results_data.append(parser(r.text)) - scores.append(r.pop("relevance_score")) - # just in case question was present - r.pop("question", None) - except: - results_data.append({}) - scores.append(extract_score(r.text)) - - return [ - ( - Context( - context=strip_citations(llm_result.text), - text=m, - model_extra={}, - score=score, - **r, - ), - llm_result, - ) for r, m, llm_result, score in zip(results_data, matches, llm_results, scores) - ] \ No newline at end of file + Args: + matches (list[Text]): A list of text matches to gather evidence from. + question (str): The question to be answered. + prompt_runner (PromptRunner | None): The prompt runner to use for obtaining completions. + extra_prompt_data (dict[str, str] | None, optional): Additional data to include in the prompt. + parser (Callable[[str], dict[str, Any]] | None, optional): A function to parse the LLM result text. + callbacks (list[Callable[[str], None]] | None, optional): A list of callback functions to be called + with the LLM result text. + + Returns: + list[tuple[Context, LLMResult]]: A list of tuples containing the context and LLM result for each match. + """ + data = [ + { + "question": question, + "citation": m.name + ": " + m.doc.formatted_citation, + "text": m.text, + } + | (extra_prompt_data or {}) + for m in matches + ] + + llm_results : list[LLMResult] = [] + if prompt_runner: + result = await prompt_runner( + data, + callbacks, + "evidence:" + matches[0].name, + ) + llm_results = result if isinstance(result, list) else [result] + + results_data = [] + scores = [] + for r in llm_results: + if parser: + res = parser(r.text) + results_data.append(res) + scores.append(res.pop("relevance_score")) + # just in case question was present + res.pop("question", None) + else: + results_data.append({}) + scores.append(extract_score(r.text)) + + return [ + ( + Context( + context=strip_citations(llm_result.text), + text=m, + model_extra={}, + score=score, + **r, + ), + llm_result, + ) + for r, m, llm_result, score in zip( + results_data, matches, llm_results, scores, strict=True + ) + ] diff --git a/paperqa/docs.py b/paperqa/docs.py index 0f5d8832a..592f63471 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -22,13 +22,10 @@ ) from paperqa.clients import DEFAULT_CLIENTS, DocMetadataClient -from paperqa.core import ( - llm_parse_json, - map_fxn_summary, - gather_with_batch -) +from paperqa.core import gather_with_batch, llm_parse_json, map_fxn_summary from paperqa.llms import ( EmbeddingModel, + LLMBatchModel, LLMModel, NumpyVectorStore, PromptRunner, @@ -44,7 +41,6 @@ LLMResult, PQASession, Text, - Context, set_llm_session_ids, ) from paperqa.utils import ( @@ -55,8 +51,6 @@ maybe_is_text, md5sum, name_in_text, - extract_score, - strip_citations ) logger = logging.getLogger(__name__) @@ -537,14 +531,14 @@ def get_evidence( ) ) - async def aget_evidence( + async def aget_evidence( # noqa: PLR0912 self, query: PQASession | str, exclude_text_filter: set[str] | None = None, settings: MaybeSettings = None, callbacks: list[Callable] | None = None, embedding_model: EmbeddingModel | None = None, - summary_llm_model: LLMModel | None = None, + summary_llm_model: LLMModel | LLMBatchModel | None = None, ) -> PQASession: evidence_settings = get_settings(settings) @@ -609,8 +603,8 @@ async def aget_evidence( with set_llm_session_ids(session.id): if evidence_settings.use_batch_in_summary: results = await gather_with_batch( - matches = matches, - question = session.question, + matches=matches, + question=session.question, prompt_runner=prompt_runner, extra_prompt_data={ "summary_length": answer_config.evidence_summary_length, @@ -640,7 +634,7 @@ async def aget_evidence( for _, llm_result in results: session.add_tokens(llm_result) - session.contexts += [r for r, _ in results if r is not None] + session.contexts += [r for r, _ in results] return session def query( @@ -649,7 +643,7 @@ def query( settings: MaybeSettings = None, callbacks: list[Callable] | None = None, llm_model: LLMModel | None = None, - summary_llm_model: LLMModel | None = None, + summary_llm_model: LLMModel | LLMBatchModel | None = None, embedding_model: EmbeddingModel | None = None, ) -> PQASession: return get_loop().run_until_complete( @@ -669,7 +663,7 @@ async def aquery( # noqa: PLR0912 settings: MaybeSettings = None, callbacks: list[Callable] | None = None, llm_model: LLMModel | None = None, - summary_llm_model: LLMModel | None = None, + summary_llm_model: LLMModel | LLMBatchModel | None = None, embedding_model: EmbeddingModel | None = None, ) -> PQASession: diff --git a/paperqa/llms.py b/paperqa/llms.py index ea753763f..f5621eef1 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -3,6 +3,9 @@ import asyncio import contextlib import functools +import json +import logging +import tempfile from abc import ABC, abstractmethod from collections.abc import ( AsyncGenerator, @@ -16,15 +19,9 @@ from enum import StrEnum from inspect import isasyncgenfunction, signature from sys import version_info -from typing import Any, TypeVar, cast +from typing import Any, TypedDict, TypeVar, cast import litellm - -import json -import os -import tempfile -import logging - import numpy as np import tiktoken from aviary.core import ToolRequestMessage, ToolSelector @@ -45,8 +42,8 @@ logger = logging.getLogger(__name__) PromptRunner = Callable[ - [dict, list[Callable[[str], None]] | None, str | None], - Awaitable[LLMResult], + [dict | list[dict], list[Callable[[str], None]] | None, str | None], + Awaitable[LLMResult | list[LLMResult]], ] MODEL_COST_MAP = litellm.get_model_cost_map("") @@ -81,19 +78,19 @@ class EmbeddingModes(StrEnum): class OpenAIBatchStatus(StrEnum): COMPLETE = "completed" PROGRESS = "in_progress" - SUCESS = "completed" - FAILURE = "failed" - EXPIRE = "expired" - CANCEL = "cancelled" + SUCCESS = "completed" + FAILURE = "failed" + EXPIRE = "expired" + CANCEL = "cancelled" class AnthropicBatchStatus(StrEnum): COMPLETE = "ended" PROGRESS = "in_progress" - SUCESS = "succeeded" - FAILURE = "errored" - EXPIRE = "expired" - CANCEL = "canceled" + SUCCESS = "succeeded" + FAILURE = "errored" + EXPIRE = "expired" + CANCEL = "canceled" # Estimate from OpenAI's FAQ @@ -352,7 +349,7 @@ def count_tokens(self, text: str) -> int: async def run_prompt( self, prompt: str, - data: dict | list[dict[str, str]], + data: dict, callbacks: list[Callable] | None = None, name: str | None = None, system_prompt: str | None = default_system_prompt, @@ -790,55 +787,121 @@ async def select_tool( return await tool_selector(*selection_args, **selection_kwargs) -class OpenAIBatchLLMModel(LLMModel): +class LLMBatchModel(ABC, BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + llm_type: str | None = None + name: str + llm_result_callback: ( + Callable[[LLMResult], None] | Callable[[LLMResult], Awaitable[None]] | None + ) = Field( + default=None, + description=( + "An async callback that will be executed on each" + " LLMResult (different than callbacks that execute on each chunk)" + ), + exclude=True, + ) + config: dict = Field(default_factory=dict) + + async def run_prompt( + self, + prompt: str, + data: list[dict], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str | None = default_system_prompt, + ) -> list[LLMResult]: + if self.llm_type is None: + self.llm_type = self.infer_llm_type() + if self.llm_type == "chat": + return await self._run_chat(prompt, data, callbacks, name, system_prompt) + if self.llm_type == "completion": + return await self._run_completion( + prompt, data, callbacks, name, system_prompt + ) + raise ValueError(f"Unknown llm_type {self.llm_type!r}.") + + async def _run_chat( + self, + prompt: str, + data: list[dict], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str | None = default_system_prompt, + ) -> list[LLMResult]: + raise NotImplementedError + + async def _run_completion( + self, + prompt: str, + data: list[dict], + callbacks: list[Callable] | None = None, + name: str | None = None, + system_prompt: str | None = default_system_prompt, + ) -> list[LLMResult]: + raise NotImplementedError + + def infer_llm_type(self) -> str: + return "chat" + + def count_tokens(self, text: str) -> int: + return len(text) // 4 + + +class Body(TypedDict): + model: str | None + messages: list[dict[str, str]] | None + max_tokens: int | None + + +class BatchTemplate(TypedDict): + custom_id: str | None + method: str + url: str + body: Body + + +class OpenAIBatchLLMModel(LLMBatchModel): """A wrapper around the OpenAI library to use the batch API.""" + name: str = "gpt-4o-mini" config: dict = Field( default_factory=dict, description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", - ) - status: OpenAIBatchStatus = Field( - default=OpenAIBatchStatus, - description="Statuses used to report the status of the API request.", - ) + ) - def write_jsonl(self, - data: list[dict[str, str]], - filename: str): - - batch_template = { + def write_jsonl(self, data: list[list[dict[str, str]]], filename: str | Any): + + batch_template: BatchTemplate = { "custom_id": None, "method": "POST", - "url": self.config.get('endpoint'), - "body": { - "model": None, - "messages": None, - "max_tokens": None - } + "url": "/v1/chat/completions", + "body": {"model": None, "messages": None, "max_tokens": None}, } - with open(filename, "w") as f: + with open(filename, "w") as tmp_file: for i, d in enumerate(data): batch_template["custom_id"] = str(i) - batch_template["body"]["model"] = self.config.get('model') + batch_template["body"]["model"] = self.config.get("model") batch_template["body"]["messages"] = d - batch_template["body"]["max_tokens"] = self.config.get('max_tokens') - f.write(json.dumps(batch_template) + "\n") + batch_template["body"]["max_tokens"] = self.config.get("max_tokens") + tmp_file.write(json.dumps(batch_template) + "\n") - @rate_limited async def acomplete(self): raise NotImplementedError("Only chat models are supported by openAI batch API.") - @rate_limited async def acomplete_iter(self): - raise NotImplementedError("Async generator not supported for batch calls and nly chat models are supported by openAI batch API.") + raise NotImplementedError( + "Async generator not supported for batch calls and nly chat models are supported by openAI batch API." + ) async def _run_chat( - self, + self, prompt: str, - data: list[dict[str,str]], + data: list[dict], callbacks: list[Callable] | None = None, name: str | None = None, - system_prompt: str = default_system_prompt, + system_prompt: str | None = default_system_prompt, ) -> list[LLMResult]: if callbacks: sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] @@ -849,15 +912,15 @@ async def _run_chat( batch = [] for d in data: messages = [ - {"role": m["role"], "content": m["content"].format(**d)} - for m in ( - [{"role": "system", "content": system_prompt}, human_message_prompt] - if system_prompt - else [human_message_prompt] - ) + {"role": m["role"], "content": m["content"].format(**d)} + for m in ( + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] + ) ] batch.append(messages) - + start_clock = asyncio.get_running_loop().time() chunks = await self.achat(batch) batch_time = asyncio.get_running_loop().time() - start_clock @@ -865,10 +928,10 @@ async def _run_chat( if callbacks: for chunk in chunks: await do_callbacks( - async_callbacks, sync_callbacks, chunk.text, name - ) + async_callbacks, sync_callbacks, chunk.text or "", name + ) - results = [ + return [ LLMResult( model=self.name, name=name, @@ -878,139 +941,128 @@ async def _run_chat( completion_count=chunk.completion_tokens, seconds_to_first_token=batch_time, seconds_to_last_token=batch_time, - ) for messages, chunk in zip(batch, chunks) + ) + for messages, chunk in zip(batch, chunks, strict=True) ] - return results - - @rate_limited - async def achat(self, - messages: list[dict[str, str]] - ) -> list[Chunk]: + async def achat(self, messages: list[list[dict]]) -> list[Chunk]: try: import openai except ImportError as exc: raise ImportError( - "Please install paper-qa[batch] to use" - " OpenAIBatchLLMModel." - ) + "Please install paper-qa[batch] to use OpenAIBatchLLMModel." + ) from exc client = openai.AsyncOpenAI() with tempfile.NamedTemporaryFile(suffix=".jsonl") as tmp_file: tmp_filename = tmp_file.name self.write_jsonl(messages, tmp_filename) - file = await client.files.create( - file=open(tmp_filename, "rb"), - purpose="batch" - ) + file = await client.files.create(file=tmp_file, purpose="batch") batch = await client.batches.create( input_file_id=file.id, endpoint="/v1/chat/completions", completion_window="24h", - metadata={ - "description": "" - } + metadata={"description": ""}, ) start_clock = asyncio.get_running_loop().time() - while batch.status != self.status.COMPLETE: + while batch.status != OpenAIBatchStatus.COMPLETE: batch = await client.batches.retrieve(batch.id) - if batch.status == self.status.FAILURE: - raise Exception("Batch failed. \n\nReason: \n" + "\n".join([k.message for k in batch.errors.data])) - elif batch.status == self.status.CANCEL: - raise Exception("Batch was cancelled.") - + if batch.status == OpenAIBatchStatus.FAILURE: + error_messages = [] + if batch.errors and hasattr(batch.errors, "data") and batch.errors.data: + error_messages = [ + str(k.message) + for k in batch.errors.data + if k.message is not None + ] + raise RuntimeError( + "Batch failed. \n\nReason: \n" + "\n".join(error_messages) + ) + if batch.status == OpenAIBatchStatus.CANCEL: + raise ConnectionError("Batch was cancelled.") + batch_time = asyncio.get_running_loop().time() - start_clock - if batch_time > self.config.get('batch_summary_timelimit'): - raise Exception("Batch took too long to complete.") - - logger.info(f"Summary batch status: {batch.status} | Time elapsed: {batch_time}") - await asyncio.sleep(self.config.get('batch_polling_interval')) - - responses = await client.files.content(batch.output_file_id) - response_lines = responses.read().decode('utf-8').splitlines() - responses = [json.loads(line) for line in response_lines] - sorted_responses = sorted(responses, key=lambda x: int(x["custom_id"])) # The batchAPI doesn't guarantee the order of the responses - - chunks = [ + if batch_time > self.config.get("batch_summary_timelimit", 24 * 60 * 60): + raise TimeoutError("Batch took too long to complete.") + + logger.info( + f"Summary batch status: {batch.status} | Time elapsed: {batch_time}" + ) + await asyncio.sleep(self.config.get("batch_polling_interval", 30)) + + if batch.output_file_id: + api_responses = await client.files.content(batch.output_file_id) + else: + raise RuntimeError("Batch failed to generate output file.") + sorted_responses = sorted( + [ + json.loads(line) + for line in api_responses.read().decode("utf-8").splitlines() + ], + key=lambda x: int(x["custom_id"]), + ) # The batchAPI doesn't guarantee the order of the responses + + return [ Chunk( text=response["response"]["body"]["choices"][0]["message"]["content"], prompt_tokens=response["response"]["body"]["usage"]["prompt_tokens"], - completion_tokens=response["response"]["body"]["usage"]["completion_tokens"], - ) for response in sorted_responses + completion_tokens=response["response"]["body"]["usage"][ + "completion_tokens" + ], + ) + for response in sorted_responses ] - return chunks - - @rate_limited async def achat_iter(self): - raise NotImplementedError("Async generator not supported for batch calls. Use achat instead.") - - def infer_llm_type(self): - self.config['endpoint'] = "/v1/chat/completions" - return "chat" - - def count_tokens(self, text: str) -> int: - return len(text) // 4 - - async def check_rate_limit(self, token_count: float, **kwargs) -> None: - if "rate_limit" in self.config: - await GLOBAL_LIMITER.try_acquire( - ("client", self.name), - self.config["rate_limit"].get(self.name, None), - weight=max(int(token_count), 1), - **kwargs, - ) + raise NotImplementedError( + "Async generator not supported for batch calls. Use achat instead." + ) -class AnthropicBatchLLMModel(LLMModel): +class AnthropicBatchLLMModel(LLMBatchModel): """A wrapper around the anthropic library to use the batch API.""" + name: str = "claude-3-5-sonnet-20241022" config: dict = Field( default_factory=dict, description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", - ) - status: AnthropicBatchStatus = Field( - default=AnthropicBatchStatus, - description="Statuses used to report the status of the API request.", - ) + ) - @rate_limited async def acomplete(self): raise NotImplementedError("Completion models are not supported yet") - @rate_limited async def acomplete_iter(self): raise NotImplementedError("Completion models are not supported yet") - + async def _run_chat( - self, + self, prompt: str, - data: list[dict[str,str]], + data: list[dict], callbacks: list[Callable] | None = None, name: str | None = None, - system_prompt: str = default_system_prompt, + system_prompt: str | None = default_system_prompt, ) -> list[LLMResult]: if callbacks: sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] - + human_message_prompt = {"role": "user", "content": prompt} batch = [] for d in data: messages = [ - {"role": m["role"], "content": m["content"].format(**d)} - for m in ( - [{"role": "system", "content": system_prompt}, human_message_prompt] - if system_prompt - else [human_message_prompt] - ) + {"role": m["role"], "content": m["content"].format(**d)} + for m in ( + [{"role": "system", "content": system_prompt}, human_message_prompt] + if system_prompt + else [human_message_prompt] + ) ] batch.append(messages) - + start_clock = asyncio.get_running_loop().time() chunks = await self.achat(batch) batch_time = asyncio.get_running_loop().time() - start_clock @@ -1018,10 +1070,10 @@ async def _run_chat( if callbacks: for chunk in chunks: await do_callbacks( - async_callbacks, sync_callbacks, chunk.text, name - ) + async_callbacks, sync_callbacks, chunk.text or "", name + ) - results = [ + return [ LLMResult( model=self.name, name=name, @@ -1031,85 +1083,69 @@ async def _run_chat( completion_count=chunk.completion_tokens, seconds_to_first_token=batch_time, seconds_to_last_token=batch_time, - ) for messages, chunk in zip(batch, chunks) + ) + for messages, chunk in zip(batch, chunks, strict=True) ] - return results - - @rate_limited - async def achat(self, messages: list[dict[str, str]]) -> list[Chunk]: + async def achat(self, messages: list[list[dict[str, str]]]) -> list[Chunk]: try: import anthropic - from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming + from anthropic.types.beta.message_create_params import ( + MessageCreateParamsNonStreaming, + ) from anthropic.types.beta.messages.batch_create_params import Request except ImportError as exc: raise ImportError( - "Please install paper-qa[batch] to use" - " AnthropicBatchLLMModel." - ) - + "Please install paper-qa[batch] to use AnthropicBatchLLMModel." + ) from exc + client = anthropic.AsyncAnthropic() requests = [ Request( custom_id=str(i), params=MessageCreateParamsNonStreaming( - model=self.config.get('model'), - max_tokens=self.config.get('max_tokens'), - messages=m - ) - ) for i, m in enumerate(messages) + model=self.config.get("model"), + max_tokens=self.config.get("max_tokens"), + messages=m, + ), + ) + for i, m in enumerate(messages) ] - batch = await client.beta.messages.batches.create( - requests=requests - ) + batch = await client.beta.messages.batches.create(requests=requests) start_clock = asyncio.get_running_loop().time() - while batch.processing_status != self.status.COMPLETE: + while batch.processing_status != AnthropicBatchStatus.COMPLETE: batch = await client.beta.messages.batches.retrieve(batch.id) - + batch_time = asyncio.get_running_loop().time() - start_clock - if batch_time > self.config.get('batch_summary_timelimit'): - raise Exception("Batch took too long to complete.") + if batch_time > self.config.get("batch_summary_timelimit", 24 * 60 * 60): + raise TimeoutError("Batch took too long to complete.") - logger.info(f"Summary batch status: {batch.processing_status} | Time elapsed: {batch_time}") - await asyncio.sleep(self.config.get('batch_polling_interval')) + logger.info( + f"Summary batch status: {batch.processing_status} | Time elapsed: {batch_time}" + ) + await asyncio.sleep(self.config.get("batch_polling_interval", 30)) api_responses = await client.beta.messages.batches.results(batch.id) - responses = [r for r in api_responses] - sorted_responses = sorted(responses, key=lambda x: int(x.custom_id)) # The batchAPI doesn't guarantee the order of the responses + responses = list(api_responses) + sorted_responses = sorted( + responses, key=lambda x: int(x.custom_id) + ) # The batchAPI doesn't guarantee the order of the responses - chunks = [ + return [ Chunk( text=response.result.message.content[0].text, prompt_tokens=response.result.message.usage.input_tokens, completion_tokens=response.result.message.usage.output_tokens, - ) for response in sorted_responses + ) + for response in sorted_responses ] - - return chunks - - @rate_limited async def achat_iter(self): raise NotImplementedError("support to callbacks is not implemented yet") - def infer_llm_type(self): - return "chat" - - def count_tokens(self, text: str) -> int: - return len(text) // 4 - - async def check_rate_limit(self, token_count: float, **kwargs) -> None: - if "rate_limit" in self.config: - await GLOBAL_LIMITER.try_acquire( - ("client", self.name), - self.config["rate_limit"].get(self.name, None), - weight=max(int(token_count), 1), - **kwargs, - ) - def cosine_similarity(a, b): norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) diff --git a/paperqa/settings.py b/paperqa/settings.py index 72143653d..7c71dd4db 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -41,11 +41,12 @@ HAS_LDP_INSTALLED = False from paperqa.llms import ( - EmbeddingModel, - LiteLLMModel, - OpenAIBatchLLMModel, - AnthropicBatchLLMModel, - embedding_model_factory + AnthropicBatchLLMModel, + EmbeddingModel, + LiteLLMModel, + LLMBatchModel, + OpenAIBatchLLMModel, + embedding_model_factory, ) from paperqa.prompts import ( CONTEXT_INNER_PROMPT, @@ -570,16 +571,15 @@ def make_default_litellm_model_list_settings( ] } -def make_default_openai_batch_llm_settings( - llm: str, temperature: float = 0.0 -) -> dict: + +def make_default_openai_batch_llm_settings(llm: str, temperature: float = 0.0) -> dict: return { "model": llm, "temperature": temperature, "max_tokens": 2048, - } + class Settings(BaseSettings): model_config = SettingsConfigDict(extra="ignore") @@ -622,7 +622,7 @@ class Settings(BaseSettings): ), ) batch_summary_timelimit: int = Field( - default=24*60*60, + default=24 * 60 * 60, description="Time limit for batch summarization in seconds", ) batch_polling_interval: int = Field( @@ -812,12 +812,16 @@ def get_llm(self) -> LiteLLMModel: or make_default_litellm_model_list_settings(self.llm, self.temperature), ) - def get_summary_llm(self) -> LiteLLMModel: + def get_summary_llm(self) -> LiteLLMModel | LLMBatchModel: if self.use_batch_in_summary: import openai + client = openai.OpenAI() - openai_models = [k.id for k in client.models.list().data - if k.owned_by in ('system', "openai")] + openai_models = [ + k.id + for k in client.models.list().data + if k.owned_by in {"system", "openai"} + ] if self.summary_llm.startswith("claude-"): return AnthropicBatchLLMModel( name=self.summary_llm, @@ -826,7 +830,7 @@ def get_summary_llm(self) -> LiteLLMModel: self.summary_llm, self.temperature ), ) - elif self.summary_llm in openai_models: + if self.summary_llm in openai_models: return OpenAIBatchLLMModel( name=self.summary_llm, config=self.summary_llm_config @@ -834,12 +838,11 @@ def get_summary_llm(self) -> LiteLLMModel: self.summary_llm, self.temperature ), ) - else: - raise NotImplementedError( - "`use_batch_in_summary` is set to True, but the summary LLM is not supported" - "for batch processing.\nEither use a Claude or an OpenAI chat model or set " - "`use_batch_in_summary` to False." - ) + raise NotImplementedError( + "`use_batch_in_summary` is set to True, but the summary LLM is not supported" + "for batch processing.\nEither use a Claude or an OpenAI chat model or set " + "`use_batch_in_summary` to False." + ) return LiteLLMModel( name=self.summary_llm, config=self.summary_llm_config diff --git a/pyproject.toml b/pyproject.toml index 9a30ca6c9..5b9c7375c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,10 @@ readme = "README.md" requires-python = ">=3.11" [project.optional-dependencies] +batch = [ + "anthropic", + "openai", +] datasets = [ "datasets", ] @@ -88,10 +92,6 @@ typing = [ zotero = [ "pyzotero", ] -batch = [ - "openai", - "anthropic", -] [project.scripts] pqa = "paperqa.agents:main" diff --git a/tests/test_llms.py b/tests/test_llms.py index b86fe44a5..9971afb69 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -1,28 +1,28 @@ +import json import pathlib import pickle from typing import Any from unittest.mock import AsyncMock, MagicMock, patch +import anthropic import litellm import openai -import anthropic import pytest -import json from paperqa import ( + AnthropicBatchLLMModel, HybridEmbeddingModel, LiteLLMEmbeddingModel, LiteLLMModel, OpenAIBatchLLMModel, - AnthropicBatchLLMModel, SentenceTransformerEmbeddingModel, SparseEmbeddingModel, embedding_model_factory, ) from paperqa.llms import ( - Chunk, - OpenAIBatchStatus, - AnthropicBatchStatus, + AnthropicBatchStatus, + Chunk, + OpenAIBatchStatus, ) from tests.conftest import VCR_DEFAULT_MATCH_ON @@ -176,113 +176,201 @@ def config(self, request) -> dict[str, Any]: "model": model_name, "temperature": 0.0, "max_tokens": 64, - "batch_summary_timelimit": 24*60*60, + "batch_summary_timelimit": 24 * 60 * 60, "batch_polling_interval": 5, } @pytest.mark.parametrize( - "config",[ + "config", + [ pytest.param("gpt-4o-mini", id="chat-model"), - pytest.param("gpt-3.5-turbo-instruct", id="completion-model") - ], indirect=True + pytest.param("gpt-3.5-turbo-instruct", id="completion-model"), + ], + indirect=True, ) @pytest.mark.asyncio - async def test_run_prompt(self, monkeypatch, config: dict[str, Any], request) -> None: - + async def test_run_prompt(self, config: dict[str, Any], request) -> None: + mock_client = AsyncMock(spec_set=openai.AsyncOpenAI()) - - mock_file_id = 'file-123' - mock_client.files.create = AsyncMock( - return_value=MagicMock( - id=mock_file_id - ) - ) - - mock_batch_id = 'batch_123' + + mock_file_id = "file-123" + mock_client.files.create = AsyncMock(return_value=MagicMock(id=mock_file_id)) + + mock_batch_id = "batch_123" mock_client.batches.create = AsyncMock( - return_value=MagicMock( - id=mock_batch_id, - status=OpenAIBatchStatus.PROGRESS) - ) + return_value=MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.PROGRESS) + ) if request.node.name == "test_run_prompt[completion-model]": batch_retrieve_calls = [ - MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.FAILURE, - errors=MagicMock( - data=[ - MagicMock(message="Batch failed: The model gpt-3.5-turbo-instruct is not supported for batch completions.")] + MagicMock( + id=mock_batch_id, + status=OpenAIBatchStatus.FAILURE, + errors=MagicMock( + data=[ + MagicMock( + message=( + "Batch failed: The model gpt-3.5-turbo-instruct " + "is not supported for batch completions." + ) ) - ), + ] + ), + ), ] elif request.node.name == "test_run_prompt[chat-model]": batch_retrieve_calls = [ MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.PROGRESS), - MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.COMPLETE, output_file_id='file-789') + MagicMock( + id=mock_batch_id, + status=OpenAIBatchStatus.COMPLETE, + output_file_id="file-789", + ), ] - mock_client.batches.retrieve = AsyncMock( - side_effect=batch_retrieve_calls - ) + mock_client.batches.retrieve = AsyncMock(side_effect=batch_retrieve_calls) sample_responses = [ - { - 'id': 'file-789', 'custom_id': '0', - 'response': { - 'body': { - 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The duck says "quack." This vocalization is characteristic of the species Anas platyrhynchos, commonly known as the mallard duck, which is often used as a representative example for the duck family, Anatidae.', 'refusal': None}, 'logprobs': None, 'finish_reason': 'stop'}], 'usage': { - 'prompt_tokens': 46, 'completion_tokens': 47, 'total_tokens': 93, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0} + { + "id": "file-789", + "custom_id": "0", + "response": { + "body": { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": ( + 'The duck says "quack." This vocalization is characteristic of the species ' + "Anas platyrhynchos, commonly known as the mallard duck, which is often used " + "as a representative example for the duck family, Anatidae." + ), + "refusal": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 46, + "completion_tokens": 47, + "total_tokens": 93, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0, }, - } - }, - }, - { - 'id': 'file-789', 'custom_id': '1', - 'response': { - 'body': { - 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The dog says "bark." This is a vocalization commonly associated with canines, used for communication purposes such as alerting, expressing excitement, or seeking attention.', 'refusal': None}, 'logprobs': None, 'finish_reason': 'stop'}], - 'usage': { - 'prompt_tokens': 46, 'completion_tokens': 34, 'total_tokens': 80, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0} + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, }, - } - }, - }, - { - 'id': 'file-789', 'custom_id': '2', - 'response': { - 'body': { - 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'It seems you\'re quoting or referencing "the cat says." If you\'re looking for a specific context, such as a phrase, a song, or a scientific observation (like feline vocalizations), please provide more details for a precise response.', 'refusal': None}, 'logprobs': None, 'finish_reason': 'stop'}], - 'usage': { - 'prompt_tokens': 46, 'completion_tokens': 46, 'total_tokens': 92, 'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0}, 'completion_tokens_details': {'reasoning_tokens': 0, 'audio_tokens': 0, 'accepted_prediction_tokens': 0, 'rejected_prediction_tokens': 0} + }, + } + }, + }, + { + "id": "file-789", + "custom_id": "1", + "response": { + "body": { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": ( + 'The dog says "bark." This is a vocalization ' + "commonly associated with canines, used for " + "communication purposes such as alerting, expressing " + "excitement, or seeking attention." + ), + "refusal": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 46, + "completion_tokens": 34, + "total_tokens": 80, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0, }, - } - }, - } - ] + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + }, + } + }, + }, + { + "id": "file-789", + "custom_id": "2", + "response": { + "body": { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": ( + 'It seems you\'re quoting or referencing "the cat says." ' + "If you're looking for a specific context, such as a phrase, a song, " + "or a scientific observation (like feline vocalizations), please provide " + "more details for a precise response." + ), + "refusal": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 46, + "completion_tokens": 46, + "total_tokens": 92, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0, + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + }, + } + }, + }, + ] - response_data = '\n'.join(json.dumps(resp) for resp in sample_responses) + response_data = "\n".join(json.dumps(resp) for resp in sample_responses) mock_response_content = MagicMock() - mock_response_content.read.return_value = response_data.encode('utf-8') - mock_client.files.content = AsyncMock( - return_value = mock_response_content - ) + mock_response_content.read.return_value = response_data.encode() + mock_client.files.content = AsyncMock(return_value=mock_response_content) - with patch('openai.AsyncOpenAI', return_value=mock_client): - llm = OpenAIBatchLLMModel(name=config['model'], config=config) + with patch("openai.AsyncOpenAI", return_value=mock_client): + llm = OpenAIBatchLLMModel(name=config["model"], config=config) outputs = [] + def accum(x) -> None: outputs.append(x) async def ac(x) -> None: pass - data = [ - {"animal": "duck"}, - {"animal": "dog"}, - {"animal": "cat"} - ] + data = [{"animal": "duck"}, {"animal": "dog"}, {"animal": "cat"}] if request.node.name == "test_run_prompt[completion-model]": - with pytest.raises(Exception) as e_info: + with pytest.raises(RuntimeError) as e_info: completion = await llm.run_prompt( prompt="The {animal} says", data=data, @@ -297,20 +385,29 @@ async def ac(x) -> None: callbacks=[accum, ac], ) - assert all([completion[k].model == config['model'] for k in range(len(data))]) - assert all([completion[k].seconds_to_first_token > 0 for k in range(len(data))]) - assert all([completion[k].prompt_count > 0 for k in range(len(data))]) - assert all([completion[k].completion_count > 0 for k in range(len(data))]) - assert all([completion[k].completion_count <= config['max_tokens'] for k in range(len(data))]) - assert sum([completion[k].cost for k in range(len(data))]) > 0 - assert all([str(completion[k]) == outputs[k] for k in range(len(data))]) + assert all( + completion[k].model == config["model"] for k in range(len(data)) + ) + assert all( + completion[k].seconds_to_first_token > 0 for k in range(len(data)) + ) + assert all(completion[k].prompt_count > 0 for k in range(len(data))) + assert all(completion[k].completion_count > 0 for k in range(len(data))) + assert all( + completion[k].completion_count <= config["max_tokens"] + for k in range(len(data)) + ) + assert sum(comp.cost for comp in completion) > 0 + assert all(str(completion[k]) == outputs[k] for k in range(len(data))) @pytest.mark.parametrize( - "config",[ + "config", + [ pytest.param("gpt-4o-mini"), - ], indirect=True + ], + indirect=True, ) - def test_pickling(self, tmp_path: pathlib.Path, config: dict[str,Any]) -> None: + def test_pickling(self, tmp_path: pathlib.Path, config: dict[str, Any]) -> None: pickle_path = tmp_path / "llm_model.pickle" llm = OpenAIBatchLLMModel( name="gpt-4o-mini", @@ -323,6 +420,7 @@ def test_pickling(self, tmp_path: pathlib.Path, config: dict[str,Any]) -> None: assert llm.name == rehydrated_llm.name assert llm.config == rehydrated_llm.config + class TestAnthropicBatchLLMModel: @pytest.fixture(scope="class") def config(self, request) -> dict[str, Any]: @@ -331,143 +429,153 @@ def config(self, request) -> dict[str, Any]: "model": model_name, "temperature": 0.0, "max_tokens": 64, - "batch_summary_timelimit": 24*60*60, + "batch_summary_timelimit": 24 * 60 * 60, "batch_polling_interval": 5, } @pytest.mark.asyncio @pytest.mark.parametrize( - "config",[ + "config", + [ pytest.param("claude-3-haiku-20240307", id="chat-model"), - ], indirect=True + ], + indirect=True, ) - async def test_run_prompt(self, config: dict[str, Any], request) -> None: + async def test_run_prompt(self, config: dict[str, Any]) -> None: mock_client = AsyncMock(spec_set=anthropic.AsyncAnthropic()) - # Define mock methods for the client mock_client = MagicMock() mock_batches = MagicMock() - # mock_client.beta = MagicMock() - # mock_client.beta.messages = MagicMock() mock_client.beta.messages.batches = mock_batches - mock_batch_id = 'msgbatch_123' + mock_batch_id = "msgbatch_123" mock_batches.create = AsyncMock( return_value=MagicMock( - id=mock_batch_id, - processing_status=AnthropicBatchStatus.PROGRESS), - ) + id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS + ), + ) batch_retrieve_call = [ - MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS), - MagicMock(id=mock_batch_id, processing_status=AnthropicBatchStatus.COMPLETE) + MagicMock( + id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS + ), + MagicMock( + id=mock_batch_id, processing_status=AnthropicBatchStatus.COMPLETE + ), ] - mock_batches.retrieve = AsyncMock( - side_effect=batch_retrieve_call - ) + mock_batches.retrieve = AsyncMock(side_effect=batch_retrieve_call) mock_responses = [ MagicMock( - custom_id='0', + custom_id="0", result=MagicMock( message=MagicMock( - id='msg_0143L9rPswgaUyENkHkPJLcn', + id="msg_0143L9rPswgaUyENkHkPJLcn", content=[ MagicMock( - text="I don't actually hear any ducks saying anything. As an AI assistant, I don't have the ability to hear or interpret sounds from the physical world. I can only respond based on the text you provide to me through this chat interface. If you'd like, you can tell me what you think the duck is", + text=( + "I don't actually hear any ducks saying anything. " + "As an AI assistant, I don't have the ability to hear or interpret " + "sounds from the physical world. I can only respond based on the text " + "you provide to me through this chat interface. " + "If you'd like, you can tell me what you think the duck is" + ), ) ], - model='claude-3-haiku-20240307', - role='assistant', - stop_reason='max_tokens', + model="claude-3-haiku-20240307", + role="assistant", + stop_reason="max_tokens", stop_sequence=None, - type='message', - usage=MagicMock( - input_tokens=10, - output_tokens=64 - ) + type="message", + usage=MagicMock(input_tokens=10, output_tokens=64), ), - type='succeeded' - ) + type="succeeded", + ), ), MagicMock( - custom_id='1', + custom_id="1", result=MagicMock( message=MagicMock( - id='msg_01KujiHEB5S8pfRUCmrbabu4', + id="msg_01KujiHEB5S8pfRUCmrbabu4", content=[ MagicMock( - text="Unfortunately, I don't actually hear a dog speaking. As an AI assistant without physical senses, I can't directly perceive animals making sounds. Could you please provide more context about what the dog is saying, or what you would like me to respond to regarding the dog? I'd be happy to try to assist", + text=( + "Unfortunately, I don't actually hear a dog speaking. " + "As an AI assistant without physical senses, I" + "can't directly perceive animals making sounds. " + "Could you please provide more context about what the " + "dog is saying, or what you would like me to respond to " + "regarding the dog? I'd be happy to try to assist" + ), ) ], - model='claude-3-haiku-20240307', - role='assistant', - stop_reason='max_tokens', + model="claude-3-haiku-20240307", + role="assistant", + stop_reason="max_tokens", stop_sequence=None, - type='message', - usage=MagicMock( - input_tokens=10, - output_tokens=64 - ) + type="message", + usage=MagicMock(input_tokens=10, output_tokens=64), ), - type='succeeded' - ) + type="succeeded", + ), ), MagicMock( - custom_id='2', + custom_id="2", result=MagicMock( message=MagicMock( - id='msg_01Pf2LqV7wjnwqerkZubbofA', + id="msg_01Pf2LqV7wjnwqerkZubbofA", content=[ MagicMock( - text="I'm afraid I don't actually hear a cat speaking. As an AI assistant, I don't have the ability to hear or communicate with animals directly. I can only respond based on the text you provide to me. If you'd like, you can tell me what you imagine the cat is saying, and I'll", + text=( + "I'm afraid I don't actually hear a cat speaking. " + "As an AI assistant, I don't have the ability to hear " + "or communicate with animals directly. I can only respond " + "based on the text you provide to me. If you'd " + "like, you can tell me what you imagine the cat is saying, and I'll" + ), ) ], - model='claude-3-haiku-20240307', - role='assistant', - stop_reason='max_tokens', + model="claude-3-haiku-20240307", + role="assistant", + stop_reason="max_tokens", stop_sequence=None, - type='message', - usage=MagicMock( - input_tokens=10, - output_tokens=64 - ) + type="message", + usage=MagicMock(input_tokens=10, output_tokens=64), ), - type='succeeded' - ) + type="succeeded", + ), ), ] # Create a generator function - def mock_results_generator(batch_id): - for response in mock_responses: - yield response + def mock_results_generator(_batch_id): + + yield from mock_responses mock_batches.results = AsyncMock( return_value=mock_results_generator(mock_batch_id) + ) + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + llm = AnthropicBatchLLMModel(name=config["model"], config=config) + + data = [{"animal": "duck"}, {"animal": "dog"}, {"animal": "cat"}] + + completions = await llm.run_prompt( + prompt="The {animal} says", + data=data, ) - with patch('anthropic.AsyncAnthropic', return_value=mock_client): - llm = AnthropicBatchLLMModel(name=config['model'], config=config) + assert all(comp.model == config["model"] for comp in completions) + assert all(comp.seconds_to_first_token > 0 for comp in completions) + assert all(comp.prompt_count > 0 for comp in completions) + assert all(comp.completion_count > 0 for comp in completions) + assert all( + comp.completion_count <= config["max_tokens"] for comp in completions + ) + assert sum(comp.cost for comp in completions) > 0 - data = [ - {"animal": "duck"}, - {"animal": "dog"}, - {"animal": "cat"} - ] - - completion = await llm.run_prompt( - prompt="The {animal} says", - data=data, - ) - - assert all([completion[k].model == config['model'] for k in range(len(data))]) - assert all([completion[k].seconds_to_first_token > 0 for k in range(len(data))]) - assert all([completion[k].prompt_count > 0 for k in range(len(data))]) - assert all([completion[k].completion_count > 0 for k in range(len(data))]) - assert all([completion[k].completion_count <= config['max_tokens'] for k in range(len(data))]) - assert sum([completion[k].cost for k in range(len(data))]) > 0 @pytest.mark.asyncio async def test_embedding_model_factory_sentence_transformer() -> None: diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index a9d4c0718..9dc8dcf3c 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -505,9 +505,10 @@ async def test_docs_lifecycle(subtests: SubTests, stub_data_dir: Path) -> None: assert docs.texts assert all(t not in docs.texts_index for t in docs.texts) + def test_evidence(docs_fixture) -> None: debug_settings = Settings.from_name("debug") - evidence = docs_fixture.get_evidence( + evidence = docs_fixture.get_evidence( PQASession(question="What does XAI stand for?"), settings=debug_settings, ).contexts diff --git a/uv.lock b/uv.lock index 7f8c43f23..02a240782 100644 --- a/uv.lock +++ b/uv.lock @@ -1529,7 +1529,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.4.1.dev28+g9a0a6c4" +version = "5.4.1.dev30+g660bfa0.d20241119" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From 293658ab7aa92424161be98af5a9b6deb64d7191 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Tue, 19 Nov 2024 09:13:26 -0800 Subject: [PATCH 12/26] Updates uv on github actions to include extra requirements --- .github/workflows/tests.yml | 4 ++-- paperqa/core.py | 2 +- uv.lock | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9a493feab..d6a02414d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,7 +35,7 @@ jobs: enable-cache: true - run: uv python pin ${{ matrix.python-version }} - uses: hynek/build-and-inspect-python-package@v2 - - run: uv sync --python-preference=only-managed + - run: uv sync --python-preference=only-managed --all-extras - run: uv run refurb paperqa tests - run: uv run pylint paperqa test: @@ -49,7 +49,7 @@ jobs: with: enable-cache: true - run: uv python pin ${{ matrix.python-version }} - - run: uv sync --python-preference=only-managed + - run: uv sync --python-preference=only-managed --all-extras - name: Cache datasets uses: actions/cache@v4 with: diff --git a/paperqa/core.py b/paperqa/core.py index 18a327cd8..465446178 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -151,7 +151,7 @@ async def gather_with_batch( for m in matches ] - llm_results : list[LLMResult] = [] + llm_results: list[LLMResult] = [] if prompt_runner: result = await prompt_runner( data, diff --git a/uv.lock b/uv.lock index 02a240782..6b34f5aa1 100644 --- a/uv.lock +++ b/uv.lock @@ -1529,7 +1529,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.4.1.dev30+g660bfa0.d20241119" +version = "5.4.1.dev34+gee351f2.d20241119" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From 1ad1c7cbc17ded659ee00793a00f8f8dabf44296 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Tue, 19 Nov 2024 11:24:43 -0800 Subject: [PATCH 13/26] Removed the --all-extras flag from uv in github workflow A more general solution is to include it i the field [dependency-groups] of pyproject.toml --- .github/workflows/tests.yml | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d6a02414d..9a493feab 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,7 +35,7 @@ jobs: enable-cache: true - run: uv python pin ${{ matrix.python-version }} - uses: hynek/build-and-inspect-python-package@v2 - - run: uv sync --python-preference=only-managed --all-extras + - run: uv sync --python-preference=only-managed - run: uv run refurb paperqa tests - run: uv run pylint paperqa test: @@ -49,7 +49,7 @@ jobs: with: enable-cache: true - run: uv python pin ${{ matrix.python-version }} - - run: uv sync --python-preference=only-managed --all-extras + - run: uv sync --python-preference=only-managed - name: Cache datasets uses: actions/cache@v4 with: diff --git a/pyproject.toml b/pyproject.toml index 5b9c7375c..3d499ac01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = ["setuptools>=64", "setuptools_scm>=8"] dev = [ "ipython>=8", # Pin to keep recent "mypy>=1.8", # Pin for mutable-override - "paper-qa[datasets,ldp,typing,zotero,local]", + "paper-qa[datasets,batch,ldp,typing,zotero,local]", "pre-commit>=3.4", # Pin to keep recent "pydantic~=2.0", "pylint-pydantic", From af32005ad8ecd22c04717956bfc99e45f43fca8c Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Tue, 19 Nov 2024 11:25:32 -0800 Subject: [PATCH 14/26] Refactored OpenAiBatchStatus and AnthropicBatchStatus to make the code DRYer --- paperqa/llms.py | 57 ++++++++++++++++++++++++++++++---------------- tests/test_llms.py | 17 +++++++------- 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index ab3db73ea..3ccadd7f8 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -75,22 +75,37 @@ class EmbeddingModes(StrEnum): QUERY = "query" -class OpenAIBatchStatus(StrEnum): - COMPLETE = "completed" - PROGRESS = "in_progress" - SUCCESS = "completed" - FAILURE = "failed" - EXPIRE = "expired" - CANCEL = "cancelled" - - -class AnthropicBatchStatus(StrEnum): - COMPLETE = "ended" - PROGRESS = "in_progress" - SUCCESS = "succeeded" - FAILURE = "errored" - EXPIRE = "expired" - CANCEL = "canceled" +class BatchStatus(StrEnum): + COMPLETE = "complete" + PROGRESS = "progress" + SUCCESS = "success" + FAILURE = "failure" + EXPIRE = "expire" + CANCEL = "cancel" + + def from_openai(self) -> str: + """Convert BatchStatus to OpenAI status.""" + mapping = { + BatchStatus.COMPLETE: "completed", + BatchStatus.PROGRESS: "in_progress", + BatchStatus.SUCCESS: "completed", # Assuming OpenAI uses "completed" for success + BatchStatus.FAILURE: "failed", + BatchStatus.EXPIRE: "expired", + BatchStatus.CANCEL: "cancelled", + } + return mapping[self] + + def from_anthropic(self) -> str: + """Convert BatchStatus to Anthropic status.""" + mapping = { + BatchStatus.COMPLETE: "ended", + BatchStatus.PROGRESS: "in_progress", + BatchStatus.SUCCESS: "succeeded", + BatchStatus.FAILURE: "errored", + BatchStatus.EXPIRE: "expired", + BatchStatus.CANCEL: "canceled", + } + return mapping[self] # Estimate from OpenAI's FAQ @@ -968,9 +983,9 @@ async def achat(self, messages: list[list[dict]]) -> list[Chunk]: ) start_clock = asyncio.get_running_loop().time() - while batch.status != OpenAIBatchStatus.COMPLETE: + while batch.status != BatchStatus.COMPLETE.from_openai(): batch = await client.batches.retrieve(batch.id) - if batch.status == OpenAIBatchStatus.FAILURE: + if batch.status == BatchStatus.FAILURE.from_openai(): error_messages = [] if batch.errors and hasattr(batch.errors, "data") and batch.errors.data: error_messages = [ @@ -981,9 +996,11 @@ async def achat(self, messages: list[list[dict]]) -> list[Chunk]: raise RuntimeError( "Batch failed. \n\nReason: \n" + "\n".join(error_messages) ) - if batch.status == OpenAIBatchStatus.CANCEL: + if batch.status == BatchStatus.CANCEL.from_openai(): raise ConnectionError("Batch was cancelled.") + # if batch.stats == OpenAIBatchStats.PROGRESS: + batch_time = asyncio.get_running_loop().time() - start_clock if batch_time > self.config.get("batch_summary_timelimit", 24 * 60 * 60): raise TimeoutError("Batch took too long to complete.") @@ -1116,7 +1133,7 @@ async def achat(self, messages: list[list[dict[str, str]]]) -> list[Chunk]: batch = await client.beta.messages.batches.create(requests=requests) start_clock = asyncio.get_running_loop().time() - while batch.processing_status != AnthropicBatchStatus.COMPLETE: + while batch.processing_status != BatchStatus.COMPLETE.from_anthropic(): batch = await client.beta.messages.batches.retrieve(batch.id) batch_time = asyncio.get_running_loop().time() - start_clock diff --git a/tests/test_llms.py b/tests/test_llms.py index 9971afb69..31972be1d 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -20,9 +20,8 @@ embedding_model_factory, ) from paperqa.llms import ( - AnthropicBatchStatus, + BatchStatus, Chunk, - OpenAIBatchStatus, ) from tests.conftest import VCR_DEFAULT_MATCH_ON @@ -198,14 +197,14 @@ async def test_run_prompt(self, config: dict[str, Any], request) -> None: mock_batch_id = "batch_123" mock_client.batches.create = AsyncMock( - return_value=MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.PROGRESS) + return_value=MagicMock(id=mock_batch_id, status=BatchStatus.PROGRESS.from_openai()) ) if request.node.name == "test_run_prompt[completion-model]": batch_retrieve_calls = [ MagicMock( id=mock_batch_id, - status=OpenAIBatchStatus.FAILURE, + status=BatchStatus.FAILURE.from_openai(), errors=MagicMock( data=[ MagicMock( @@ -220,10 +219,10 @@ async def test_run_prompt(self, config: dict[str, Any], request) -> None: ] elif request.node.name == "test_run_prompt[chat-model]": batch_retrieve_calls = [ - MagicMock(id=mock_batch_id, status=OpenAIBatchStatus.PROGRESS), + MagicMock(id=mock_batch_id, status=BatchStatus.PROGRESS.from_openai()), MagicMock( id=mock_batch_id, - status=OpenAIBatchStatus.COMPLETE, + status=BatchStatus.COMPLETE.from_openai(), output_file_id="file-789", ), ] @@ -452,16 +451,16 @@ async def test_run_prompt(self, config: dict[str, Any]) -> None: mock_batch_id = "msgbatch_123" mock_batches.create = AsyncMock( return_value=MagicMock( - id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS + id=mock_batch_id, processing_status=BatchStatus.PROGRESS.from_anthropic() ), ) batch_retrieve_call = [ MagicMock( - id=mock_batch_id, processing_status=AnthropicBatchStatus.PROGRESS + id=mock_batch_id, processing_status=BatchStatus.PROGRESS.from_anthropic() ), MagicMock( - id=mock_batch_id, processing_status=AnthropicBatchStatus.COMPLETE + id=mock_batch_id, processing_status=BatchStatus.COMPLETE.from_anthropic() ), ] mock_batches.retrieve = AsyncMock(side_effect=batch_retrieve_call) From 63e4b3988c5e3af3296193c73ace7c4de3499923 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 19:30:50 +0000 Subject: [PATCH 15/26] [pre-commit.ci lite] apply automatic fixes --- tests/test_llms.py | 13 +++++++++---- uv.lock | 6 ++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_llms.py b/tests/test_llms.py index 31972be1d..199d5bc2e 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -197,7 +197,9 @@ async def test_run_prompt(self, config: dict[str, Any], request) -> None: mock_batch_id = "batch_123" mock_client.batches.create = AsyncMock( - return_value=MagicMock(id=mock_batch_id, status=BatchStatus.PROGRESS.from_openai()) + return_value=MagicMock( + id=mock_batch_id, status=BatchStatus.PROGRESS.from_openai() + ) ) if request.node.name == "test_run_prompt[completion-model]": @@ -451,16 +453,19 @@ async def test_run_prompt(self, config: dict[str, Any]) -> None: mock_batch_id = "msgbatch_123" mock_batches.create = AsyncMock( return_value=MagicMock( - id=mock_batch_id, processing_status=BatchStatus.PROGRESS.from_anthropic() + id=mock_batch_id, + processing_status=BatchStatus.PROGRESS.from_anthropic(), ), ) batch_retrieve_call = [ MagicMock( - id=mock_batch_id, processing_status=BatchStatus.PROGRESS.from_anthropic() + id=mock_batch_id, + processing_status=BatchStatus.PROGRESS.from_anthropic(), ), MagicMock( - id=mock_batch_id, processing_status=BatchStatus.COMPLETE.from_anthropic() + id=mock_batch_id, + processing_status=BatchStatus.COMPLETE.from_anthropic(), ), ] mock_batches.retrieve = AsyncMock(side_effect=batch_retrieve_call) diff --git a/uv.lock b/uv.lock index 6b34f5aa1..ffc50f707 100644 --- a/uv.lock +++ b/uv.lock @@ -1529,7 +1529,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.4.1.dev34+gee351f2.d20241119" +version = "5.4.1.dev40+g6a7418c.d20241119" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1577,10 +1577,12 @@ zotero = [ [package.dev-dependencies] dev = [ + { name = "anthropic" }, { name = "datasets" }, { name = "ipython" }, { name = "ldp" }, { name = "mypy" }, + { name = "openai" }, { name = "pandas-stubs" }, { name = "pre-commit" }, { name = "pydantic" }, @@ -1637,7 +1639,7 @@ requires-dist = [ dev = [ { name = "ipython", specifier = ">=8" }, { name = "mypy", specifier = ">=1.8" }, - { name = "paper-qa", extras = ["datasets", "ldp", "typing", "zotero", "local"] }, + { name = "paper-qa", extras = ["datasets", "batch", "ldp", "typing", "zotero", "local"] }, { name = "pre-commit", specifier = ">=3.4" }, { name = "pydantic", specifier = "~=2.0" }, { name = "pylint-pydantic" }, From d7dbd729204082145830d58be615fca07060c4d0 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Tue, 19 Nov 2024 12:51:45 -0800 Subject: [PATCH 16/26] Cleaned unneeded comments --- paperqa/llms.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index 3ccadd7f8..fe4336caf 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -88,7 +88,7 @@ def from_openai(self) -> str: mapping = { BatchStatus.COMPLETE: "completed", BatchStatus.PROGRESS: "in_progress", - BatchStatus.SUCCESS: "completed", # Assuming OpenAI uses "completed" for success + BatchStatus.SUCCESS: "completed", BatchStatus.FAILURE: "failed", BatchStatus.EXPIRE: "expired", BatchStatus.CANCEL: "cancelled", @@ -999,8 +999,6 @@ async def achat(self, messages: list[list[dict]]) -> list[Chunk]: if batch.status == BatchStatus.CANCEL.from_openai(): raise ConnectionError("Batch was cancelled.") - # if batch.stats == OpenAIBatchStats.PROGRESS: - batch_time = asyncio.get_running_loop().time() - start_clock if batch_time > self.config.get("batch_summary_timelimit", 24 * 60 * 60): raise TimeoutError("Batch took too long to complete.") From 7c37f6d64f4e673e88721578fa177d7d908da2d8 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Tue, 19 Nov 2024 13:45:51 -0800 Subject: [PATCH 17/26] Updated the way the system message is passed to anthropic --- paperqa/llms.py | 9 ++++++++- uv.lock | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index fe4336caf..260c3063c 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1122,7 +1122,14 @@ async def achat(self, messages: list[list[dict[str, str]]]) -> list[Chunk]: params=MessageCreateParamsNonStreaming( model=self.config.get("model"), max_tokens=self.config.get("max_tokens"), - messages=m, + system="".join( + [ + user_m["content"] + for user_m in messages[0] + if user_m["role"] == "system" + ] + ), + messages=[user_m for user_m in m if user_m["role"] == "user"], ), ) for i, m in enumerate(messages) diff --git a/uv.lock b/uv.lock index ffc50f707..6efb9abca 100644 --- a/uv.lock +++ b/uv.lock @@ -1529,7 +1529,7 @@ wheels = [ [[package]] name = "paper-qa" -version = "5.4.1.dev40+g6a7418c.d20241119" +version = "5.4.1.dev42+gd7dbd72.d20241119" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From de18907fbb43552dd1d295e6ebcca9bfcc5ba742 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Tue, 19 Nov 2024 17:13:51 -0800 Subject: [PATCH 18/26] changed how the file is passed to openai --- paperqa/llms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index 260c3063c..34dd8caf7 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -973,7 +973,7 @@ async def achat(self, messages: list[list[dict]]) -> list[Chunk]: with tempfile.NamedTemporaryFile(suffix=".jsonl") as tmp_file: tmp_filename = tmp_file.name self.write_jsonl(messages, tmp_filename) - file = await client.files.create(file=tmp_file, purpose="batch") + file = await client.files.create(file=open(tmp_filename,"rb"), purpose="batch") batch = await client.batches.create( input_file_id=file.id, From 3e72bd4be737281c5f0760903590ec02ad8cbcdf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 20 Nov 2024 01:15:57 +0000 Subject: [PATCH 19/26] [pre-commit.ci lite] apply automatic fixes --- paperqa/llms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index 34dd8caf7..5ec14d7e6 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -973,7 +973,9 @@ async def achat(self, messages: list[list[dict]]) -> list[Chunk]: with tempfile.NamedTemporaryFile(suffix=".jsonl") as tmp_file: tmp_filename = tmp_file.name self.write_jsonl(messages, tmp_filename) - file = await client.files.create(file=open(tmp_filename,"rb"), purpose="batch") + file = await client.files.create( + file=open(tmp_filename, "rb"), purpose="batch" + ) batch = await client.batches.create( input_file_id=file.id, From 7c7f4b8248f4e2f2c5c7e7447c0b29fdc7f07a99 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Wed, 20 Nov 2024 00:09:49 -0800 Subject: [PATCH 20/26] Avoided writing to a file when sending the batch to openAi --- paperqa/llms.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index 5ec14d7e6..fc57303d2 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -3,6 +3,7 @@ import asyncio import contextlib import functools +import io import json import logging import tempfile @@ -886,21 +887,26 @@ class OpenAIBatchLLMModel(LLMBatchModel): description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.", ) - def write_jsonl(self, data: list[list[dict[str, str]]], filename: str | Any): - + async def write_jsonl( + self, + data: list[list[dict[str, str]]], + tmp_file: io.BytesIO, + ): batch_template: BatchTemplate = { "custom_id": None, "method": "POST", "url": "/v1/chat/completions", "body": {"model": None, "messages": None, "max_tokens": None}, } - with open(filename, "w") as tmp_file: - for i, d in enumerate(data): - batch_template["custom_id"] = str(i) - batch_template["body"]["model"] = self.config.get("model") - batch_template["body"]["messages"] = d - batch_template["body"]["max_tokens"] = self.config.get("max_tokens") - tmp_file.write(json.dumps(batch_template) + "\n") + + for i, d in enumerate(data): + batch_template["custom_id"] = str(i) + batch_template["body"]["model"] = self.config.get("model") + batch_template["body"]["messages"] = d + batch_template["body"]["max_tokens"] = self.config.get("max_tokens") + serialized_data = json.dumps(batch_template) + "\n" + tmp_file.write(serialized_data.encode("utf-8")) + # tmp_file.write(json.dumps(batch_template) + "\n") async def acomplete(self): raise NotImplementedError("Only chat models are supported by openAI batch API.") @@ -970,12 +976,10 @@ async def achat(self, messages: list[list[dict]]) -> list[Chunk]: client = openai.AsyncOpenAI() - with tempfile.NamedTemporaryFile(suffix=".jsonl") as tmp_file: - tmp_filename = tmp_file.name - self.write_jsonl(messages, tmp_filename) - file = await client.files.create( - file=open(tmp_filename, "rb"), purpose="batch" - ) + with io.BytesIO() as tmp_file: + await self.write_jsonl(messages, tmp_file) + tmp_file.seek(0) + file = await client.files.create(file=tmp_file, purpose="batch") batch = await client.batches.create( input_file_id=file.id, From 6c8f186c6bac7af1bccd3f2091889f21106b040c Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Wed, 20 Nov 2024 09:16:45 -0800 Subject: [PATCH 21/26] Skipped writing a file. Instead, the content is directly passed to the API --- paperqa/llms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index fc57303d2..4786e12e2 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -6,7 +6,6 @@ import io import json import logging -import tempfile from abc import ABC, abstractmethod from collections.abc import ( AsyncGenerator, From 17c26ebca60c01002d3f371c7793ba8aa7bfab99 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Wed, 20 Nov 2024 09:35:36 -0800 Subject: [PATCH 22/26] Fixed lint error --- paperqa/llms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index 4aa4ff210..c543e151a 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -904,7 +904,7 @@ async def write_jsonl( batch_template["body"]["messages"] = d batch_template["body"]["max_tokens"] = self.config.get("max_tokens") serialized_data = json.dumps(batch_template) + "\n" - tmp_file.write(serialized_data.encode("utf-8")) + tmp_file.write(serialized_data.encode()) # tmp_file.write(json.dumps(batch_template) + "\n") async def acomplete(self): From c258306db18b6c1158593eab26a16d789bd84d42 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Wed, 20 Nov 2024 10:40:47 -0800 Subject: [PATCH 23/26] Updated the batch time limit settings name --- paperqa/llms.py | 4 ++-- paperqa/settings.py | 7 +++++-- tests/test_llms.py | 4 ++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index c543e151a..57d06e720 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1005,7 +1005,7 @@ async def achat(self, messages: list[list[dict]]) -> list[Chunk]: raise ConnectionError("Batch was cancelled.") batch_time = asyncio.get_running_loop().time() - start_clock - if batch_time > self.config.get("batch_summary_timelimit", 24 * 60 * 60): + if batch_time > self.config.get("batch_summary_time_limit", 24 * 60 * 60): raise TimeoutError("Batch took too long to complete.") logger.info( @@ -1147,7 +1147,7 @@ async def achat(self, messages: list[list[dict[str, str]]]) -> list[Chunk]: batch = await client.beta.messages.batches.retrieve(batch.id) batch_time = asyncio.get_running_loop().time() - start_clock - if batch_time > self.config.get("batch_summary_timelimit", 24 * 60 * 60): + if batch_time > self.config.get("batch_summary_time_limit", 24 * 60 * 60): raise TimeoutError("Batch took too long to complete.") logger.info( diff --git a/paperqa/settings.py b/paperqa/settings.py index 38334d487..90978dd5f 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -637,9 +637,12 @@ class Settings(BaseSettings): "and OpenAI (https://platform.openai.com/docs/guides/batch) chat models." ), ) - batch_summary_timelimit: int = Field( + batch_summary_time_limit: int = Field( default=24 * 60 * 60, - description="Time limit for batch summarization in seconds", + description=( + "Time limit for batch summarization in seconds. " + "Default is set to 24 hours to match OpenAI's and Anthropic's limit." + ), ) batch_polling_interval: int = Field( default=30, diff --git a/tests/test_llms.py b/tests/test_llms.py index 199d5bc2e..ff12efeb3 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -175,7 +175,7 @@ def config(self, request) -> dict[str, Any]: "model": model_name, "temperature": 0.0, "max_tokens": 64, - "batch_summary_timelimit": 24 * 60 * 60, + "batch_summary_time_limit": 24 * 60 * 60, "batch_polling_interval": 5, } @@ -430,7 +430,7 @@ def config(self, request) -> dict[str, Any]: "model": model_name, "temperature": 0.0, "max_tokens": 64, - "batch_summary_timelimit": 24 * 60 * 60, + "batch_summary_time_limit": 24 * 60 * 60, "batch_polling_interval": 5, } From 4b8e1c372014dae5de6af5c13b3f7b07265a84bf Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Wed, 20 Nov 2024 10:41:29 -0800 Subject: [PATCH 24/26] Removed type hints from docstrings in gather_with_batch --- paperqa/core.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/paperqa/core.py b/paperqa/core.py index 465446178..eff8907d6 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -130,16 +130,16 @@ async def gather_with_batch( Gathers evidence considering a batch of texts. The completions are obtained using a batch API. Args: - matches (list[Text]): A list of text matches to gather evidence from. - question (str): The question to be answered. - prompt_runner (PromptRunner | None): The prompt runner to use for obtaining completions. - extra_prompt_data (dict[str, str] | None, optional): Additional data to include in the prompt. - parser (Callable[[str], dict[str, Any]] | None, optional): A function to parse the LLM result text. - callbacks (list[Callable[[str], None]] | None, optional): A list of callback functions to be called + matches: A list of text matches to gather evidence from. + question: The question to be answered. + prompt_runner: The prompt runner to use for obtaining completions. + extra_prompt_data: Additional data to include in the prompt. + parser: A function to parse the LLM result text. + callbacks: A list of callback functions to be called with the LLM result text. Returns: - list[tuple[Context, LLMResult]]: A list of tuples containing the context and LLM result for each match. + list: A list of tuples containing the context and LLM result for each match. """ data = [ { @@ -178,7 +178,6 @@ async def gather_with_batch( Context( context=strip_citations(llm_result.text), text=m, - model_extra={}, score=score, **r, ), From 8b5c1fa0d7445ac33494a528a4af9418ef25cef3 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Wed, 20 Nov 2024 11:41:13 -0800 Subject: [PATCH 25/26] Added exception in map_fxn_summary to treat multiple reponses --- paperqa/core.py | 11 +++++++++-- paperqa/docs.py | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/paperqa/core.py b/paperqa/core.py index eff8907d6..a545da234 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -2,7 +2,7 @@ import json import re -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Any from paperqa.llms import PromptRunner @@ -74,6 +74,12 @@ async def map_fxn_summary( callbacks, "evidence:" + text.name, ) + + if isinstance(result, Sequence) and len(result) != 1: + raise NotImplementedError( + f"Expected a single LLMResult, got {len(result)}. : {result}" + ) + llm_result = result if isinstance(result, LLMResult) else result[0] context = llm_result.text result_data = parser(context) if parser else {} @@ -139,7 +145,7 @@ async def gather_with_batch( with the LLM result text. Returns: - list: A list of tuples containing the context and LLM result for each match. + List of tuples containing the context and LLM result for each match. """ data = [ { @@ -158,6 +164,7 @@ async def gather_with_batch( callbacks, "evidence:" + matches[0].name, ) + llm_results = result if isinstance(result, list) else [result] results_data = [] diff --git a/paperqa/docs.py b/paperqa/docs.py index 5adf8c239..1d2a4bf7d 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -637,6 +637,8 @@ async def aget_evidence( # noqa: PLR0912 prompt_runner=prompt_runner, extra_prompt_data={ "summary_length": answer_config.evidence_summary_length, + # citations are formatted inside the function + # for each text in matches }, parser=llm_parse_json if prompt_config.use_json else None, callbacks=callbacks, From ab40b54e432c2d8d01045fe7be36cd2f5c094f48 Mon Sep 17 00:00:00 2001 From: Mayk Caldas Date: Wed, 20 Nov 2024 12:22:01 -0800 Subject: [PATCH 26/26] Added a description explaining the llm_type attribute --- paperqa/llms.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index 57d06e720..f1a42d44f 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -313,7 +313,10 @@ def __str__(self): class LLMModel(ABC, BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - llm_type: str | None = None + llm_type: str | StrEnum | None = Field( + default=None, + description="A string indicating the type of LLM model (e.g., 'chat' or 'completion').", + ) name: str llm_result_callback: ( Callable[[LLMResult], None] | Callable[[LLMResult], Awaitable[None]] | None @@ -805,7 +808,10 @@ async def select_tool( class LLMBatchModel(ABC, BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - llm_type: str | None = None + llm_type: str | StrEnum | None = Field( + default=None, + description="A string indicating the type of LLM model (e.g., 'chat' or 'completion').", + ) name: str llm_result_callback: ( Callable[[LLMResult], None] | Callable[[LLMResult], Awaitable[None]] | None @@ -889,7 +895,7 @@ class OpenAIBatchLLMModel(LLMBatchModel): async def write_jsonl( self, data: list[list[dict[str, str]]], - tmp_file: io.BytesIO, + mem_buffer: io.BytesIO, ): batch_template: BatchTemplate = { "custom_id": None, @@ -904,8 +910,7 @@ async def write_jsonl( batch_template["body"]["messages"] = d batch_template["body"]["max_tokens"] = self.config.get("max_tokens") serialized_data = json.dumps(batch_template) + "\n" - tmp_file.write(serialized_data.encode()) - # tmp_file.write(json.dumps(batch_template) + "\n") + mem_buffer.write(serialized_data.encode()) async def acomplete(self): raise NotImplementedError("Only chat models are supported by openAI batch API.") @@ -975,10 +980,10 @@ async def achat(self, messages: list[list[dict]]) -> list[Chunk]: client = openai.AsyncOpenAI() - with io.BytesIO() as tmp_file: - await self.write_jsonl(messages, tmp_file) - tmp_file.seek(0) - file = await client.files.create(file=tmp_file, purpose="batch") + with io.BytesIO() as mem_buffer: + await self.write_jsonl(messages, mem_buffer) + mem_buffer.seek(0) + file = await client.files.create(file=mem_buffer, purpose="batch") batch = await client.batches.create( input_file_id=file.id, @@ -1045,7 +1050,7 @@ async def achat_iter(self): class AnthropicBatchLLMModel(LLMBatchModel): """A wrapper around the anthropic library to use the batch API.""" - name: str = "claude-3-5-sonnet-20241022" + name: str = "claude-3-5-sonnet-latest" config: dict = Field( default_factory=dict, description="Configuration dictionary for this model. Currently supported keys are `model` and `max_token`.",