Skip to content

Commit 0e4ba42

Browse files
committed
Merge branch 'main' into 56-oracle-db-integration
2 parents f24b3df + 627cc06 commit 0e4ba42

File tree

4 files changed

+54
-5
lines changed

4 files changed

+54
-5
lines changed

src/client/content/config/tabs/models.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _initialize_model(action: str, model_type: str, model_id: str = None, model_
8686
quoted_model_id = urllib.parse.quote(model_id, safe="")
8787
model = api_call.get(endpoint=f"v1/models/{model_provider}/{quoted_model_id}")
8888
else:
89-
model = {"id": "unset", "type": model_type, "provider": "unset", "status": "CUSTOM"}
89+
model = {"id": "", "type": model_type, "provider": "unset", "status": "CUSTOM"}
9090

9191
if action == "add":
9292
model["enabled"] = True
@@ -128,6 +128,11 @@ def _render_model_selection(model: dict, provider_models: list, action: str) ->
128128
model_keys = [m["key"] for m in provider_models]
129129
model_index = next((i for i, key in enumerate(model_keys) if key == model["id"]), None)
130130

131+
# If the current model ID is not in the supported list, add it to the options
132+
if model_index is None and model["id"] not in model_keys:
133+
model_keys.append(model["id"])
134+
model_index = len(model_keys) - 1
135+
131136
model["id"] = st.selectbox(
132137
"Model (Required):",
133138
help=help_text.help_dict["model_id"],
@@ -145,8 +150,7 @@ def _render_model_selection(model: dict, provider_models: list, action: str) ->
145150
def _render_api_configuration(model: dict, provider_models: list, disable_for_oci: bool) -> dict:
146151
"""Render API configuration UI and return updated model"""
147152
api_base = next(
148-
(m.get("api_base", "") for m in provider_models if m.get("key") == model["id"]),
149-
model.get("api_base", "")
153+
(m.get("api_base", "") for m in provider_models if m.get("key") == model["id"]), model.get("api_base", "")
150154
)
151155

152156
model["api_base"] = st.text_input(
@@ -218,6 +222,8 @@ def _handle_form_submission(model: dict, action: str) -> bool:
218222

219223
try:
220224
if action == "add" and action_button.button(label="Add", type="primary", width="stretch"):
225+
if not all([model["id"], model["provider"]]):
226+
raise ValueError
221227
create_model(model=model)
222228
return True
223229
if action == "edit" and action_button.button(label="Save", type="primary", width="stretch"):
@@ -228,6 +234,11 @@ def _handle_form_submission(model: dict, action: str) -> bool:
228234
return True
229235
except api_call.ApiError as ex:
230236
st.error(f"Failed to {action} model: {ex}")
237+
except ValueError:
238+
if not model["id"]:
239+
st.error("Model name is required.")
240+
if not model["provider"]:
241+
st.error("Provider name is required.")
231242

232243
if cancel_button.button(label="Cancel", type="secondary"):
233244
st_common.clear_state_key("model_configs")

src/server/api/utils/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def update(payload: schema.Model) -> schema.Model:
9292
"""Update an existing Model definition"""
9393

9494
model_upd = get(model_provider=payload.provider, model_id=payload.id)
95-
if payload.enabled and not is_url_accessible(model_upd.api_base)[0]:
95+
if payload.enabled and model_upd.api_base and not is_url_accessible(model_upd.api_base)[0]:
9696
model_upd.enabled = False
9797
raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.")
9898

src/server/api/v1/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def models_update(payload: schema.Model) -> schema.Model:
8787
async def models_create(
8888
payload: schema.Model,
8989
) -> schema.Model:
90-
"""Update a model"""
90+
"""Create a model"""
9191
logger.debug("Received model_create - payload: %s", payload)
9292

9393
try:

tests/client/content/config/tabs/test_models.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,41 @@ def test_embedding_model_specific_fields(self, app_server, app_test):
367367
# Embedding model specific fields might be present
368368
if "max_chunk_size" in model:
369369
assert isinstance(model["max_chunk_size"], int)
370+
371+
def test_render_model_selection_with_custom_model_id(self, app_server, app_test):
372+
"""Test that _render_model_selection handles custom model IDs not in supported models list"""
373+
from client.content.config.tabs.models import _render_model_selection, get_supported_models
374+
375+
assert app_server is not None
376+
at = self._setup_function_test(app_test)
377+
378+
# Get actual supported models from API
379+
with patch("client.utils.api_call.state", at.session_state):
380+
supported_models = get_supported_models("ll")
381+
382+
# Find a provider and create a model with a custom ID not in their supported list
383+
openai_provider = next((p for p in supported_models if p["provider"] == "openai"), None)
384+
assert openai_provider is not None, "OpenAI provider should be available in supported models"
385+
386+
provider_models = openai_provider["models"]
387+
model_keys = [m["key"] for m in provider_models]
388+
389+
# Create a custom model ID that definitely won't be in the supported list
390+
custom_model_id = "custom-fine-tuned-model-12345"
391+
assert custom_model_id not in model_keys, f"Custom model ID {custom_model_id} should not be in supported models"
392+
393+
# Test model with custom ID
394+
model = {
395+
"id": custom_model_id,
396+
"provider": "openai",
397+
"type": "ll"
398+
}
399+
400+
action = "edit"
401+
402+
with patch("client.content.config.tabs.models.state", at.session_state):
403+
# This should preserve the custom model ID even though it's not in provider models
404+
result_model = _render_model_selection(model, provider_models, action)
405+
406+
# The model ID should be preserved
407+
assert result_model["id"] == custom_model_id

0 commit comments

Comments
 (0)