Skip to content

UI - fix edit azure public model name + support changing model names post create #10249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/litellm/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,43 @@ def test_router_with_model_info_and_model_group():
model_group="gpt-3.5-turbo",
user_facing_model_group_name="gpt-3.5-turbo",
)


@pytest.mark.asyncio
async def test_router_with_tags_and_fallbacks():
"""
If fallback model missing tag, raise error
"""
from litellm import Router

router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "Hello, world!",
"tags": ["test"],
},
},
{
"model_name": "anthropic-claude-3-5-sonnet",
"litellm_params": {
"model": "claude-3-5-sonnet-latest",
"mock_response": "Hello, world 2!",
},
},
],
fallbacks=[
{"gpt-3.5-turbo": ["anthropic-claude-3-5-sonnet"]},
],
enable_tag_filtering=True,
)

with pytest.raises(Exception):
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_testing_fallbacks=True,
metadata={"tags": ["test"]},
)
55 changes: 2 additions & 53 deletions tests/local_testing/test_amazing_vertex_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def test_get_response():
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
try:
response = await acompletion(
model="gemini-pro",
model="gemini-1.5-flash",
messages=[
{
"role": "system",
Expand Down Expand Up @@ -1784,7 +1784,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
load_vertex_ai_credentials()
litellm.set_verbose = True
data = {
"model": "vertex_ai/gemini-pro",
"model": "vertex_ai/gemini-1.5-flash",
"messages": [
{
"role": "user",
Expand Down Expand Up @@ -1844,57 +1844,6 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
pass


@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_gemini_pro_async_function_calling():
load_vertex_ai_credentials()
litellm.set_verbose = True
try:
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
]
completion = await litellm.acompletion(
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
print(f"message content: {completion.choices[0].message.content}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1

# except litellm.APIError as e:
# pass
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# raise Exception("it worked!")


# asyncio.run(gemini_pro_async_function_calling())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const ConditionalPublicModelName: React.FC = () => {
const customModelName = Form.useWatch('custom_model_name', form);
const showPublicModelName = !selectedModels.includes('all-wildcard');


// Force table to re-render when custom model name changes
useEffect(() => {
if (customModelName && selectedModels.includes('custom')) {
Expand All @@ -35,20 +36,33 @@ const ConditionalPublicModelName: React.FC = () => {
// Initial setup of model mappings when models are selected
useEffect(() => {
if (selectedModels.length > 0 && !selectedModels.includes('all-wildcard')) {
const mappings = selectedModels.map((model: string) => {
if (model === 'custom' && customModelName) {
// Check if we already have mappings that match the selected models
const currentMappings = form.getFieldValue('model_mappings') || [];

// Only update if the mappings don't exist or don't match the selected models
const shouldUpdateMappings = currentMappings.length !== selectedModels.length ||
!selectedModels.every(model =>
currentMappings.some((mapping: { public_name: string; litellm_model: string }) =>
mapping.public_name === model ||
(model === 'custom' && mapping.public_name === customModelName)));

if (shouldUpdateMappings) {
const mappings = selectedModels.map((model: string) => {
if (model === 'custom' && customModelName) {
return {
public_name: customModelName,
litellm_model: customModelName
};
}
return {
public_name: customModelName,
litellm_model: customModelName
public_name: model,
litellm_model: model
};
}
return {
public_name: model,
litellm_model: model
};
});
form.setFieldValue('model_mappings', mappings);
setTableKey(prev => prev + 1); // Force table re-render
});

form.setFieldValue('model_mappings', mappings);
setTableKey(prev => prev + 1); // Force table re-render
}
}
}, [selectedModels, customModelName, form]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,34 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({

// If "all-wildcard" is selected, clear the model_name field
if (values.includes("all-wildcard")) {
form.setFieldsValue({ model_name: undefined, model_mappings: [] });
form.setFieldsValue({ model: undefined, model_mappings: [] });
} else {
// Update model mappings immediately for each selected model
const mappings = values
.map(model => ({
// Get current model value to check if we need to update
const currentModel = form.getFieldValue('model');

// Only update if the value has actually changed
if (JSON.stringify(currentModel) !== JSON.stringify(values)) {

// Create mappings first
const mappings = values.map(model => ({
public_name: model,
litellm_model: model
}));
form.setFieldsValue({ model_mappings: mappings });

// Update both fields in one call to reduce re-renders
form.setFieldsValue({
model: values,
model_mappings: mappings
});

}
}
};

// Handle custom model name changes
const handleCustomModelNameChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const customName = e.target.value;

// Immediately update the model mappings
const currentMappings = form.getFieldValue('model_mappings') || [];
const updatedMappings = currentMappings.map((mapping: any) => {
Expand Down Expand Up @@ -69,7 +81,11 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
{(selectedProvider === Providers.Azure) ||
(selectedProvider === Providers.OpenAI_Compatible) ||
(selectedProvider === Providers.Ollama) ? (
<TextInput placeholder={getPlaceholder(selectedProvider)} />
<>
<TextInput
placeholder={getPlaceholder(selectedProvider)}
/>
</>
) : providerModels.length > 0 ? (
<AntSelect
mode="multiple"
Expand Down
6 changes: 4 additions & 2 deletions ui/litellm-dashboard/src/components/model_info_view.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
} from "@tremor/react";
import NumericalInput from "./shared/numerical_input";
import { ArrowLeftIcon, TrashIcon, KeyIcon } from "@heroicons/react/outline";
import { modelDeleteCall, modelUpdateCall, CredentialItem, credentialGetCall, credentialCreateCall, modelInfoCall, modelInfoV1Call } from "./networking";
import { modelDeleteCall, modelUpdateCall, CredentialItem, credentialGetCall, credentialCreateCall, modelInfoCall, modelInfoV1Call, modelPatchUpdateCall } from "./networking";
import { Button, Form, Input, InputNumber, message, Select, Modal } from "antd";
import EditModelModal from "./edit_model/edit_model_modal";
import { handleEditModelSubmit } from "./edit_model/edit_model_modal";
Expand Down Expand Up @@ -118,6 +118,8 @@ export default function ModelInfoView({
try {
if (!accessToken) return;
setIsSaving(true);

console.log("values.model_name, ", values.model_name);

let updatedLitellmParams = {
...localModelData.litellm_params,
Expand Down Expand Up @@ -149,7 +151,7 @@ export default function ModelInfoView({
}
};

await modelUpdateCall(accessToken, updateData);
await modelPatchUpdateCall(accessToken, updateData, modelId);

const updatedModelData = {
...localModelData,
Expand Down
43 changes: 43 additions & 0 deletions ui/litellm-dashboard/src/components/networking.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3149,6 +3149,49 @@ export const teamUpdateCall = async (
}
};

/**
* Patch update a model
*
* @param accessToken
* @param formValues
* @returns
*/
export const modelPatchUpdateCall = async (
accessToken: string,
formValues: Record<string, any>, // Assuming formValues is an object
modelId: string
) => {
try {
console.log("Form Values in modelUpateCall:", formValues); // Log the form values before making the API call

const url = proxyBaseUrl ? `${proxyBaseUrl}/model/${modelId}/update` : `/model/${modelId}/update`;
const response = await fetch(url, {
method: "PATCH",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
...formValues, // Include formValues in the request body
}),
});

if (!response.ok) {
const errorData = await response.text();
handleError(errorData);
console.error("Error update from the server:", errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
console.log("Update model Response:", data);
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
console.error("Failed to update model:", error);
throw error;
}
};

export const modelUpdateCall = async (
accessToken: string,
formValues: Record<string, any> // Assuming formValues is an object
Expand Down
Loading