Skip to content

Commit 647fd09

Browse files
authored
Merge pull request #12 from codelion/feat-add-hf-hub-integration
Feat add hf hub integration
2 parents 27ecc04 + 8abab96 commit 647fd09

File tree

4 files changed

+242
-48
lines changed

4 files changed

+242
-48
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ classifier.save("./my_classifier")
5454

5555
# Load it later
5656
loaded_classifier = AdaptiveClassifier.load("./my_classifier")
57+
58+
# The library is also integrated with Hugging Face. So you can push and load from HF Hub.
59+
60+
# Save to Hub
61+
classifier.push_to_hub("username/model-name")
62+
63+
# Load from Hub
64+
classifier = AdaptiveClassifier.from_pretrained("username/model-name")
5765
```
5866

5967
## Advanced Usage

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name="adaptive-classifier",
18-
version="0.0.4",
18+
version="0.0.5",
1919
author="codelion",
2020
author_email="[email protected]",
2121
description="A flexible, adaptive classification system for dynamic text classification",

src/adaptive_classifier/classifier.py

Lines changed: 218 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import torch.nn.functional as F
44
import numpy as np
55
from transformers import AutoModel, AutoTokenizer
6-
from typing import List, Dict, Optional, Tuple, Any, Set
6+
from typing import List, Dict, Optional, Tuple, Any, Set, Union
77
import logging
88
import copy
99
from pathlib import Path
1010
from safetensors.torch import save_file, load_file
1111
import json
1212
from sklearn.cluster import KMeans
13+
from huggingface_hub import ModelHubMixin
14+
import os
15+
import shutil
1316

1417
from .models import Example, AdaptiveHead, ModelConfig
1518
from .memory import PrototypeMemory
@@ -18,7 +21,7 @@
1821

1922
logger = logging.getLogger(__name__)
2023

21-
class AdaptiveClassifier:
24+
class AdaptiveClassifier(ModelHubMixin):
2225
"""A flexible classifier that can adapt to new classes and examples."""
2326

2427
def __init__(
@@ -284,33 +287,43 @@ def predict(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
284287

285288
return predictions[:k]
286289

287-
def save(self, save_dir: str):
288-
"""Save classifier state with representative examples."""
289-
save_dir = Path(save_dir)
290-
save_dir.mkdir(parents=True, exist_ok=True)
291-
292-
# Select representative examples for each class
293-
saved_examples = {}
294-
for label, examples in self.memory.examples.items():
295-
saved_examples[label] = [
296-
ex.to_dict() for ex in
297-
self.select_representative_examples(examples, k=5)
298-
]
290+
def _save_pretrained(
291+
self,
292+
save_directory: Union[str, Path],
293+
config: Optional[Dict[str, Any]] = None,
294+
**kwargs
295+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
296+
"""Save the model to a directory.
299297
298+
Args:
299+
save_directory: Directory to save the model to
300+
config: Optional additional configuration
301+
**kwargs: Additional arguments passed to save_pretrained
302+
303+
Returns:
304+
Tuple of (dict of filenames, dict of objects to save)
305+
"""
306+
save_directory = Path(save_directory)
307+
os.makedirs(save_directory, exist_ok=True)
308+
300309
# Save configuration and metadata
301-
config = {
310+
config_dict = {
302311
'model_name': self.model.config._name_or_path,
303312
'embedding_dim': self.embedding_dim,
304313
'label_to_id': self.label_to_id,
305314
'id_to_label': {str(k): v for k, v in self.id_to_label.items()},
306315
'train_steps': self.train_steps,
307-
'config': self.config.to_dict(),
308-
'examples': saved_examples
316+
'config': self.config.to_dict()
309317
}
310-
311-
with open(save_dir / 'config.json', 'w') as f:
312-
json.dump(config, f)
313-
318+
319+
# Save examples in a separate file to keep config clean
320+
saved_examples = {}
321+
for label, examples in self.memory.examples.items():
322+
saved_examples[label] = [
323+
ex.to_dict() for ex in
324+
self.select_representative_examples(examples, k=5)
325+
]
326+
314327
# Save model tensors
315328
tensor_dict = {}
316329

@@ -322,65 +335,224 @@ def save(self, save_dir: str):
322335
if self.adaptive_head is not None:
323336
for name, param in self.adaptive_head.state_dict().items():
324337
tensor_dict[f'adaptive_head_{name}'] = param
325-
326-
# Save tensors
327-
save_file(tensor_dict, save_dir / 'tensors.safetensors')
328-
338+
339+
# Save files
340+
config_file = save_directory / "config.json"
341+
examples_file = save_directory / "examples.json"
342+
tensors_file = save_directory / "model.safetensors"
343+
344+
with open(config_file, "w", encoding="utf-8") as f:
345+
json.dump(config_dict, f, indent=2, sort_keys=True)
346+
347+
with open(examples_file, "w", encoding="utf-8") as f:
348+
json.dump(saved_examples, f, indent=2, sort_keys=True)
349+
350+
save_file(tensor_dict, tensors_file)
351+
352+
# Generate model card if it doesn't exist
353+
model_card_path = save_directory / "README.md"
354+
if not model_card_path.exists():
355+
model_card_content = self._generate_model_card()
356+
with open(model_card_path, "w", encoding="utf-8") as f:
357+
f.write(model_card_content)
358+
359+
# Return files that were created
360+
saved_files = {
361+
"config": config_file.name,
362+
"examples": examples_file.name,
363+
"model": tensors_file.name,
364+
"model_card": model_card_path.name,
365+
}
366+
367+
return saved_files, {}
368+
329369
@classmethod
330-
def load(cls, save_dir: str, device: Optional[str] = None) -> 'AdaptiveClassifier':
331-
"""Load classifier with saved examples."""
332-
save_dir = Path(save_dir)
370+
def _from_pretrained(
371+
cls,
372+
model_id: Union[str, Path],
373+
config: Optional[Dict[str, Any]] = None,
374+
**kwargs
375+
) -> "AdaptiveClassifier":
376+
"""Load a model from the HuggingFace Hub or local directory.
333377
378+
Args:
379+
model_id: HuggingFace Hub model ID or path to local directory
380+
config: Optional configuration overrides
381+
**kwargs: Additional arguments passed to from_pretrained
382+
383+
Returns:
384+
Loaded AdaptiveClassifier instance
385+
"""
386+
model_path = Path(model_id)
387+
334388
# Load configuration
335-
with open(save_dir / 'config.json', 'r') as f:
336-
config = json.load(f)
337-
389+
with open(model_path / "config.json", "r", encoding="utf-8") as f:
390+
config_dict = json.load(f)
391+
392+
# Load examples
393+
with open(model_path / "examples.json", "r", encoding="utf-8") as f:
394+
saved_examples = json.load(f)
395+
338396
# Initialize classifier
397+
device = kwargs.get("device", None)
339398
classifier = cls(
340-
config['model_name'],
399+
config_dict['model_name'],
341400
device=device,
342-
config=config.get('config', None)
401+
config=config_dict.get('config', None)
343402
)
344-
403+
345404
# Restore label mappings
346-
classifier.label_to_id = config['label_to_id']
405+
classifier.label_to_id = config_dict['label_to_id']
347406
classifier.id_to_label = {
348-
int(k): v for k, v in config['id_to_label'].items()
407+
int(k): v for k, v in config_dict['id_to_label'].items()
349408
}
350-
classifier.train_steps = config['train_steps']
351-
409+
classifier.train_steps = config_dict['train_steps']
410+
352411
# Load tensors
353-
tensors = load_file(save_dir / 'tensors.safetensors')
354-
412+
tensors = load_file(model_path / "model.safetensors")
413+
355414
# Restore saved examples
356-
saved_examples = config['examples']
357415
for label, examples_data in saved_examples.items():
358416
classifier.memory.examples[label] = [
359417
Example.from_dict(ex_data) for ex_data in examples_data
360418
]
361-
419+
362420
# Restore prototypes
363421
for label in classifier.label_to_id.keys():
364422
prototype_key = f'prototype_{label}'
365423
if prototype_key in tensors:
366424
prototype = tensors[prototype_key]
367425
classifier.memory.prototypes[label] = prototype
368-
426+
369427
# Rebuild memory system
370428
classifier.memory._restore_from_save()
371-
429+
372430
# Restore adaptive head if it exists
373431
adaptive_head_params = {
374432
k.replace('adaptive_head_', ''): v
375433
for k, v in tensors.items()
376434
if k.startswith('adaptive_head_')
377435
}
378-
436+
379437
if adaptive_head_params:
380438
classifier._initialize_adaptive_head()
381439
classifier.adaptive_head.load_state_dict(adaptive_head_params)
382-
440+
383441
return classifier
442+
443+
def _generate_model_card(self) -> str:
444+
"""Generate a model card for the classifier.
445+
446+
Returns:
447+
Model card content as string
448+
"""
449+
stats = self.get_memory_stats()
450+
451+
model_card = f"""---
452+
language: multilingual
453+
tags:
454+
- adaptive-classifier
455+
- text-classification
456+
- continuous-learning
457+
license: apache-2.0
458+
---
459+
460+
# Adaptive Classifier
461+
462+
This model is an instance of an [adaptive-classifier](https://github.com/codelion/adaptive-classifier) that allows for continuous learning and dynamic class addition.
463+
464+
You can install it with `pip install adaptive-classifier`.
465+
466+
## Model Details
467+
468+
- Base Model: {self.model.config._name_or_path}
469+
- Number of Classes: {stats['num_classes']}
470+
- Total Examples: {stats['total_examples']}
471+
- Embedding Dimension: {self.embedding_dim}
472+
473+
## Class Distribution
474+
475+
```
476+
{self._format_class_distribution(stats)}
477+
```
478+
479+
## Usage
480+
481+
```python
482+
from adaptive_classifier import AdaptiveClassifier
483+
484+
# Load the model
485+
classifier = AdaptiveClassifier.from_pretrained("{self.model.config._name_or_path}")
486+
487+
# Make predictions
488+
text = "Your text here"
489+
predictions = classifier.predict(text)
490+
print(predictions) # List of (label, confidence) tuples
491+
492+
# Add new examples
493+
texts = ["Example 1", "Example 2"]
494+
labels = ["class1", "class2"]
495+
classifier.add_examples(texts, labels)
496+
```
497+
498+
## Training Details
499+
500+
- Training Steps: {self.train_steps}
501+
- Examples per Class: See distribution above
502+
- Prototype Memory: Active
503+
- Neural Adaptation: {"Active" if self.adaptive_head is not None else "Inactive"}
504+
505+
## Limitations
506+
507+
This model:
508+
- Requires at least {self.config.min_examples_per_class} examples per class
509+
- Has a maximum of {self.config.max_examples_per_class} examples per class
510+
- Updates prototypes every {self.config.prototype_update_frequency} examples
511+
512+
## Citation
513+
514+
```bibtex
515+
@software{{adaptive_classifier,
516+
title = {{Adaptive Classifier: Dynamic Text Classification with Continuous Learning}},
517+
author = {{Sharma, Asankhaya}},
518+
year = {{2025}},
519+
publisher = {{GitHub}},
520+
url = {{https://github.com/codelion/adaptive-classifier}}
521+
}}
522+
```
523+
"""
524+
return model_card
525+
526+
def _format_class_distribution(self, stats: Dict[str, Any]) -> str:
527+
"""Format class distribution for model card.
528+
529+
Args:
530+
stats: Statistics from get_memory_stats()
531+
532+
Returns:
533+
Formatted string of class distribution
534+
"""
535+
if 'examples_per_class' not in stats:
536+
return "No examples stored"
537+
538+
lines = []
539+
total = sum(stats['examples_per_class'].values())
540+
541+
for label, count in sorted(stats['examples_per_class'].items()):
542+
percentage = (count / total) * 100 if total > 0 else 0
543+
lines.append(f"{label}: {count} examples ({percentage:.1f}%)")
544+
545+
return "\n".join(lines)
546+
547+
# Keep existing save/load methods for backwards compatibility
548+
def save(self, save_dir: str):
549+
"""Legacy save method for backwards compatibility."""
550+
self._save_pretrained(save_dir)
551+
552+
@classmethod
553+
def load(cls, save_dir: str, device: Optional[str] = None) -> 'AdaptiveClassifier':
554+
"""Legacy load method for backwards compatibility."""
555+
return cls._from_pretrained(save_dir, device=device)
384556

385557
def to(self, device: str) -> 'AdaptiveClassifier':
386558
"""Move the model to specified device.

tests/test_classifier.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,12 @@ def test_save_load(base_classifier, sample_data):
7474

7575
# Save
7676
base_classifier.save(save_path)
77+
78+
# Check all required files exist
7779
assert (save_path / "config.json").exists()
78-
assert (save_path / "tensors.safetensors").exists()
80+
assert (save_path / "model.safetensors").exists()
81+
assert (save_path / "examples.json").exists()
82+
assert (save_path / "README.md").exists()
7983

8084
# Load with same device
8185
loaded_classifier = AdaptiveClassifier.load(save_path, device=base_classifier.device)
@@ -105,6 +109,16 @@ def test_save_load(base_classifier, sample_data):
105109
assert abs(score1 - score2) < score_threshold, \
106110
f"Scores differ too much: {score1} vs {score2}"
107111

112+
# Test memory statistics match
113+
original_stats = base_classifier.get_memory_stats()
114+
loaded_stats = loaded_classifier.get_memory_stats()
115+
116+
assert original_stats['num_classes'] == loaded_stats['num_classes']
117+
assert original_stats['total_examples'] == loaded_stats['total_examples']
118+
for label in original_stats['examples_per_class']:
119+
assert original_stats['examples_per_class'][label] == \
120+
loaded_stats['examples_per_class'][label]
121+
108122
def test_dynamic_class_addition(base_classifier, sample_data):
109123
texts, labels = sample_data
110124
base_classifier.add_examples(texts[:3], labels[:3])

0 commit comments

Comments
 (0)