Skip to content

Commit 96541bd

Browse files
committed
test: add tests for stream_usage and token tracking
1 parent 2423af8 commit 96541bd

File tree

2 files changed

+250
-0
lines changed

2 files changed

+250
-0
lines changed

tests/test_callbacks.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from uuid import uuid4
17+
18+
import pytest
19+
from langchain.schema import Generation, LLMResult
20+
from langchain_core.messages import AIMessage
21+
from langchain_core.outputs import ChatGeneration
22+
23+
from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
24+
from nemoguardrails.logging.callbacks import LoggingCallbackHandler
25+
from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo
26+
from nemoguardrails.logging.stats import LLMStats
27+
28+
29+
@pytest.mark.asyncio
30+
async def test_token_usage_tracking_with_usage_metadata():
31+
"""Test that token usage is tracked when usage_metadata is available (stream_usage=True scenario)."""
32+
33+
llm_call_info = LLMCallInfo()
34+
llm_call_info_var.set(llm_call_info)
35+
36+
llm_stats = LLMStats()
37+
llm_stats_var.set(llm_stats)
38+
39+
explain_info = ExplainInfo()
40+
explain_info_var.set(explain_info)
41+
42+
handler = LoggingCallbackHandler()
43+
44+
# simulate the LLM response with usage metadata (as would happen with stream_usage=True)
45+
ai_message = AIMessage(
46+
content="Hello! How can I help you?",
47+
usage_metadata={"input_tokens": 10, "output_tokens": 6, "total_tokens": 16},
48+
)
49+
50+
chat_generation = ChatGeneration(message=ai_message)
51+
llm_result = LLMResult(generations=[[chat_generation]])
52+
53+
# call the on_llm_end method
54+
await handler.on_llm_end(llm_result, run_id=uuid4())
55+
56+
assert llm_call_info.total_tokens == 16
57+
assert llm_call_info.prompt_tokens == 10
58+
assert llm_call_info.completion_tokens == 6
59+
60+
assert llm_stats.get_stat("total_tokens") == 16
61+
assert llm_stats.get_stat("total_prompt_tokens") == 10
62+
assert llm_stats.get_stat("total_completion_tokens") == 6
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_token_usage_tracking_with_llm_output_fallback():
67+
"""Test token usage tracking with legacy llm_output format."""
68+
69+
llm_call_info = LLMCallInfo()
70+
llm_call_info_var.set(llm_call_info)
71+
72+
llm_stats = LLMStats()
73+
llm_stats_var.set(llm_stats)
74+
75+
explain_info = ExplainInfo()
76+
explain_info_var.set(explain_info)
77+
78+
handler = LoggingCallbackHandler()
79+
80+
# simulate LLM response with token usage in llm_output (fallback scenario)
81+
generation = Generation(text="Fallback response")
82+
llm_result = LLMResult(
83+
generations=[[generation]],
84+
llm_output={
85+
"token_usage": {
86+
"total_tokens": 20,
87+
"prompt_tokens": 12,
88+
"completion_tokens": 8,
89+
}
90+
},
91+
)
92+
93+
await handler.on_llm_end(llm_result, run_id=uuid4())
94+
95+
assert llm_call_info.total_tokens == 20
96+
assert llm_call_info.prompt_tokens == 12
97+
assert llm_call_info.completion_tokens == 8
98+
99+
assert llm_stats.get_stat("total_tokens") == 20
100+
assert llm_stats.get_stat("total_prompt_tokens") == 12
101+
assert llm_stats.get_stat("total_completion_tokens") == 8
102+
103+
104+
@pytest.mark.asyncio
105+
async def test_no_token_usage_tracking_without_metadata():
106+
"""Test that no token usage is tracked when metadata is not available."""
107+
108+
llm_call_info = LLMCallInfo()
109+
llm_call_info_var.set(llm_call_info)
110+
111+
llm_stats = LLMStats()
112+
llm_stats_var.set(llm_stats)
113+
114+
explain_info = ExplainInfo()
115+
explain_info_var.set(explain_info)
116+
117+
handler = LoggingCallbackHandler()
118+
119+
# simulate LLM response without usage metadata (stream_usage=False scenario)
120+
ai_message = AIMessage(content="Hello! How can I help you?")
121+
chat_generation = ChatGeneration(message=ai_message)
122+
llm_result = LLMResult(generations=[[chat_generation]])
123+
124+
await handler.on_llm_end(llm_result, run_id=uuid4())
125+
126+
assert llm_call_info.total_tokens is None or llm_call_info.total_tokens == 0
127+
assert llm_call_info.prompt_tokens is None or llm_call_info.prompt_tokens == 0
128+
assert (
129+
llm_call_info.completion_tokens is None or llm_call_info.completion_tokens == 0
130+
)
131+
132+
133+
@pytest.mark.asyncio
134+
async def test_multiple_generations_token_accumulation():
135+
"""Test that token usage accumulates across multiple generations."""
136+
137+
llm_call_info = LLMCallInfo()
138+
llm_call_info_var.set(llm_call_info)
139+
140+
llm_stats = LLMStats()
141+
llm_stats_var.set(llm_stats)
142+
143+
explain_info = ExplainInfo()
144+
explain_info_var.set(explain_info)
145+
146+
handler = LoggingCallbackHandler()
147+
148+
ai_message1 = AIMessage(
149+
content="First response",
150+
usage_metadata={"input_tokens": 5, "output_tokens": 3, "total_tokens": 8},
151+
)
152+
153+
ai_message2 = AIMessage(
154+
content="Second response",
155+
usage_metadata={"input_tokens": 7, "output_tokens": 4, "total_tokens": 11},
156+
)
157+
158+
chat_generation1 = ChatGeneration(message=ai_message1)
159+
chat_generation2 = ChatGeneration(message=ai_message2)
160+
llm_result = LLMResult(generations=[[chat_generation1, chat_generation2]])
161+
162+
await handler.on_llm_end(llm_result, run_id=uuid4())
163+
164+
assert llm_call_info.total_tokens == 19 # 8 + 11
165+
assert llm_call_info.prompt_tokens == 12 # 5 + 7
166+
assert llm_call_info.completion_tokens == 7 # 3 + 4
167+
168+
assert llm_stats.get_stat("total_tokens") == 19
169+
assert llm_stats.get_stat("total_prompt_tokens") == 12
170+
assert llm_stats.get_stat("total_completion_tokens") == 7

