Skip to content

Commit 196f61c

Browse files
authored
Merge pull request #31 from codelion/fix-hf-integration
Fix hf integration
2 parents 27a05dd + fd580d0 commit 196f61c

File tree

4 files changed

+68
-7
lines changed

4 files changed

+68
-7
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ numpy>=1.24.0
66
tqdm>=4.65.0
77
setuptools>=65.0.0
88
wheel>=0.40.0
9-
scikit-learn
9+
scikit-learn
10+
huggingface_hub>=0.17.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name="adaptive-classifier",
18-
version="0.0.10",
18+
version="0.0.11",
1919
author="codelion",
2020
author_email="[email protected]",
2121
description="A flexible, adaptive classification system for dynamic text classification",

src/adaptive_classifier/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .classifier import AdaptiveClassifier
22
from .models import Example, AdaptiveHead, ModelConfig
33
from .memory import PrototypeMemory
4+
from huggingface_hub import ModelHubMixin
45

56
import os
67
import re

src/adaptive_classifier/classifier.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from safetensors.torch import save_file, load_file
1111
import json
1212
from sklearn.cluster import KMeans
13-
from huggingface_hub import ModelHubMixin
13+
from huggingface_hub import ModelHubMixin, hf_hub_download
1414
import os
1515
import shutil
1616

@@ -369,21 +369,80 @@ def _save_pretrained(
369369
@classmethod
370370
def _from_pretrained(
371371
cls,
372-
model_id: Union[str, Path],
373-
config: Optional[Dict[str, Any]] = None,
372+
*,
373+
model_id: str,
374+
revision: Optional[str] = None,
375+
cache_dir: Optional[str] = None,
376+
force_download: bool = False,
377+
proxies: Optional[Dict] = None,
378+
resume_download: bool = False,
379+
local_files_only: bool = False,
380+
token: Optional[Union[str, bool]] = None,
374381
**kwargs
375382
) -> "AdaptiveClassifier":
376383
"""Load a model from the HuggingFace Hub or local directory.
377384
378385
Args:
379386
model_id: HuggingFace Hub model ID or path to local directory
380-
config: Optional configuration overrides
387+
revision: Revision of the model on the Hub
388+
cache_dir: Cache directory for downloaded models
389+
force_download: Force download of models
390+
proxies: Proxies to use for downloading
391+
resume_download: Resume downloading if interrupted
392+
local_files_only: Use local files only, don't download
393+
token: Authentication token for Hub
381394
**kwargs: Additional arguments passed to from_pretrained
382395
383396
Returns:
384397
Loaded AdaptiveClassifier instance
385398
"""
386-
model_path = Path(model_id)
399+
400+
# Check if model_id is a local directory
401+
if os.path.isdir(model_id):
402+
model_path = Path(model_id)
403+
else:
404+
# Download config file from the Hub
405+
try:
406+
config_file = hf_hub_download(
407+
repo_id=model_id,
408+
filename="config.json",
409+
revision=revision,
410+
cache_dir=cache_dir,
411+
force_download=force_download,
412+
proxies=proxies,
413+
resume_download=resume_download,
414+
token=token,
415+
local_files_only=local_files_only,
416+
)
417+
model_path = Path(os.path.dirname(config_file))
418+
419+
# Download examples file
420+
examples_file = hf_hub_download(
421+
repo_id=model_id,
422+
filename="examples.json",
423+
revision=revision,
424+
cache_dir=cache_dir,
425+
force_download=force_download,
426+
proxies=proxies,
427+
resume_download=resume_download,
428+
token=token,
429+
local_files_only=local_files_only,
430+
)
431+
432+
# Download model file
433+
model_file = hf_hub_download(
434+
repo_id=model_id,
435+
filename="model.safetensors",
436+
revision=revision,
437+
cache_dir=cache_dir,
438+
force_download=force_download,
439+
proxies=proxies,
440+
resume_download=resume_download,
441+
token=token,
442+
local_files_only=local_files_only,
443+
)
444+
except Exception as e:
445+
raise ValueError(f"Error downloading model from {model_id}: {e}")
387446

388447
# Load configuration
389448
with open(model_path / "config.json", "r", encoding="utf-8") as f:

0 commit comments

Comments
 (0)