|
10 | 10 | from safetensors.torch import save_file, load_file
|
11 | 11 | import json
|
12 | 12 | from sklearn.cluster import KMeans
|
13 |
| -from huggingface_hub import ModelHubMixin |
| 13 | +from huggingface_hub import ModelHubMixin, hf_hub_download |
14 | 14 | import os
|
15 | 15 | import shutil
|
16 | 16 |
|
@@ -369,21 +369,80 @@ def _save_pretrained(
|
369 | 369 | @classmethod
|
370 | 370 | def _from_pretrained(
|
371 | 371 | cls,
|
372 |
| - model_id: Union[str, Path], |
373 |
| - config: Optional[Dict[str, Any]] = None, |
| 372 | + *, |
| 373 | + model_id: str, |
| 374 | + revision: Optional[str] = None, |
| 375 | + cache_dir: Optional[str] = None, |
| 376 | + force_download: bool = False, |
| 377 | + proxies: Optional[Dict] = None, |
| 378 | + resume_download: bool = False, |
| 379 | + local_files_only: bool = False, |
| 380 | + token: Optional[Union[str, bool]] = None, |
374 | 381 | **kwargs
|
375 | 382 | ) -> "AdaptiveClassifier":
|
376 | 383 | """Load a model from the HuggingFace Hub or local directory.
|
377 | 384 |
|
378 | 385 | Args:
|
379 | 386 | model_id: HuggingFace Hub model ID or path to local directory
|
380 |
| - config: Optional configuration overrides |
| 387 | + revision: Revision of the model on the Hub |
| 388 | + cache_dir: Cache directory for downloaded models |
| 389 | + force_download: Force download of models |
| 390 | + proxies: Proxies to use for downloading |
| 391 | + resume_download: Resume downloading if interrupted |
| 392 | + local_files_only: Use local files only, don't download |
| 393 | + token: Authentication token for Hub |
381 | 394 | **kwargs: Additional arguments passed to from_pretrained
|
382 | 395 |
|
383 | 396 | Returns:
|
384 | 397 | Loaded AdaptiveClassifier instance
|
385 | 398 | """
|
386 |
| - model_path = Path(model_id) |
| 399 | + |
| 400 | + # Check if model_id is a local directory |
| 401 | + if os.path.isdir(model_id): |
| 402 | + model_path = Path(model_id) |
| 403 | + else: |
| 404 | + # Download config file from the Hub |
| 405 | + try: |
| 406 | + config_file = hf_hub_download( |
| 407 | + repo_id=model_id, |
| 408 | + filename="config.json", |
| 409 | + revision=revision, |
| 410 | + cache_dir=cache_dir, |
| 411 | + force_download=force_download, |
| 412 | + proxies=proxies, |
| 413 | + resume_download=resume_download, |
| 414 | + token=token, |
| 415 | + local_files_only=local_files_only, |
| 416 | + ) |
| 417 | + model_path = Path(os.path.dirname(config_file)) |
| 418 | + |
| 419 | + # Download examples file |
| 420 | + examples_file = hf_hub_download( |
| 421 | + repo_id=model_id, |
| 422 | + filename="examples.json", |
| 423 | + revision=revision, |
| 424 | + cache_dir=cache_dir, |
| 425 | + force_download=force_download, |
| 426 | + proxies=proxies, |
| 427 | + resume_download=resume_download, |
| 428 | + token=token, |
| 429 | + local_files_only=local_files_only, |
| 430 | + ) |
| 431 | + |
| 432 | + # Download model file |
| 433 | + model_file = hf_hub_download( |
| 434 | + repo_id=model_id, |
| 435 | + filename="model.safetensors", |
| 436 | + revision=revision, |
| 437 | + cache_dir=cache_dir, |
| 438 | + force_download=force_download, |
| 439 | + proxies=proxies, |
| 440 | + resume_download=resume_download, |
| 441 | + token=token, |
| 442 | + local_files_only=local_files_only, |
| 443 | + ) |
| 444 | + except Exception as e: |
| 445 | + raise ValueError(f"Error downloading model from {model_id}: {e}") |
387 | 446 |
|
388 | 447 | # Load configuration
|
389 | 448 | with open(model_path / "config.json", "r", encoding="utf-8") as f:
|
|
0 commit comments