Skip to content

Changed to rely on litellm for computing cost #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 0 additions & 56 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,59 +98,3 @@ def table_formatter(
raise NotImplementedError(
f"Object type {type(example_object)} can not be converted to table."
)


# Index 0 is for prompt tokens, index 1 is for completion tokens
costs: dict[str, tuple[float, float]] = {
"claude-2": (11.02 / 10**6, 32.68 / 10**6),
"claude-instant-1": (1.63 / 10**6, 5.51 / 10**6),
"claude-3-sonnet-20240229": (3 / 10**6, 15 / 10**6),
"claude-3-5-sonnet-20240620": (3 / 10**6, 15 / 10**6),
"claude-3-opus-20240229": (15 / 10**6, 75 / 10**6),
"babbage-002": (0.0004 / 10**3, 0.0004 / 10**3),
"gpt-3.5-turbo": (0.0010 / 10**3, 0.0020 / 10**3),
"gpt-3.5-turbo-1106": (0.0010 / 10**3, 0.0020 / 10**3),
"gpt-3.5-turbo-0613": (0.0010 / 10**3, 0.0020 / 10**3),
"gpt-3.5-turbo-0301": (0.0010 / 10**3, 0.0020 / 10**3),
"gpt-3.5-turbo-0125": (0.0005 / 10**3, 0.0015 / 10**3),
"gpt-4-1106-preview": (0.010 / 10**3, 0.030 / 10**3),
"gpt-4-0125-preview": (0.010 / 10**3, 0.030 / 10**3),
"gpt-4-turbo-2024-04-09": (10 / 10**6, 30 / 10**6),
"gpt-4-turbo": (10 / 10**6, 30 / 10**6),
"gpt-4": (0.03 / 10**3, 0.06 / 10**3),
"gpt-4-0613": (0.03 / 10**3, 0.06 / 10**3),
"gpt-4-0314": (0.03 / 10**3, 0.06 / 10**3),
"gpt-4o": (2.5 / 10**6, 10 / 10**6),
"gpt-4o-2024-05-13": (5 / 10**6, 15 / 10**6),
"gpt-4o-2024-08-06": (2.5 / 10**6, 10 / 10**6),
"gpt-4o-mini": (0.15 / 10**6, 0.60 / 10**6),
"gemini-1.5-flash": (0.35 / 10**6, 0.35 / 10**6),
"gemini-1.5-pro": (3.5 / 10**6, 10.5 / 10**6),
# supported Anyscale models per
# https://docs.anyscale.com/endpoints/text-generation/query-a-model
"meta-llama/Meta-Llama-3-8B-Instruct": (0.15 / 10**6, 0.15 / 10**6),
"meta-llama/Meta-Llama-3-70B-Instruct": (1.0 / 10**6, 1.0 / 10**6),
"mistralai/Mistral-7B-Instruct-v0.1": (0.15 / 10**6, 0.15 / 10**6),
"mistralai/Mixtral-8x7B-Instruct-v0.1": (1.0 / 10**6, 1.0 / 10**6),
"mistralai/Mixtral-8x22B-Instruct-v0.1": (1.0 / 10**6, 1.0 / 10**6),
}


def compute_model_token_cost(model: str, tokens: int, is_completion: bool) -> float:
if model in costs: # Prefer our internal costs model
model_costs: tuple[float, float] = costs[model]
else:
logger.warning(f"Model {model} not found in costs.")
return 0.0
return tokens * model_costs[int(is_completion)]


def compute_total_model_token_cost(token_counts: dict[str, list[int]]) -> float:
"""Sum the token counts for each model and return the total cost."""
cost = 0.0
for model, tokens in token_counts.items():
if sum(tokens) > 0:
cost += compute_model_token_cost(
model, tokens=tokens[0], is_completion=False
) + compute_model_token_cost(model, tokens=tokens[1], is_completion=True)
return cost
3 changes: 1 addition & 2 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Field as FieldV1,
)

from .helpers import compute_total_model_token_cost, get_year
from .helpers import get_year
from .search import get_directory_index
from .models import QueryRequest, SimpleProfiler
from ..settings import Settings
Expand All @@ -27,7 +27,6 @@

async def status(docs: Docs, answer: Answer, relevant_score_cutoff: int = 5) -> str:
"""Create a string that provides a summary of the input doc/answer."""
answer.cost = compute_total_model_token_cost(answer.token_counts)
return (
f"Status: Paper Count={len(docs.docs)}"
f" | Relevant Papers={len({c.text.doc.dockey for c in answer.contexts if c.score > relevant_score_cutoff})}"
Expand Down
76 changes: 50 additions & 26 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]:

class LiteLLMEmbeddingModel(EmbeddingModel):
name: str = Field(default="text-embedding-3-small")
embedding_kwargs: dict = Field(default={})
embedding_kwargs: dict = Field(default_factory=dict)

