Skip to content

Refactor to remove skip_system from LLMModel.run_prompt #680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def litellm_get_search_query(
result = await model.run_prompt(
prompt=search_prompt,
data={"question": question, "count": count},
skip_system=True,
system_prompt=None,
)
search_query = result.text
queries = [s for s in search_query.split("\n") if len(s) > 3] # noqa: PLR2004
Expand Down
4 changes: 2 additions & 2 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ async def aadd( # noqa: PLR0912
result = await llm_model.run_prompt(
prompt=parse_config.citation_prompt,
data={"text": texts[0].text},
skip_system=True, # skip system because it's too hesitant to answer
system_prompt=None, # skip system because it's too hesitant to answer
)
citation = result.text
if (
Expand Down Expand Up @@ -313,7 +313,7 @@ async def aadd( # noqa: PLR0912
result = await llm_model.run_prompt(
prompt=parse_config.structured_citation_prompt,
data={"citation": citation},
skip_system=True,
system_prompt=None,
)
# This code below tries to isolate the JSON
# based on observed messages from LLMs
Expand Down
30 changes: 11 additions & 19 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,18 +328,15 @@ async def run_prompt(
data: dict,
callbacks: list[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
system_prompt: str | None = default_system_prompt,
) -> LLMResult:
if self.llm_type is None:
self.llm_type = self.infer_llm_type()
if self.llm_type == "chat":
return await self._run_chat(
prompt, data, callbacks, name, skip_system, system_prompt
)
return await self._run_chat(prompt, data, callbacks, name, system_prompt)
if self.llm_type == "completion":
return await self._run_completion(
prompt, data, callbacks, name, skip_system, system_prompt
prompt, data, callbacks, name, system_prompt
)
raise ValueError(f"Unknown llm_type {self.llm_type!r}.")

Expand All @@ -349,8 +346,7 @@ async def _run_chat(
data: dict,
callbacks: list[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
system_prompt: str | None = default_system_prompt,
) -> LLMResult:
"""Run a chat prompt.

Expand All @@ -359,20 +355,18 @@ async def _run_chat(
data: Keys for the input variables that will be formatted into prompt.
callbacks: Optional functions to call with each chunk of the completion.
name: Optional name for the result.
skip_system: Set True to skip the system prompt.
system_prompt: System prompt to use.
system_prompt: System prompt to use, or None/empty string to not use one.

Returns:
Result of the chat.
"""
system_message_prompt = {"role": "system", "content": system_prompt}
human_message_prompt = {"role": "user", "content": prompt}
messages = [
{"role": m["role"], "content": m["content"].format(**data)}
for m in (
[human_message_prompt]
if skip_system
else [system_message_prompt, human_message_prompt]
[{"role": "system", "content": system_prompt}, human_message_prompt]
if system_prompt
else [human_message_prompt]
)
]
result = LLMResult(
Expand Down Expand Up @@ -425,8 +419,7 @@ async def _run_completion(
data: dict,
callbacks: Iterable[Callable] | None = None,
name: str | None = None,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
system_prompt: str | None = default_system_prompt,
) -> LLMResult:
"""Run a completion prompt.

Expand All @@ -435,14 +428,13 @@ async def _run_completion(
data: Keys for the input variables that will be formatted into prompt.
callbacks: Optional functions to call with each chunk of the completion.
name: Optional name for the result.
skip_system: Set True to skip the system prompt.
system_prompt: System prompt to use.
system_prompt: System prompt to use, or None/empty string to not use one.

Returns:
Result of the completion.
"""
formatted_prompt: str = (
prompt if skip_system else system_prompt + "\n\n" + prompt
system_prompt + "\n\n" + prompt if system_prompt else prompt
).format(**data)
result = LLMResult(
model=self.name,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def accum(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum],
)
assert completion.model == "gpt-4o-mini"
Expand All @@ -72,7 +72,7 @@ def accum(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
)
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
Expand All @@ -85,7 +85,7 @@ async def ac(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum, ac],
)
assert completion.cost > 0
Expand Down
8 changes: 4 additions & 4 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def accum(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum],
)
assert completion.seconds_to_first_token > 0
Expand All @@ -432,7 +432,7 @@ def accum(x) -> None:
assert str(completion) == "".join(outputs)

completion = await llm.run_prompt(
prompt="The {animal} says", data={"animal": "duck"}, skip_system=True
prompt="The {animal} says", data={"animal": "duck"}, system_prompt=None
)
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
Expand All @@ -453,7 +453,7 @@ def accum(x) -> None:
completion = await llm.run_prompt(
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum],
)
assert completion.seconds_to_first_token > 0
Expand All @@ -464,7 +464,7 @@ def accum(x) -> None:
assert completion.cost > 0

completion = await llm.run_prompt(
prompt="The {animal} says", data={"animal": "duck"}, skip_system=True
prompt="The {animal} says", data={"animal": "duck"}, system_prompt=None
)
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def accum(x) -> None:
3,
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum],
)

Expand All @@ -192,7 +192,7 @@ def accum2(x) -> None:
use_gather=True,
prompt="The {animal} says",
data={"animal": "duck"},
skip_system=True,
system_prompt=None,
callbacks=[accum2],
)

Expand Down