Skip to content

Commit 72b5751

Browse files
committed
fix: improve error handling and remove unnecessary OpenAI client initialization in HuggingFaceProvider
1 parent e800225 commit 72b5751

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

apps/backend/app/agent/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def _get_provider(self, **kwargs: Any) -> OllamaProvider | OpenAIProvider
3838
raise ProviderError(
3939
f"Ollama Model '{model}' is not found. Run `ollama pull {model} or pick from any available models {installed_ollama_models}"
4040
)
41-
return OllamaProvider(model_name=model, host="http://localhost:11434")
41+
return OllamaProvider(model_name=model)
4242

4343
async def run(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
4444
"""

apps/backend/app/agent/providers/huggingface.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@ def __init__(self, model_name: str = "microsoft/Phi-3-mini-4k-instruct"):
1515
self._api_key = os.getenv("HF_API_KEY")
1616
if not self._api_key:
1717
raise ProviderError("Hugging Face API key is missing")
18-
# Initialize the OpenAI client with Hugging Face API endpoint
19-
self._client = OpenAI(
20-
base_url="https://api-inference.huggingface.co/models",
21-
api_key=self._api_key,
22-
)
18+
2319

2420
async def __call__(self, prompt: str, **generation_args: Any) -> str:
2521
opts = {
@@ -38,7 +34,7 @@ def _generate_sync(self, prompt: str, options: Dict[str, Any]) -> str:
3834
"parameters": {
3935
"temperature": options.get("temperature", 0.7),
4036
"top_p": options.get("top_p", 0.9),
41-
"max_new_tokens": options.get("max_tokens", 20000),
37+
"max_tokens": options.get("max_tokens", 20000),
4238
}
4339
}
4440
response = requests.post(
@@ -47,13 +43,20 @@ def _generate_sync(self, prompt: str, options: Dict[str, Any]) -> str:
4743
json=payload
4844
)
4945
response.raise_for_status() # Raise an exception for 4XX/5XX responses
50-
return response.json()[0]["generated_text"]
46+
response_data = response.json()
47+
if not response_data or not isinstance(response_data, list) or len(response_data) == 0:
48+
raise ProviderError("Invalid response format from Hugging Face API")
49+
if "generated_text" not in response_data[0]:
50+
raise ProviderError("Missing 'generated_text' in Hugging Face API response")
51+
return response_data[0]["generated_text"]
5152
except Exception as e:
5253
raise ProviderError(f"Hugging Face - error generating response: {e}") from e
5354

5455
class HuggingFaceEmbeddingProvider(EmbeddingProvider):
5556
def __init__(self, api_key: str | None = None, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
5657
self._api_key = api_key or os.getenv("HF_API_KEY")
58+
if not self._api_key:
59+
raise ProviderError("Hugging Face API key is missing")
5760
self._model = SentenceTransformer(embedding_model)
5861

5962
async def embed(self, text: str) -> list[float]:

0 commit comments

Comments
 (0)