Skip to content

[Feat] DD Trace - Add instrumentation for streaming chunks #11338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/my-website/docs/proxy/config_settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ router_settings:
| DD_API_KEY | API key for Datadog integration
| DD_SITE | Site URL for Datadog (e.g., datadoghq.com)
| DD_SOURCE | Source identifier for Datadog logs
| DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE | Resource name for Datadog tracing of streaming chunk yields. Default is "streaming.chunk.yield"
| DD_ENV | Environment identifier for Datadog logs. Only supported for `datadog_llm_observability` callback
| DD_SERVICE | Service identifier for Datadog logs. Defaults to "litellm-server"
| DD_VERSION | Version identifier for Datadog logs. Defaults to "unknown"
Expand Down
3 changes: 3 additions & 0 deletions litellm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@
MAX_LANGFUSE_INITIALIZED_CLIENTS = int(
os.getenv("MAX_LANGFUSE_INITIALIZED_CLIENTS", 50)
)
DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE = os.getenv(
"DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE", "streaming.chunk.yield"
)

############### LLM Provider Constants ###############
### ANTHROPIC CONSTANTS ###
Expand Down
3 changes: 3 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from litellm.exceptions import LiteLLMUnknownProvider
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.litellm_core_utils.health_check_utils import (
_create_health_check_response,
_filter_model_params,
Expand Down Expand Up @@ -314,6 +315,7 @@ async def create(self, messages, model=None, **kwargs):
return response


@tracer.wrap()
@client
async def acompletion(
model: str,
Expand Down Expand Up @@ -810,6 +812,7 @@ def mock_completion(
raise Exception("Mock completion response failed - {}".format(e))


@tracer.wrap()
@client
def completion( # type: ignore # noqa: PLR0915
model: str,
Expand Down
14 changes: 14 additions & 0 deletions litellm/model_prices_and_context_window_backup.json
Original file line number Diff line number Diff line change
Expand Up @@ -8750,6 +8750,20 @@
"notes": "'supports_image_input' is a deprecated field. Use 'supports_embedding_image_input' instead."
}
},
"embed-v4.0": {
"max_tokens": 1024,
"max_input_tokens": 1024,
"input_cost_per_token": 1.2e-07,
"input_cost_per_image": 4.7e-07,
"output_cost_per_token": 0.0,
"litellm_provider": "cohere",
"mode": "embedding",
"supports_image_input": true,
"supports_embedding_image_input": true,
"metadata": {
"notes": "'supports_image_input' is a deprecated field. Use 'supports_embedding_image_input' instead."
}
},
"replicate/meta/llama-2-13b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
Expand Down
8 changes: 6 additions & 2 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
Expand Down Expand Up @@ -142,9 +144,11 @@ async def error_gen_message() -> AsyncGenerator[str, None]:

async def combined_generator() -> AsyncGenerator[str, None]:
if first_chunk_value is not None:
yield first_chunk_value
with tracer.trace(DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE):
yield first_chunk_value
async for chunk in generator:
yield chunk
with tracer.trace(DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE):
yield chunk

return StreamingResponse(
combined_generator(),
Expand Down
19 changes: 5 additions & 14 deletions litellm/proxy/proxy_config.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
model_list:
- model_name: anthropic/*
- model_name: openai/*
litellm_params:
model: anthropic/*

guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
guardrailIdentifier: wf0hkdb5x07f # your guardrail ID on bedrock
guardrailVersion: "DRAFT" # your guardrail version on bedrock
default_on: true

model: openai/*




general_settings:
store_prompts_in_spend_logs: true

100 changes: 100 additions & 0 deletions tests/test_litellm/proxy/test_common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,103 @@ async def mock_generator():
'data: {"content": "actual data"}\n\n',
"data: [DONE]\n\n",
]

async def test_create_streaming_response_all_chunks_have_dd_trace(self):
"""Test that all stream chunks are wrapped with dd trace at the streaming generator level"""
import json
from unittest.mock import patch

# Create a mock tracer
mock_tracer = MagicMock()
mock_span = MagicMock()
mock_tracer.trace.return_value.__enter__.return_value = mock_span
mock_tracer.trace.return_value.__exit__.return_value = None

# Mock generator with multiple chunks
async def mock_generator():
yield 'data: {"content": "chunk 1"}\n\n'
yield 'data: {"content": "chunk 2"}\n\n'
yield 'data: {"content": "chunk 3"}\n\n'
yield "data: [DONE]\n\n"

# Patch the tracer in the common_request_processing module
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
response = await create_streaming_response(
mock_generator(), "text/event-stream", {}
)

assert response.status_code == 200

# Consume the stream to trigger the tracer calls
content = await self.consume_stream(response)

# Verify all chunks are present
assert len(content) == 4
assert content[0] == 'data: {"content": "chunk 1"}\n\n'
assert content[1] == 'data: {"content": "chunk 2"}\n\n'
assert content[2] == 'data: {"content": "chunk 3"}\n\n'
assert content[3] == "data: [DONE]\n\n"

# Verify that tracer.trace was called for each chunk (4 chunks total)
assert mock_tracer.trace.call_count == 4

# Verify that each call was made with the correct operation name
expected_calls = [
(("streaming.chunk.yield",), {}),
(("streaming.chunk.yield",), {}),
(("streaming.chunk.yield",), {}),
(("streaming.chunk.yield",), {}),
]

actual_calls = mock_tracer.trace.call_args_list
assert len(actual_calls) == 4

for i, call in enumerate(actual_calls):
args, kwargs = call
assert (
args[0] == "streaming.chunk.yield"
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"

async def test_create_streaming_response_dd_trace_with_error_chunk(self):
"""Test that dd trace is applied even when the first chunk contains an error"""
from unittest.mock import patch

# Create a mock tracer
mock_tracer = MagicMock()
mock_span = MagicMock()
mock_tracer.trace.return_value.__enter__.return_value = mock_span
mock_tracer.trace.return_value.__exit__.return_value = None

# Mock generator with error in first chunk
async def mock_generator():
yield 'data: {"error": {"code": 400, "message": "bad request"}}\n\n'
yield 'data: {"content": "chunk after error"}\n\n'
yield "data: [DONE]\n\n"

# Patch the tracer in the common_request_processing module
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
response = await create_streaming_response(
mock_generator(), "text/event-stream", {}
)

# Even with error, status should be set to error code but tracing should still work
assert response.status_code == 400

# Consume the stream to trigger the tracer calls
content = await self.consume_stream(response)

# Verify all chunks are present
assert len(content) == 3

# Verify that tracer.trace was called for each chunk
assert mock_tracer.trace.call_count == 3

# Verify that each call was made with the correct operation name
actual_calls = mock_tracer.trace.call_args_list
assert len(actual_calls) == 3

for i, call in enumerate(actual_calls):
args, kwargs = call
assert (
args[0] == "streaming.chunk.yield"
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"
Loading