async def embed_documents(
self, texts: list[str], batch_size: int = 16
Expand Down Expand Up @@ -129,9 +129,9 @@ class LLMModel(ABC, BaseModel):
" LLMResult (different than callbacks that execute on each chunk)",
exclude=True,
)
config: dict = Field(default={})
config: dict = Field(default_factory=dict)

async def acomplete(self, prompt: str) -> str:
async def acomplete(self, prompt: str) -> tuple[str, tuple[int, int]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is sort of an abstractmethod, can you somehow document what the return type is?

Can make a NamedTuple that is sort of self-documenting, or a docstring

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK added a new class and revised code to use it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks

raise NotImplementedError

async def acomplete_iter(self, prompt: str) -> Any:
Expand All @@ -141,7 +141,9 @@ async def acomplete_iter(self, prompt: str) -> Any:
"""
raise NotImplementedError

async def achat(self, messages: Iterable[dict[str, str]]) -> str:
async def achat(
self, messages: Iterable[dict[str, str]]
) -> tuple[str, tuple[int, int]]:
raise NotImplementedError

async def achat_iter(self, messages: Iterable[dict[str, str]]) -> Any:
Expand Down Expand Up @@ -205,16 +207,17 @@ async def execute(
result.prompt_count = sum(
self.count_tokens(m["content"]) for m in messages
) + sum(self.count_tokens(m["role"]) for m in messages)
usage = (0, 0)
if callbacks is None:
output = await self.achat(messages)
output, usage = await self.achat(messages)
else:
sync_callbacks = [
f for f in callbacks if not is_coroutine_callable(f)
]
async_callbacks = [f for f in callbacks if is_coroutine_callable(f)]
completion = self.achat_iter(messages)
text_result = []
async for chunk in completion: # type: ignore[attr-defined]
async for chunk, _usage in completion: # type: ignore[attr-defined]
if chunk:
if result.seconds_to_first_token == 0:
result.seconds_to_first_token = (
Expand All @@ -224,8 +227,13 @@ async def execute(
await do_callbacks(
async_callbacks, sync_callbacks, chunk, name
)
usage = _usage
output = "".join(text_result)
result.completion_count = self.count_tokens(output)
# not always reliable
if sum(usage) > 0:
result.prompt_count, result.completion_count = usage
else:
result.completion_count = self.count_tokens(output)
result.text = output
result.name = name
result.seconds_to_last_token = (
Expand Down Expand Up @@ -254,8 +262,9 @@ async def execute(
formatted_prompt = completion_prompt.format(**data)
result.prompt_count = self.count_tokens(formatted_prompt)
result.prompt = formatted_prompt
usage = (0, 0)
if callbacks is None:
output = await self.acomplete(formatted_prompt)
output, usage = await self.acomplete(formatted_prompt)
else:
sync_callbacks = [
f for f in callbacks if not is_coroutine_callable(f)
Expand All @@ -266,7 +275,7 @@ async def execute(
formatted_prompt,
)
text_result = []
async for chunk in completion: # type: ignore[attr-defined]
async for chunk, _usage in completion: # type: ignore[attr-defined]
if chunk:
if result.seconds_to_first_token == 0:
result.seconds_to_first_token = (
Expand All @@ -276,8 +285,12 @@ async def execute(
await do_callbacks(
async_callbacks, sync_callbacks, chunk, name
)
usage = _usage
output = "".join(text_result)
result.completion_count = self.count_tokens(output)
if sum(usage) > 0:
result.prompt_count, result.completion_count = usage
else:
result.completion_count = self.count_tokens(output)
result.text = output
result.name = name
result.seconds_to_last_token = (
Expand Down Expand Up @@ -326,7 +339,7 @@ class LiteLLMModel(LLMModel):

"""

config: dict = Field(default={})
config: dict = Field(default_factory=dict)
name: str = "gpt-4o-mini"
_router: Router | None = None

Expand Down Expand Up @@ -375,31 +388,42 @@ def router(self):
)
return self._router

