Skip to content

CLI Fixes #322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions paperqa/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import argparse
import logging
import os
from datetime import datetime
from typing import Any

from pydantic_settings import CliSettingsSource
Expand All @@ -16,7 +15,7 @@
from rich.console import Console
from rich.logging import RichHandler

from .main import agent_query, search
from .main import agent_query, index_search
from .models import AnswerResponse, QueryRequest
from .search import SearchIndex, get_directory_index

Expand Down Expand Up @@ -44,6 +43,8 @@ def configure_cli_logging(verbosity: int = 0) -> None:
"paperqa.agents.models": logging.WARNING,
"paperqa.agents.search": logging.INFO,
"litellm": logging.WARNING,
"LiteLLM Router": logging.WARNING,
"LiteLLM Proxy": logging.WARNING,
}
}

Expand All @@ -59,6 +60,8 @@ def configure_cli_logging(verbosity: int = 0) -> None:
"paperqa.models": logging.DEBUG,
"paperqa.agents.search": logging.DEBUG,
"litellm": logging.INFO,
"LiteLLM Router": logging.INFO,
"LiteLLM Proxy": logging.INFO,
}

verbosity_map[3] = verbosity_map[2] | {
Expand Down Expand Up @@ -90,22 +93,6 @@ def configure_cli_logging(verbosity: int = 0) -> None:
print(f"PaperQA version: {__version__}")


def get_file_timestamps(path: os.PathLike | str) -> dict[str, str]:
# Get the stats for the file/directory
stats = os.stat(path)

# Get created time (ctime)
created_time = datetime.fromtimestamp(stats.st_ctime)

# Get modified time (mtime)
modified_time = datetime.fromtimestamp(stats.st_mtime)

return {
"created_at": created_time.strftime("%Y-%m-%d %H:%M:%S"),
"modified_at": modified_time.strftime("%Y-%m-%d %H:%M:%S"),
}


def ask(query: str, settings: Settings) -> AnswerResponse:
"""Query PaperQA via an agent."""
configure_cli_logging(verbosity=settings.verbosity)
Expand Down Expand Up @@ -138,7 +125,7 @@ def search_query(
index_name = settings.get_index_name()
loop = get_loop()
return loop.run_until_complete(
search(
index_search(
query,
index_name=index_name,
index_directory=settings.index_directory,
Expand All @@ -147,13 +134,20 @@ def search_query(


def build_index(
index_name: str,
directory: str | os.PathLike,
settings: Settings,
) -> SearchIndex:
"""Build a PaperQA search index, this will also happen automatically upon using `ask`."""
if index_name == "default":
index_name = settings.get_index_name()
configure_cli_logging(verbosity=settings.verbosity)
settings.paper_directory = directory
loop = get_loop()

return loop.run_until_complete(get_directory_index(settings=settings))
return loop.run_until_complete(
get_directory_index(index_name=index_name, settings=settings)
)


def save_settings(
Expand Down Expand Up @@ -190,6 +184,10 @@ def main():
help="Named settings to use. Will search in local, pqa directory, and package last",
)

parser.add_argument(
"--index", "-i", default="default", help="Index name to search or create"
)

subparsers = parser.add_subparsers(
title="commands", dest="command", description="Available commands"
)
Expand All @@ -208,12 +206,14 @@ def main():
search_parser = subparsers.add_parser(
"search",
help="Search the index specified by --index."
" Pass --index answers to search previous answers.",
" Pass `--index answers` to search previous answers.",
)
search_parser.add_argument("query", help="Keyword search")
search_parser.add_argument(
"-i", dest="index", default="default", help="Index to search"

build_parser = subparsers.add_parser(
"index", help="Build a search index from given directory"
)
build_parser.add_argument("directory", help="Directory to build index from")

# Create CliSettingsSource instance
cli_settings = CliSettingsSource(Settings, root_parser=parser)
Expand All @@ -237,7 +237,7 @@ def main():
case "search":
search_query(args.query, args.index, settings)
case "index":
build_index(args.verbosity)
build_index(args.index, args.directory, settings)
case _:
commands = ", ".join({"view", "ask", "search", "index"})
brief_help = f"\nRun with commands: {{{commands}}}\n\n"
Expand Down
12 changes: 5 additions & 7 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ async def litellm_get_search_query(
f"The current year is {get_year()}."
)

if "gpt" not in llm:
raise ValueError(
f"Invalid llm: {llm}, note a GPT model must be used for the fake agent search."
)
model = LiteLLMModel(name=llm)
model.config["model_list"][0]["litellm_params"].update({"temperature": temperature})
chain = model.make_chain(prompt=search_prompt, skip_system=True)
Expand Down Expand Up @@ -91,9 +87,11 @@ def table_formatter(
table.add_column("Title", style="cyan")
table.add_column("File", style="magenta")
for obj, filename in objects:
table.add_row(
cast(Docs, obj).texts[0].doc.title[:max_chars_per_column], filename # type: ignore[attr-defined]
)
try:
display_name = cast(Docs, obj).texts[0].doc.title # type: ignore[attr-defined]
except AttributeError:
display_name = cast(Docs, obj).texts[0].doc.citation
table.add_row(display_name[:max_chars_per_column], filename)
return table
raise NotImplementedError(
f"Object type {type(example_object)} can not be converted to table."
Expand Down
16 changes: 12 additions & 4 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,23 @@ async def aplan_with_injected_callbacks(
return answer, agent_status


async def search(
async def index_search(
query: str,
index_name: str = "answers",
index_directory: str | os.PathLike | None = None,
) -> list[tuple[AnswerResponse, str] | tuple[Any, str]]:
fields = [*SearchIndex.REQUIRED_FIELDS]
if index_name == "answers":
fields.append("question")
search_index = SearchIndex(
["file_location", "body", "question"],
fields=fields,
index_name=index_name,
index_directory=index_directory or pqa_directory("indexes"),
storage=SearchDocumentStorage.JSON_MODEL_DUMP,
storage=(
SearchDocumentStorage.JSON_MODEL_DUMP
if index_name == "answers"
else SearchDocumentStorage.PICKLE_COMPRESSED
),
)

results = [
Expand All @@ -320,6 +327,7 @@ async def search(
# Render the table to a string
console.print(table_formatter(results))
else:
agent_logger.info("No results found.")
count = await search_index.count
agent_logger.info(f"No results found. Searched {count} docs")

return results
10 changes: 8 additions & 2 deletions paperqa/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ async def searcher(self) -> Searcher:
self._searcher = index.searcher()
return self._searcher

@property
async def count(self) -> int:
return (await self.searcher).num_docs

@property
async def index_files(self) -> dict[str, str]:
if not self._index_files:
Expand Down Expand Up @@ -284,7 +288,7 @@ async def get_saved_object(
return None

def clean_query(self, query: str) -> str:
for replace in ("*", "[", "]"):
for replace in ("*", "[", "]", ":", "(", ")", "{", "}", "~"):
query = query.replace(replace, "")
return query

Expand Down Expand Up @@ -395,6 +399,7 @@ async def process_file(


async def get_directory_index(
index_name: str | None = None,
sync_index_w_directory: bool = True,
settings: MaybeSettings = None,
) -> SearchIndex:
Expand All @@ -403,6 +408,7 @@ async def get_directory_index(

Args:
sync_index_w_directory: Sync the index with the directory. (i.e. delete files not in directory)
index_name: Name of the index. If not given, the name will be taken from the settings
settings: Application settings.
"""
_settings = get_settings(settings)
Expand All @@ -415,7 +421,7 @@ async def get_directory_index(

search_index = SearchIndex(
fields=[*SearchIndex.REQUIRED_FIELDS, "title", "year"],
index_name=_settings.get_index_name(),
index_name=index_name or _settings.get_index_name(),
index_directory=_settings.index_directory,
)

Expand Down
2 changes: 0 additions & 2 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ async def aadd( # noqa: PLR0912
doi = citation_doi
if citation_author := citation_json.get("authors"):
authors = citation_author

# see if we can upgrade to DocDetails
# if not, we can progress with a normal Doc
# if "overwrite_fields_from_metadata" is used:
Expand All @@ -349,7 +348,6 @@ async def aadd( # noqa: PLR0912
query_kwargs["authors"] = authors
if title:
query_kwargs["title"] = title

doc = await metadata_client.upgrade_doc_to_doc_details(
doc, **(query_kwargs | kwargs)
)
Expand Down
2 changes: 2 additions & 0 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def setup_default_logs() -> None:
"httpx": {"level": "WARNING"},
# SEE: https://github.com/BerriAI/litellm/issues/2256
"LiteLLM": {"level": "WARNING"},
"LiteLLM Router": {"level": "WARNING"},
"LiteLLM Proxy": {"level": "WARNING"},
},
}
)
65 changes: 65 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import io
import os
import sys
from pathlib import Path

import pytest

from paperqa.settings import Settings
from paperqa.utils import pqa_directory

try:
from paperqa.agents import ask, build_index, main, search_query
from paperqa.agents.models import AnswerResponse
except ImportError:
pytest.skip("agents module is not installed", allow_module_level=True)


def test_can_modify_settings():
old_argv = sys.argv
old_stdout = sys.stdout
captured_output = io.StringIO()
try:
sys.argv = "paperqa -s debug --llm=my-model-foo save unit_test".split()
main()

sys.stdout = captured_output
assert Settings.from_name("unit_test").llm == "my-model-foo"

sys.argv = "paperqa -s unit_test view".split()
main()

output = captured_output.getvalue().strip()
assert "my-model-foo" in output
finally:
sys.argv = old_argv
sys.stdout = old_stdout
os.unlink(pqa_directory("settings") / "unit_test.json")


def test_cli_ask(agent_index_dir: Path, stub_data_dir: Path):
settings = Settings.from_name("debug")
settings.index_directory = agent_index_dir
settings.paper_directory = stub_data_dir
response = ask(
"How can you use XAI for chemical property prediction?", settings=settings
)
assert response.answer.formatted_answer

search_result = search_query(
" ".join(response.answer.formatted_answer.split()[:5]),
"answers",
settings,
)
found_answer = search_result[0][0]
assert isinstance(found_answer, AnswerResponse)
assert found_answer.model_dump_json() == response.model_dump_json()


def test_cli_can_build_and_search_index(agent_index_dir: Path, stub_data_dir: Path):
settings = Settings.from_name("debug")
settings.index_directory = agent_index_dir
index_name = "test"
build_index(index_name, stub_data_dir, settings)
search_result = search_query("XAI", index_name, settings)
assert search_result
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading