Skip to content

Refactor to breakout config from rest of code #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c9124b2
Extracted retrieve method
whitead Jun 20, 2024
b7574d0
Need to decide how to DRY this
whitead Jun 21, 2024
48efa06
Merge branch 'main' into issue-283
whitead Jun 25, 2024
b5a5b81
Merge branch 'september-2024-release' into issue-283
whitead Aug 29, 2024
a32fe4f
First draft of mapping
whitead Aug 30, 2024
6343cfd
Switched to new map function in gather evidence
whitead Aug 30, 2024
3ff2c4a
Reran pre-commit
whitead Aug 30, 2024
daecc4b
Merge branch 'september-2024-release' into issue-283
whitead Aug 30, 2024
02ba597
Stashing progress
whitead Aug 31, 2024
c3bd35c
Stashing progress
whitead Sep 1, 2024
d3706f4
Finished refactor
whitead Sep 3, 2024
d0fbdf5
Fixed ruff errors
whitead Sep 3, 2024
ad3a2da
Fixed all type hinting
whitead Sep 3, 2024
5aeb270
Made it possible to load named configs
whitead Sep 3, 2024
2038c37
Making progress on tests
whitead Sep 3, 2024
c2fe115
halfway through tests
whitead Sep 3, 2024
bc01d59
Added back all unit tests
whitead Sep 4, 2024
cf9f670
Fixed linting errors
whitead Sep 4, 2024
07d3cf8
Reenable CI
whitead Sep 4, 2024
20dd67e
Got indexes working again
whitead Sep 4, 2024
b183e2c
Stashing progress on agent rewrite
whitead Sep 5, 2024
3eb8e90
Agent tests finally pass
whitead Sep 5, 2024
c0783d3
Finished agent tests
whitead Sep 5, 2024
b2a9c8b
More work on tests
whitead Sep 5, 2024
e69ce1e
Merge branch september-2024-release into issue-283
mskarlin Sep 5, 2024
b9b66c9
removed unused imports and remove python label from docstring to avoi…
mskarlin Sep 5, 2024
ca66124
Rewrote CLI to use settings objects
whitead Sep 6, 2024
ac7833f
Stashing progress
whitead Sep 6, 2024
33a0d79
Moving to `uv` for installation/CI, parallel `pre-commit` in CI (#316)
jamesbraza Sep 6, 2024
b68fd91
Fixing `pybtex` import by requiring `setuptools` (#318)
jamesbraza Sep 6, 2024
0ed9c69
Fixing test installation in CI by specifying missing dependencies (#319)
jamesbraza Sep 6, 2024
9ec067e
Got CLI to work nicely
whitead Sep 6, 2024
2dcb659
Merge branch 'issue-283' of github.com:whitead/paper-qa into issue-283
whitead Sep 6, 2024
b8998c6
LiteLLM integration (#315)
mskarlin Sep 7, 2024
5ed57b4
Can now save and load settings
whitead Sep 7, 2024
e2aa2d0
Merge branch 'issue-283' of github.com:whitead/paper-qa into issue-283
whitead Sep 7, 2024
51499ad
Removed old CLI tests
whitead Sep 7, 2024
c3c314e
Fixed logging and tests
whitead Sep 8, 2024
2ac3f1e
Addressed some PR comments
whitead Sep 8, 2024
bc91c03
More PR Comments
whitead Sep 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ fabric.properties
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
.vscode

# Local History for Visual Studio Code
.history/
Expand Down Expand Up @@ -301,5 +302,6 @@ env

# Matching pyproject.toml
paperqa/version.py
tests/example*
tests/*txt
tests/*html
tests/test_index/*
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ local_client = AsyncOpenAI(

docs = Docs(
client=local_client,
docs_index=NumpyVectorStore(embedding_model=LlamaEmbeddingModel()),
texts_index=NumpyVectorStore(embedding_model=LlamaEmbeddingModel()),
llm_model=OpenAILLMModel(
config=dict(
Expand Down Expand Up @@ -201,15 +200,12 @@ docs = Docs(embedding="text-embedding-3-large")
- `"hybrid-<model_name>"` i.e. `"hybrid-text-embedding-3-small"` to use a hybrid sparse keyword (based on a token modulo embedding) and dense vector embedding, any OpenAI or VoyageAI model can be used in the dense model name
- `"sparse"` to use a sparse keyword embedding only

For deeper embedding customization, embedding models and vector stores can be built separately and passed into the `Docs` object. Embedding models are used to create both paper-qa's index of document citation embedding vectors (`docs_index` argument) as well as the full-text embedding vectors (`texts_index` argument). They can both be specified as arguments when you create a new `Docs` object. You can use use any embedding model which implements paper-qa's `EmbeddingModel` class. For example, to use `text-embedding-3-large`:
For deeper embedding customization, embedding models and vector stores can be built separately and passed into the `Docs` object. Embedding models are used to create paper-qa's index of the full-text embedding vectors (`texts_index` argument). They can both be specified as arguments when you create a new `Docs` object. You can use use any embedding model which implements paper-qa's `EmbeddingModel` class. For example, to use `text-embedding-3-large`:

```python
from paperqa import Docs, NumpyVectorStore, OpenAIEmbeddingModel

docs = Docs(
docs_index=NumpyVectorStore(
embedding_model=OpenAIEmbeddingModel(name="text-embedding-3-large")
),
texts_index=NumpyVectorStore(
embedding_model=OpenAIEmbeddingModel(name="text-embedding-3-large")
),
Expand All @@ -224,7 +220,6 @@ from langchain_openai import OpenAIEmbeddings
from paperqa import Docs, LangchainVectorStore

docs = Docs(
docs_index=LangchainVectorStore(cls=FAISS, embedding_model=OpenAIEmbeddings()),
texts_index=LangchainVectorStore(cls=FAISS, embedding_model=OpenAIEmbeddings()),
)
```
Expand All @@ -243,7 +238,6 @@ local_client = AsyncOpenAI(

docs = Docs(
client=local_client,
docs_index=NumpyVectorStore(embedding_model=SentenceTransformerEmbeddingModel()),
texts_index=NumpyVectorStore(embedding_model=SentenceTransformerEmbeddingModel()),
llm_model=OpenAILLMModel(
config=dict(
Expand All @@ -260,7 +254,6 @@ from paperqa import Docs, HybridEmbeddingModel, SparseEmbeddingModel, NumpyVecto

model = HybridEmbeddingModel(models=[OpenAIEmbeddingModel(), SparseEmbeddingModel()])
docs = Docs(
docs_index=NumpyVectorStore(embedding_model=model),
texts_index=NumpyVectorStore(embedding_model=model),
)
```
Expand Down Expand Up @@ -318,7 +311,6 @@ from langchain_openai import OpenAIEmbeddings

docs = Docs(
texts_index=LangchainVectorStore(cls=FAISS, embedding_model=OpenAIEmbeddings()),
docs_index=LangchainVectorStore(cls=FAISS, embedding_model=OpenAIEmbeddings()),
)
```

Expand Down
8 changes: 5 additions & 3 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .docs import Answer, Context, Doc, Docs, PromptCollection, Text, print_callback
from .config import Settings, get_settings
from .docs import Answer, Docs, print_callback
from .llms import (
AnthropicLLMModel,
EmbeddingModel,
Expand All @@ -18,7 +19,7 @@
llm_model_factory,
vector_store_factory,
)
from .types import DocDetails
from .types import Context, Doc, DocDetails, Text
from .version import __version__

__all__ = [
Expand All @@ -39,12 +40,13 @@
"NumpyVectorStore",
"OpenAIEmbeddingModel",
"OpenAILLMModel",
"PromptCollection",
"SentenceTransformerEmbeddingModel",
"Settings",
"SparseEmbeddingModel",
"Text",
"__version__",
"embedding_model_factory",
"get_settings",
"llm_model_factory",
"print_callback",
"vector_store_factory",
Expand Down
25 changes: 3 additions & 22 deletions paperqa/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Annotated, Any

import yaml
from typing_extensions import Annotated

from .. import __version__
from ..utils import get_loop, pqa_directory

try:
import anyio
import typer
from rich.console import Console
from rich.logging import RichHandler
Expand Down Expand Up @@ -105,7 +103,7 @@ def parse_dot_to_dict(str_w_dots: str, value: str) -> dict[str, Any]:
if not parsed:
try:
eval_value = ast.literal_eval(value)
if isinstance(eval_value, (set, list)):
if isinstance(eval_value, set | list):
parsed[key] = eval_value
else:
parsed[key] = value
Expand Down Expand Up @@ -410,7 +408,6 @@ def ask(
docs=None,
verbosity=verbosity,
agent_type=agent_type,
index_directory=request.agent_tools.index_directory,
)
)

Expand Down Expand Up @@ -511,23 +508,7 @@ def build_index(
loop = get_loop()

return loop.run_until_complete(
get_directory_index(
directory=anyio.Path(request_settings.agent_tools.paper_directory),
index_directory=request_settings.agent_tools.index_directory,
index_name=request_settings.get_index_name(
request_settings.agent_tools.paper_directory,
request_settings.embedding,
request_settings.parsing_configuration,
),
manifest_file=(
anyio.Path(request_settings.agent_tools.manifest_file)
if request_settings.agent_tools.manifest_file
else None
),
embedding=request_settings.embedding,
chunk_chars=request_settings.parsing_configuration.chunksize,
overlap=request_settings.parsing_configuration.overlap,
)
get_directory_index(settings=request_settings.settings)
)


Expand Down
35 changes: 15 additions & 20 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,28 @@ def update_doc_models(doc: Docs, request: QueryRequest | None = None):
request = QueryRequest()
client: Any = None

if request.llm.startswith("gemini"):
doc.llm_model = LangchainLLMModel(name=request.llm)
doc.summary_llm_model = LangchainLLMModel(name=request.summary_llm)
if request.settings.llm.startswith("gemini"):
doc.llm_model = LangchainLLMModel(name=request.settings.llm)
doc.summary_llm_model = LangchainLLMModel(name=request.settings.summary_llm)
else:
doc.llm_model = llm_model_factory(request.llm)
doc.summary_llm_model = llm_model_factory(request.summary_llm)
doc.llm_model = llm_model_factory(request.settings.llm)
doc.summary_llm_model = llm_model_factory(request.settings.summary_llm)

# set temperatures
doc.llm_model.config["temperature"] = request.temperature
doc.summary_llm_model.config["temperature"] = request.temperature
doc.llm_model.config["temperature"] = request.settings.temperature
doc.summary_llm_model.config["temperature"] = request.settings.temperature

if isinstance(doc.llm_model, OpenAILLMModel):
if request.llm.startswith(
if request.settings.llm.startswith(
("meta-llama/Meta-Llama-3-", "mistralai/Mistral-", "mistralai/Mixtral-")
):
client = AsyncOpenAI(
base_url=os.environ.get("ANYSCALE_BASE_URL"),
api_key=os.environ.get("ANYSCALE_API_KEY"),
)
logger.info(f"Using Anyscale (via OpenAI client) for {request.llm}")
logger.info(
f"Using Anyscale (via OpenAI client) for {request.settings.llm}"
)
else:
client = AsyncOpenAI()
elif isinstance(doc.llm_model, AnthropicLLMModel):
Expand All @@ -203,7 +205,7 @@ def update_doc_models(doc: Docs, request: QueryRequest | None = None):
# we have to convert system to human because system is unsupported
# Also we do get blocked content, so adjust thresholds
client = ChatVertexAI(
model=request.llm,
model=request.settings.llm,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
Expand All @@ -219,20 +221,13 @@ def update_doc_models(doc: Docs, request: QueryRequest | None = None):
doc._embedding_client = AsyncOpenAI() # hard coded to OpenAI for now

doc.texts_index.embedding_model = embedding_model_factory(
request.embedding, **(request.texts_index_embedding_config or {})
)
doc.docs_index.embedding_model = embedding_model_factory(
request.embedding, **(request.docs_index_embedding_config or {})
request.settings.embedding, **(request.settings.embedding_config or {})
)
doc.texts_index.mmr_lambda = request.texts_index_mmr_lambda
doc.docs_index.mmr_lambda = request.docs_index_mmr_lambda
doc.embedding = request.embedding
doc.max_concurrent = request.max_concurrent
doc.prompts = request.prompts
doc.texts_index.mmr_lambda = request.settings.texts_index_mmr_lambda
doc.embedding = request.settings.embedding
Docs.make_llm_names_consistent(doc)

logger.debug(
f"update_doc_models: {doc.name}"
f" | {(doc.llm_model.config)} | {(doc.summary_llm_model.config)}"
f" | {doc.docs_index.__class__}"
)
43 changes: 21 additions & 22 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,13 @@ async def agent_query(
docs: Docs | None = None,
agent_type: str = "OpenAIFunctionsAgent",
verbosity: int = 0,
index_directory: str | os.PathLike | None = None,
) -> AnswerResponse:
if isinstance(query, str):
query = QueryRequest(query=query)

if docs is None:
docs = Docs()

if index_directory is None:
index_directory = pqa_directory("indexes")

# in-place modification of the docs object to match query
update_doc_models(
docs,
Expand All @@ -65,7 +61,7 @@ async def agent_query(
search_index = SearchIndex(
fields=[*SearchIndex.REQUIRED_FIELDS, "question"],
index_name="answers",
index_directory=index_directory,
index_directory=query.settings.agent.index_directory,
storage=SearchDocumentStorage.JSON_MODEL_DUMP,
)

Expand Down Expand Up @@ -110,7 +106,7 @@ async def run_agent(
Tuple of resultant answer, token counts, and agent status.
"""
profiler = SimpleProfiler()
outer_profile_name = f"agent-{agent_type}-{query.agent_llm}"
outer_profile_name = f"agent-{agent_type}-{query.settings.agent.agent_llm}"
profiler.start(outer_profile_name)

logger.info(
Expand Down Expand Up @@ -141,8 +137,10 @@ async def run_fake_agent(
query: QueryRequest,
docs: Docs,
) -> tuple[Answer, AgentStatus]:
answer = Answer(question=query.query, dockey_filter=set(), id=query.id)
tools = query_to_tools(query, state=SharedToolState(docs=docs, answer=answer))
answer = Answer(question=query.query, id=query.id)
tools = query_to_tools(
query, state=SharedToolState(docs=docs, answer=answer, settings=query.settings)
)
search_tool = cast(
PaperSearchTool,
next(
Expand Down Expand Up @@ -170,7 +168,7 @@ async def run_fake_agent(
)
# seed docs with keyword search
for search in await openai_get_search_query(
answer.question, llm=query.llm, count=3
answer.question, llm=query.settings.llm, count=3
):
await search_tool.arun(search)

Expand All @@ -194,15 +192,17 @@ async def run_langchain_agent(
profiler: SimpleProfiler,
timeout: float | None = None, # noqa: ASYNC109
) -> tuple[Answer, AgentStatus]:
answer = Answer(question=query.query, dockey_filter=set(), id=query.id)
answer = Answer(question=query.query, id=query.id)
shared_callbacks: list[BaseCallbackHandler] = [
AgentCallback(
profiler, name=f"step-{agent_type}-{query.agent_llm}", answer_id=answer.id
profiler,
name=f"step-{agent_type}-{query.settings.agent.agent_llm}",
answer_id=answer.id,
),
]
tools = query_to_tools(
query,
state=SharedToolState(docs=docs, answer=answer),
state=SharedToolState(docs=docs, answer=answer, settings=query.settings),
callbacks=shared_callbacks,
)
try:
Expand All @@ -223,25 +223,25 @@ async def run_langchain_agent(
)

# optionally use the search tool before the agent
if search_tool is not None and query.agent_tools.should_pre_search:
if search_tool is not None and query.settings.agent.should_pre_search:
logger.debug("Running search tool before agent choice.")
await search_tool.arun(answer.question)
else:
logger.debug("Skipping search tool before agent choice.")

llm = ChatOpenAI(
model=query.agent_llm,
request_timeout=timeout or query.agent_tools.timeout / 2.0,
temperature=query.temperature,
model=query.settings.agent.agent_llm,
request_timeout=timeout or query.settings.agent.timeout / 2.0,
temperature=query.settings.temperature,
)
agent_status = AgentStatus.SUCCESS
cost_callback = OpenAICallbackHandler()
agent_instance = LANGCHAIN_AGENT_TYPES[agent_type].from_llm_and_tools(
llm,
tools,
system_message=(
SystemMessage(content=query.agent_tools.agent_system_prompt)
if query.agent_tools.agent_system_prompt
SystemMessage(content=query.settings.agent.agent_system_prompt)
if query.settings.agent.agent_system_prompt
else None
),
)
Expand All @@ -251,9 +251,8 @@ async def run_langchain_agent(
agent=agent_instance,
return_intermediate_steps=True,
handle_parsing_errors=True,
max_execution_time=query.agent_tools.timeout,
max_execution_time=query.settings.agent.timeout,
callbacks=[*shared_callbacks, cost_callback],
**(query.agent_tools.agent_config or {}),
)

async def aplan_with_injected_callbacks(
Expand All @@ -276,7 +275,7 @@ async def aplan_with_injected_callbacks(
input={
# NOTE: str.format still works even if the prompt doesn't have
# template fields like 'status' or 'gen_answer_tool_name'
"input": query.agent_tools.agent_prompt.format(
"input": query.settings.agent.agent_prompt.format(
question=answer.question,
status=await status(docs, answer),
gen_answer_tool_name=answer_tool.name,
Expand All @@ -297,7 +296,7 @@ async def aplan_with_injected_callbacks(
if "Agent stopped" in call_response["output"]:
# Log that this agent has gone over timeout, and then answer directly
logger.warning(
f"Agent timeout after {query.agent_tools.timeout}-sec, just answering."
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering."
)
await answer_tool.arun(answer.question)
agent_status = AgentStatus.TIMEOUT
Expand Down
Loading
Loading