tests/test_llmrails.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,3 +1068,83 @@ def __init__(self):
10681068

10691069
assert kwargs["api_key"] == "direct-key"
10701070
assert kwargs["temperature"] == 0.3
1071+
1072+
1073+
@pytest.mark.asyncio
1074+
@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
1075+
async def test_stream_usage_enabled_for_streaming_supported_providers(
1076+
mock_init_llm_model,
1077+
):
1078+
"""Test that stream_usage=True is set when streaming is enabled for supported providers."""
1079+
config = RailsConfig.from_content(
1080+
config={
1081+
"models": [
1082+
{
1083+
"type": "main",
1084+
"engine": "openai",
1085+
"model": "gpt-4",
1086+
}
1087+
],
1088+
"streaming": True,
1089+
}
1090+
)
1091+
1092+
LLMRails(config=config)
1093+
1094+
mock_init_llm_model.assert_called_once()
1095+
call_args = mock_init_llm_model.call_args
1096+
kwargs = call_args.kwargs.get("kwargs", {})
1097+
1098+
assert kwargs.get("stream_usage") is True
1099+
1100+
1101+
@pytest.mark.asyncio
1102+
@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
1103+
async def test_stream_usage_not_set_without_streaming(mock_init_llm_model):
1104+
"""Test that stream_usage is not set when streaming is disabled."""
1105+
config = RailsConfig.from_content(
1106+
config={
1107+
"models": [
1108+
{
1109+
"type": "main",
1110+
"engine": "openai",
1111+
"model": "gpt-4",
1112+
}
1113+
],
1114+
"streaming": False,
1115+
}
1116+
)
1117+
1118+
LLMRails(config=config)
1119+
1120+
mock_init_llm_model.assert_called_once()
1121+
call_args = mock_init_llm_model.call_args
1122+
kwargs = call_args.kwargs.get("kwargs", {})
1123+
1124+
assert "stream_usage" not in kwargs
1125+
1126+
1127+
@pytest.mark.asyncio
1128+
@patch("nemoguardrails.rails.llm.llmrails.init_llm_model")
1129+
async def test_stream_usage_not_set_without_supported_providers(mock_init_llm_model):
1130+
"""Test that stream_usage is not set with an unspported provider."""
1131+
config = RailsConfig.from_content(
1132+
config={
1133+
"models": [
1134+
{
1135+
"type": "main",
1136+
"engine": "unsupported",
1137+
"model": "whatever",
1138+
}
1139+
],
1140+
"streaming": True,
1141+
}
1142+
)
1143+
1144+
LLMRails(config=config)
1145+
1146+
mock_init_llm_model.assert_called_once()
1147+
call_args = mock_init_llm_model.call_args
1148+
kwargs = call_args.kwargs.get("kwargs", {})
1149+
1150+
assert "stream_usage" not in kwargs

0 commit comments

Comments
 (0)