Skip to content

All Ruff ANN autofixes #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paperqa/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def save_settings(
logger.info(f"Settings saved to: {full_settings_path}")


def main():
def main() -> None:
parser = argparse.ArgumentParser(description="PaperQA CLI")

parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def start(self, name: str) -> None:
except RuntimeError: # No running event loop (not in async)
self.running_timers[name] = TimerData(start_time=time.time())

def stop(self, name: str):
def stop(self, name: str) -> None:
timer_data = self.running_timers.pop(name, None)
if timer_data:
try:
Expand Down
2 changes: 1 addition & 1 deletion paperqa/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def filecheck(self, filename: str, body: str | None = None):

async def add_document(
self, index_doc: dict, document: Any | None = None, max_retries: int = 1000
):
) -> None:
@retry(
stop=stop_after_attempt(max_retries),
wait=wait_random_exponential(multiplier=0.25, max=60),
Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class DOINotFoundError(Exception):
def __init__(self, message="DOI not found"):
def __init__(self, message="DOI not found") -> None:
self.message = message
super().__init__(self.message)
8 changes: 4 additions & 4 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@


# this is just to reduce None checks/type checks
async def empty_callback(result: LLMResult):
async def empty_callback(result: LLMResult) -> None:
pass


async def print_callback(result: LLMResult):
async def print_callback(result: LLMResult) -> None:
pass


Expand All @@ -85,7 +85,7 @@ def handle_default(cls, value: Path | None, info: ValidationInfo) -> Path | None
return PAPERQA_DIR / info.data["name"]
return value

def clear_docs(self):
def clear_docs(self) -> None:
self.texts = []
self.docs = {}
self.docnames = set()
Expand Down Expand Up @@ -451,7 +451,7 @@ def delete(
self.deleted_dockeys.add(dockey)
self.texts = list(filter(lambda x: x.doc.dockey != dockey, self.texts))

def _build_texts_index(self):
def _build_texts_index(self) -> None:
texts = [t for t in self.texts if t not in self.texts_index]
self.texts_index.add_texts_and_embeddings(texts)

Expand Down
6 changes: 3 additions & 3 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,10 @@ class VectorStore(BaseModel, ABC):
model_config = ConfigDict(extra="forbid")
texts_hashes: set[int] = Field(default_factory=set)

def __contains__(self, item):
def __contains__(self, item) -> bool:
return hash(item) in self.texts_hashes

def __len__(self):
def __len__(self) -> int:
return len(self.texts_hashes)

@abstractmethod
Expand Down Expand Up @@ -614,7 +614,7 @@ class LangchainVectorStore(VectorStore):
class_type: type[Embeddable] = Field(default=Embeddable)
model_config = ConfigDict(extra="forbid")

def __init__(self, **data):
def __init__(self, **data) -> None:
raise NotImplementedError(
"Langchain has updated vectorstore internals and this is not yet supported"
)
Expand Down
4 changes: 2 additions & 2 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class LLMResult(BaseModel):
default=0.0, description="Delta time (sec) to last response token's arrival."
)

def __str__(self):
def __str__(self) -> str:
return self.text

@computed_field # type: ignore[prop-decorator]
Expand Down Expand Up @@ -198,7 +198,7 @@ def get_citation(self, name: str) -> str:
raise ValueError(f"Could not find docname {name} in contexts.") from exc
return doc.citation

def add_tokens(self, result: LLMResult):
def add_tokens(self, result: LLMResult) -> None:
"""Update the token counts for the given result."""
if result.model not in self.token_counts:
self.token_counts[result.model] = [
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@


@pytest.fixture(autouse=True, scope="session")
def _load_env():
def _load_env() -> None:
load_dotenv()


@pytest.fixture(autouse=True)
def _setup_default_logs():
def _setup_default_logs() -> None:
setup_default_logs()


Expand Down
10 changes: 5 additions & 5 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def test_get_directory_index(agent_test_settings) -> None:
@pytest.mark.asyncio
async def test_get_directory_index_w_manifest(
agent_test_settings, reset_log_levels, caplog # noqa: ARG001
):
) -> None:
agent_test_settings.manifest_file = "stub_manifest.csv"
index = await get_directory_index(settings=agent_test_settings)
assert index.fields == [
Expand All @@ -71,7 +71,7 @@ async def test_get_directory_index_w_manifest(
@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError", "httpx.RemoteProtocolError"])
@pytest.mark.parametrize("agent_type", ["fake", "OpenAIFunctionsAgent"])
@pytest.mark.asyncio
async def test_agent_types(agent_test_settings, agent_type):
async def test_agent_types(agent_test_settings, agent_type) -> None:

question = "How can you use XAI for chemical property prediction?"

Expand All @@ -87,7 +87,7 @@ async def test_agent_types(agent_test_settings, agent_type):


@pytest.mark.asyncio
async def test_timeout(agent_test_settings):
async def test_timeout(agent_test_settings) -> None:
agent_test_settings.prompts.pre = None
agent_test_settings.agent.timeout = 0.001
agent_test_settings.llm = "gpt-4o-mini"
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_functions() -> None:
]


def test_query_request_docs_name_serialized():
def test_query_request_docs_name_serialized() -> None:
"""Test that the query request has a docs_name property."""
request = QueryRequest(query="Are COVID-19 vaccines effective?")
request_data = json.loads(request.model_dump_json())
Expand All @@ -298,7 +298,7 @@ def test_query_request_docs_name_serialized():
assert request_data["docs_name"] == "my_doc"


def test_answers_are_striped():
def test_answers_are_striped() -> None:
"""Test that answers are striped."""
answer = Answer(
question="What is the meaning of life?",
Expand Down
8 changes: 5 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
pytest.skip("agents module is not installed", allow_module_level=True)


def test_can_modify_settings():
def test_can_modify_settings() -> None:
old_argv = sys.argv
old_stdout = sys.stdout
captured_output = io.StringIO()
Expand All @@ -37,7 +37,7 @@ def test_can_modify_settings():
os.unlink(pqa_directory("settings") / "unit_test.json")


def test_cli_ask(agent_index_dir: Path, stub_data_dir: Path):
def test_cli_ask(agent_index_dir: Path, stub_data_dir: Path) -> None:
settings = Settings.from_name("debug")
settings.index_directory = agent_index_dir
settings.paper_directory = stub_data_dir
Expand All @@ -56,7 +56,9 @@ def test_cli_ask(agent_index_dir: Path, stub_data_dir: Path):
assert found_answer.model_dump_json() == response.model_dump_json()


def test_cli_can_build_and_search_index(agent_index_dir: Path, stub_data_dir: Path):
def test_cli_can_build_and_search_index(
agent_index_dir: Path, stub_data_dir: Path
) -> None:
settings = Settings.from_name("debug")
settings.index_directory = agent_index_dir
index_name = "test"
Expand Down
28 changes: 14 additions & 14 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
],
)
@pytest.mark.asyncio
async def test_title_search(paper_attributes: dict[str, str]):
async def test_title_search(paper_attributes: dict[str, str]) -> None:
async with aiohttp.ClientSession() as session:
client = DocMetadataClient(session, clients=ALL_CLIENTS)
details = await client.query(title=paper_attributes["title"])
Expand Down Expand Up @@ -192,7 +192,7 @@ async def test_title_search(paper_attributes: dict[str, str]):
],
)
@pytest.mark.asyncio
async def test_doi_search(paper_attributes: dict[str, str]):
async def test_doi_search(paper_attributes: dict[str, str]) -> None:
async with aiohttp.ClientSession() as session:
client = DocMetadataClient(session, clients=ALL_CLIENTS)
details = await client.query(doi=paper_attributes["doi"])
Expand All @@ -210,7 +210,7 @@ async def test_doi_search(paper_attributes: dict[str, str]):

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_bulk_doi_search():
async def test_bulk_doi_search() -> None:
dois = [
"10.1063/1.4938384",
"10.48550/arxiv.2312.07559",
Expand All @@ -228,7 +228,7 @@ async def test_bulk_doi_search():

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_bulk_title_search():
async def test_bulk_title_search() -> None:
titles = [
(
"Effect of native oxide layers on copper thin-film tensile properties: A"
Expand All @@ -253,7 +253,7 @@ async def test_bulk_title_search():

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_bad_titles():
async def test_bad_titles() -> None:
async with aiohttp.ClientSession() as session:
client = DocMetadataClient(session)
details = await client.query(title="askldjrq3rjaw938h")
Expand All @@ -269,7 +269,7 @@ async def test_bad_titles():

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_bad_dois():
async def test_bad_dois() -> None:
async with aiohttp.ClientSession() as session:
client = DocMetadataClient(session)
details = await client.query(title="abs12032jsdafn")
Expand All @@ -278,7 +278,7 @@ async def test_bad_dois():

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_minimal_fields_filtering():
async def test_minimal_fields_filtering() -> None:
async with aiohttp.ClientSession() as session:
client = DocMetadataClient(session)
details = await client.query(
Expand All @@ -303,7 +303,7 @@ async def test_minimal_fields_filtering():

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_s2_only_fields_filtering():
async def test_s2_only_fields_filtering() -> None:
async with aiohttp.ClientSession() as session:
# now get with authors just from one source
s2_client = DocMetadataClient(session, clients=[SemanticScholarProvider])
Expand Down Expand Up @@ -373,7 +373,7 @@ async def test_crossref_journalquality_fields_filtering() -> None:

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_author_matching():
async def test_author_matching() -> None:
async with aiohttp.ClientSession() as session:
crossref_client = DocMetadataClient(session, clients=[CrossrefProvider])
s2_client = DocMetadataClient(session, clients=[SemanticScholarProvider])
Expand Down Expand Up @@ -402,7 +402,7 @@ async def test_author_matching():

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_odd_client_requests():
async def test_odd_client_requests() -> None:
# try querying using an authors match, but not requesting authors back
async with aiohttp.ClientSession() as session:
client = DocMetadataClient(session)
Expand Down Expand Up @@ -449,7 +449,7 @@ async def test_odd_client_requests():


@pytest.mark.asyncio
async def test_ensure_robust_to_timeouts(monkeypatch):
async def test_ensure_robust_to_timeouts(monkeypatch) -> None:
# 0.15 should be short enough to not get a response in time.
monkeypatch.setattr(paperqa.clients.crossref, "CROSSREF_API_REQUEST_TIMEOUT", 0.05)
monkeypatch.setattr(
Expand All @@ -466,7 +466,7 @@ async def test_ensure_robust_to_timeouts(monkeypatch):


@pytest.mark.asyncio
async def test_bad_init():
async def test_bad_init() -> None:
with pytest.raises(
ValueError, match="At least one MetadataProvider must be provided."
):
Expand All @@ -475,7 +475,7 @@ async def test_bad_init():

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_ensure_sequential_run(caplog, reset_log_levels): # noqa: ARG001
async def test_ensure_sequential_run(caplog, reset_log_levels) -> None: # noqa: ARG001
caplog.set_level(logging.DEBUG)
# were using a DOI that is NOT in crossref, but running the crossref client first
# we will ensure that both are run sequentially
Expand Down Expand Up @@ -512,7 +512,7 @@ async def test_ensure_sequential_run(caplog, reset_log_levels): # noqa: ARG001
@pytest.mark.asyncio
async def test_ensure_sequential_run_early_stop(
caplog, reset_log_levels # noqa: ARG001
):
) -> None:
caplog.set_level(logging.DEBUG)
# now we should stop after hitting s2
async with aiohttp.ClientSession() as session:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)


def test_prompt_settings_validation():
def test_prompt_settings_validation() -> None:
with pytest.raises(ValidationError):
PromptSettings(summary="Invalid {variable}")

Expand All @@ -27,13 +27,13 @@ def test_prompt_settings_validation():
assert valid_pre_settings.pre == "{question}"


def test_get_formatted_variables():
def test_get_formatted_variables() -> None:
template = "This is a test {variable} with {another_variable}"
variables = get_formatted_variables(template)
assert variables == {"variable", "another_variable"}


def test_get_settings_with_valid_config():
def test_get_settings_with_valid_config() -> None:
settings = get_settings("fast")
assert not settings.parsing.use_doc_details

Expand All @@ -46,7 +46,7 @@ def test_get_settings_missing_file() -> None:
get_settings("missing_config")


def test_settings_default_instantiation():
def test_settings_default_instantiation() -> None:
settings = Settings()
assert "gpt-" in settings.llm
assert settings.answer.evidence_k == 10
Loading