async def acomplete(self, prompt: str) -> str:
return (
(await self.router.atext_completion(model=self.name, prompt=prompt))
.choices[0]
.text
async def acomplete(self, prompt: str) -> tuple[str, tuple[int, int]]:
response = await self.router.atext_completion(model=self.name, prompt=prompt)
return response.choices[0].text, (
response.usage.prompt_tokens,
response.usage.completion_tokens,
)

async def acomplete_iter(self, prompt: str) -> Any:
completion = await self.router.atext_completion(
model=self.name, prompt=prompt, stream=True
model=self.name,
prompt=prompt,
stream=True,
stream_options={"include_usage": True},
)
async for chunk in completion:
yield chunk.choices[0].text

async def achat(self, messages: Iterable[dict[str, str]]) -> str:
return (
(await self.router.acompletion(self.name, messages))
.choices[0]
.message.content
yield chunk.choices[0].text, (0, 0)
if hasattr(chunk, "usage") and hasattr(chunk.usage, "prompt_tokens"):
yield None, (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)

async def achat(
self, messages: Iterable[dict[str, str]]
) -> tuple[str, tuple[int, int]]:
response = await self.router.acompletion(self.name, messages)
return response.choices[0].message.content, (
response.usage.prompt_tokens,
response.usage.completion_tokens,
)

async def achat_iter(self, messages: Iterable[dict[str, str]]) -> Any:
completion = await self.router.acompletion(self.name, messages, stream=True)
completion = await self.router.acompletion(
self.name, messages, stream=True, stream_options={"include_usage": True}
)
async for chunk in completion:
yield chunk.choices[0].delta.content
yield chunk.choices[0].delta.content, (0, 0)
if hasattr(chunk, "usage") and hasattr(chunk.usage, "prompt_tokens"):
yield None, (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)

def infer_llm_type(self) -> str:
if all(
Expand Down
19 changes: 17 additions & 2 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, ClassVar
from uuid import UUID, uuid4

import litellm # for cost
import tiktoken
from pybtex.database import BibliographyData, Entry, Person
from pybtex.database.input.bibtex import Parser
Expand Down Expand Up @@ -91,6 +92,19 @@ class LLMResult(BaseModel):
def __str__(self):
return self.text

@computed_field # type: ignore[prop-decorator]
@property
def cost(self) -> float:
"""Return the cost of the result in dollars."""
if self.prompt_count and self.completion_count:
try:
pc = litellm.model_cost[self.model]["input_cost_per_token"]
oc = litellm.model_cost[self.model]["output_cost_per_token"]
return pc * self.prompt_count + oc * self.completion_count
except KeyError:
logger.warning(f"Could not find cost for model {self.model}.")
return 0.0


class Embeddable(BaseModel):
embedding: list[float] | None = Field(default=None, repr=False)
Expand Down Expand Up @@ -142,8 +156,7 @@ class Answer(BaseModel):
contexts: list[Context] = []
references: str = ""
formatted_answer: str = ""
# just for convenience you can override this
cost: float | None = None
cost: float = 0.0
# Map model name to a two-item list of LLM prompt token counts
# and LLM completion token counts
token_counts: dict[str, list[int]] = Field(default_factory=dict)
Expand Down Expand Up @@ -192,6 +205,8 @@ def add_tokens(self, result: LLMResult):
self.token_counts[result.model][0] += result.prompt_count
self.token_counts[result.model][1] += result.completion_count

self.cost += result.cost

def get_unique_docs_from_contexts(self, score_threshold: int = 0) -> set[Doc]:
"""Parse contexts for docs with scores above the input threshold."""
return {
Expand Down
14 changes: 11 additions & 3 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ def accum(x):
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0

assert completion.cost > 0


@pytest.mark.asyncio
async def test_chain_chat():
Expand Down Expand Up @@ -456,16 +458,19 @@ def accum(x):
assert completion.prompt_count > 0
assert completion.completion_count > 0
assert str(completion) == "".join(outputs)
assert completion.cost > 0

completion = await call({"animal": "duck"}) # type: ignore[call-arg]
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
assert completion.cost > 0

# check with mixed callbacks
async def ac(x):
pass

completion = await call({"animal": "duck"}, callbacks=[accum, ac]) # type: ignore[call-arg]
assert completion.cost > 0


@pytest.mark.skipif(os.environ.get("ANTHROPIC_API_KEY") is None, reason="No API key")
Expand All @@ -489,21 +494,24 @@ def accum(x):
assert completion.completion_count > 0
assert str(completion) == "".join(outputs)
assert isinstance(completion.text, str)
assert completion.cost > 0

completion = await call({"animal": "duck"}) # type: ignore[call-arg]
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
assert isinstance(completion.text, str)
assert completion.cost > 0

docs = Docs()
await docs.aadd(
stub_data_dir / "flag_day.html",
"National Flag of Canada Day",
settings=anthropic_settings,
)
await docs.aget_evidence(
result = await docs.aget_evidence(
"What is the national flag of Canada?", settings=anthropic_settings
)
assert result.cost > 0


def test_make_docs(stub_data_dir: Path):
Expand Down Expand Up @@ -698,10 +706,10 @@ class MyLLM(LLMModel):
name: str = "myllm"

async def acomplete(self, prompt): # noqa: ARG002
return "Echo"
return "Echo", (1, 1)

async def acomplete_iter(self, prompt): # noqa: ARG002
yield "Echo"
yield "Echo", (1, 1)

docs = Docs()
docs.add(
Expand Down
Loading