@@ -15,11 +15,7 @@ def __init__(self, model_name: str = "microsoft/Phi-3-mini-4k-instruct"):
15
15
self ._api_key = os .getenv ("HF_API_KEY" )
16
16
if not self ._api_key :
17
17
raise ProviderError ("Hugging Face API key is missing" )
18
- # Initialize the OpenAI client with Hugging Face API endpoint
19
- self ._client = OpenAI (
20
- base_url = "https://api-inference.huggingface.co/models" ,
21
- api_key = self ._api_key ,
22
- )
18
+
23
19
24
20
async def __call__ (self , prompt : str , ** generation_args : Any ) -> str :
25
21
opts = {
@@ -38,7 +34,7 @@ def _generate_sync(self, prompt: str, options: Dict[str, Any]) -> str:
38
34
"parameters" : {
39
35
"temperature" : options .get ("temperature" , 0.7 ),
40
36
"top_p" : options .get ("top_p" , 0.9 ),
41
- "max_new_tokens " : options .get ("max_tokens" , 20000 ),
37
+ "max_tokens " : options .get ("max_tokens" , 20000 ),
42
38
}
43
39
}
44
40
response = requests .post (
@@ -47,13 +43,20 @@ def _generate_sync(self, prompt: str, options: Dict[str, Any]) -> str:
47
43
json = payload
48
44
)
49
45
response .raise_for_status () # Raise an exception for 4XX/5XX responses
50
- return response .json ()[0 ]["generated_text" ]
46
+ response_data = response .json ()
47
+ if not response_data or not isinstance (response_data , list ) or len (response_data ) == 0 :
48
+ raise ProviderError ("Invalid response format from Hugging Face API" )
49
+ if "generated_text" not in response_data [0 ]:
50
+ raise ProviderError ("Missing 'generated_text' in Hugging Face API response" )
51
+ return response_data [0 ]["generated_text" ]
51
52
except Exception as e :
52
53
raise ProviderError (f"Hugging Face - error generating response: { e } " ) from e
53
54
54
55
class HuggingFaceEmbeddingProvider (EmbeddingProvider ):
55
56
def __init__ (self , api_key : str | None = None , embedding_model : str = "sentence-transformers/all-MiniLM-L6-v2" ):
56
57
self ._api_key = api_key or os .getenv ("HF_API_KEY" )
58
+ if not self ._api_key :
59
+ raise ProviderError ("Hugging Face API key is missing" )
57
60
self ._model = SentenceTransformer (embedding_model )
58
61
59
62
async def embed (self , text : str ) -> list [float ]:
0 commit comments