3
3
import torch .nn .functional as F
4
4
import numpy as np
5
5
from transformers import AutoModel , AutoTokenizer
6
- from typing import List , Dict , Optional , Tuple , Any , Set
6
+ from typing import List , Dict , Optional , Tuple , Any , Set , Union
7
7
import logging
8
8
import copy
9
9
from pathlib import Path
10
10
from safetensors .torch import save_file , load_file
11
11
import json
12
12
from sklearn .cluster import KMeans
13
+ from huggingface_hub import ModelHubMixin
14
+ import os
15
+ import shutil
13
16
14
17
from .models import Example , AdaptiveHead , ModelConfig
15
18
from .memory import PrototypeMemory
18
21
19
22
logger = logging .getLogger (__name__ )
20
23
21
- class AdaptiveClassifier :
24
+ class AdaptiveClassifier ( ModelHubMixin ) :
22
25
"""A flexible classifier that can adapt to new classes and examples."""
23
26
24
27
def __init__ (
@@ -284,33 +287,43 @@ def predict(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
284
287
285
288
return predictions [:k ]
286
289
287
- def save (self , save_dir : str ):
288
- """Save classifier state with representative examples."""
289
- save_dir = Path (save_dir )
290
- save_dir .mkdir (parents = True , exist_ok = True )
291
-
292
- # Select representative examples for each class
293
- saved_examples = {}
294
- for label , examples in self .memory .examples .items ():
295
- saved_examples [label ] = [
296
- ex .to_dict () for ex in
297
- self .select_representative_examples (examples , k = 5 )
298
- ]
290
+ def _save_pretrained (
291
+ self ,
292
+ save_directory : Union [str , Path ],
293
+ config : Optional [Dict [str , Any ]] = None ,
294
+ ** kwargs
295
+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
296
+ """Save the model to a directory.
299
297
298
+ Args:
299
+ save_directory: Directory to save the model to
300
+ config: Optional additional configuration
301
+ **kwargs: Additional arguments passed to save_pretrained
302
+
303
+ Returns:
304
+ Tuple of (dict of filenames, dict of objects to save)
305
+ """
306
+ save_directory = Path (save_directory )
307
+ os .makedirs (save_directory , exist_ok = True )
308
+
300
309
# Save configuration and metadata
301
- config = {
310
+ config_dict = {
302
311
'model_name' : self .model .config ._name_or_path ,
303
312
'embedding_dim' : self .embedding_dim ,
304
313
'label_to_id' : self .label_to_id ,
305
314
'id_to_label' : {str (k ): v for k , v in self .id_to_label .items ()},
306
315
'train_steps' : self .train_steps ,
307
- 'config' : self .config .to_dict (),
308
- 'examples' : saved_examples
316
+ 'config' : self .config .to_dict ()
309
317
}
310
-
311
- with open (save_dir / 'config.json' , 'w' ) as f :
312
- json .dump (config , f )
313
-
318
+
319
+ # Save examples in a separate file to keep config clean
320
+ saved_examples = {}
321
+ for label , examples in self .memory .examples .items ():
322
+ saved_examples [label ] = [
323
+ ex .to_dict () for ex in
324
+ self .select_representative_examples (examples , k = 5 )
325
+ ]
326
+
314
327
# Save model tensors
315
328
tensor_dict = {}
316
329
@@ -322,65 +335,224 @@ def save(self, save_dir: str):
322
335
if self .adaptive_head is not None :
323
336
for name , param in self .adaptive_head .state_dict ().items ():
324
337
tensor_dict [f'adaptive_head_{ name } ' ] = param
325
-
326
- # Save tensors
327
- save_file (tensor_dict , save_dir / 'tensors.safetensors' )
328
-
338
+
339
+ # Save files
340
+ config_file = save_directory / "config.json"
341
+ examples_file = save_directory / "examples.json"
342
+ tensors_file = save_directory / "model.safetensors"
343
+
344
+ with open (config_file , "w" , encoding = "utf-8" ) as f :
345
+ json .dump (config_dict , f , indent = 2 , sort_keys = True )
346
+
347
+ with open (examples_file , "w" , encoding = "utf-8" ) as f :
348
+ json .dump (saved_examples , f , indent = 2 , sort_keys = True )
349
+
350
+ save_file (tensor_dict , tensors_file )
351
+
352
+ # Generate model card if it doesn't exist
353
+ model_card_path = save_directory / "README.md"
354
+ if not model_card_path .exists ():
355
+ model_card_content = self ._generate_model_card ()
356
+ with open (model_card_path , "w" , encoding = "utf-8" ) as f :
357
+ f .write (model_card_content )
358
+
359
+ # Return files that were created
360
+ saved_files = {
361
+ "config" : config_file .name ,
362
+ "examples" : examples_file .name ,
363
+ "model" : tensors_file .name ,
364
+ "model_card" : model_card_path .name ,
365
+ }
366
+
367
+ return saved_files , {}
368
+
329
369
@classmethod
330
- def load (cls , save_dir : str , device : Optional [str ] = None ) -> 'AdaptiveClassifier' :
331
- """Load classifier with saved examples."""
332
- save_dir = Path (save_dir )
370
+ def _from_pretrained (
371
+ cls ,
372
+ model_id : Union [str , Path ],
373
+ config : Optional [Dict [str , Any ]] = None ,
374
+ ** kwargs
375
+ ) -> "AdaptiveClassifier" :
376
+ """Load a model from the HuggingFace Hub or local directory.
333
377
378
+ Args:
379
+ model_id: HuggingFace Hub model ID or path to local directory
380
+ config: Optional configuration overrides
381
+ **kwargs: Additional arguments passed to from_pretrained
382
+
383
+ Returns:
384
+ Loaded AdaptiveClassifier instance
385
+ """
386
+ model_path = Path (model_id )
387
+
334
388
# Load configuration
335
- with open (save_dir / 'config.json' , 'r' ) as f :
336
- config = json .load (f )
337
-
389
+ with open (model_path / "config.json" , "r" , encoding = "utf-8" ) as f :
390
+ config_dict = json .load (f )
391
+
392
+ # Load examples
393
+ with open (model_path / "examples.json" , "r" , encoding = "utf-8" ) as f :
394
+ saved_examples = json .load (f )
395
+
338
396
# Initialize classifier
397
+ device = kwargs .get ("device" , None )
339
398
classifier = cls (
340
- config ['model_name' ],
399
+ config_dict ['model_name' ],
341
400
device = device ,
342
- config = config .get ('config' , None )
401
+ config = config_dict .get ('config' , None )
343
402
)
344
-
403
+
345
404
# Restore label mappings
346
- classifier .label_to_id = config ['label_to_id' ]
405
+ classifier .label_to_id = config_dict ['label_to_id' ]
347
406
classifier .id_to_label = {
348
- int (k ): v for k , v in config ['id_to_label' ].items ()
407
+ int (k ): v for k , v in config_dict ['id_to_label' ].items ()
349
408
}
350
- classifier .train_steps = config ['train_steps' ]
351
-
409
+ classifier .train_steps = config_dict ['train_steps' ]
410
+
352
411
# Load tensors
353
- tensors = load_file (save_dir / 'tensors .safetensors' )
354
-
412
+ tensors = load_file (model_path / "model .safetensors" )
413
+
355
414
# Restore saved examples
356
- saved_examples = config ['examples' ]
357
415
for label , examples_data in saved_examples .items ():
358
416
classifier .memory .examples [label ] = [
359
417
Example .from_dict (ex_data ) for ex_data in examples_data
360
418
]
361
-
419
+
362
420
# Restore prototypes
363
421
for label in classifier .label_to_id .keys ():
364
422
prototype_key = f'prototype_{ label } '
365
423
if prototype_key in tensors :
366
424
prototype = tensors [prototype_key ]
367
425
classifier .memory .prototypes [label ] = prototype
368
-
426
+
369
427
# Rebuild memory system
370
428
classifier .memory ._restore_from_save ()
371
-
429
+
372
430
# Restore adaptive head if it exists
373
431
adaptive_head_params = {
374
432
k .replace ('adaptive_head_' , '' ): v
375
433
for k , v in tensors .items ()
376
434
if k .startswith ('adaptive_head_' )
377
435
}
378
-
436
+
379
437
if adaptive_head_params :
380
438
classifier ._initialize_adaptive_head ()
381
439
classifier .adaptive_head .load_state_dict (adaptive_head_params )
382
-
440
+
383
441
return classifier
442
+
443
+ def _generate_model_card (self ) -> str :
444
+ """Generate a model card for the classifier.
445
+
446
+ Returns:
447
+ Model card content as string
448
+ """
449
+ stats = self .get_memory_stats ()
450
+
451
+ model_card = f"""---
452
+ language: multilingual
453
+ tags:
454
+ - adaptive-classifier
455
+ - text-classification
456
+ - continuous-learning
457
+ license: apache-2.0
458
+ ---
459
+
460
+ # Adaptive Classifier
461
+
462
+ This model is an instance of an [adaptive-classifier](https://github.com/codelion/adaptive-classifier) that allows for continuous learning and dynamic class addition.
463
+
464
+ You can install it with `pip install adaptive-classifier`.
465
+
466
+ ## Model Details
467
+
468
+ - Base Model: { self .model .config ._name_or_path }
469
+ - Number of Classes: { stats ['num_classes' ]}
470
+ - Total Examples: { stats ['total_examples' ]}
471
+ - Embedding Dimension: { self .embedding_dim }
472
+
473
+ ## Class Distribution
474
+
475
+ ```
476
+ { self ._format_class_distribution (stats )}
477
+ ```
478
+
479
+ ## Usage
480
+
481
+ ```python
482
+ from adaptive_classifier import AdaptiveClassifier
483
+
484
+ # Load the model
485
+ classifier = AdaptiveClassifier.from_pretrained("{ self .model .config ._name_or_path } ")
486
+
487
+ # Make predictions
488
+ text = "Your text here"
489
+ predictions = classifier.predict(text)
490
+ print(predictions) # List of (label, confidence) tuples
491
+
492
+ # Add new examples
493
+ texts = ["Example 1", "Example 2"]
494
+ labels = ["class1", "class2"]
495
+ classifier.add_examples(texts, labels)
496
+ ```
497
+
498
+ ## Training Details
499
+
500
+ - Training Steps: { self .train_steps }
501
+ - Examples per Class: See distribution above
502
+ - Prototype Memory: Active
503
+ - Neural Adaptation: { "Active" if self .adaptive_head is not None else "Inactive" }
504
+
505
+ ## Limitations
506
+
507
+ This model:
508
+ - Requires at least { self .config .min_examples_per_class } examples per class
509
+ - Has a maximum of { self .config .max_examples_per_class } examples per class
510
+ - Updates prototypes every { self .config .prototype_update_frequency } examples
511
+
512
+ ## Citation
513
+
514
+ ```bibtex
515
+ @software{{adaptive_classifier,
516
+ title = {{Adaptive Classifier: Dynamic Text Classification with Continuous Learning}},
517
+ author = {{Sharma, Asankhaya}},
518
+ year = {{2025}},
519
+ publisher = {{GitHub}},
520
+ url = {{https://github.com/codelion/adaptive-classifier}}
521
+ }}
522
+ ```
523
+ """
524
+ return model_card
525
+
526
+ def _format_class_distribution (self , stats : Dict [str , Any ]) -> str :
527
+ """Format class distribution for model card.
528
+
529
+ Args:
530
+ stats: Statistics from get_memory_stats()
531
+
532
+ Returns:
533
+ Formatted string of class distribution
534
+ """
535
+ if 'examples_per_class' not in stats :
536
+ return "No examples stored"
537
+
538
+ lines = []
539
+ total = sum (stats ['examples_per_class' ].values ())
540
+
541
+ for label , count in sorted (stats ['examples_per_class' ].items ()):
542
+ percentage = (count / total ) * 100 if total > 0 else 0
543
+ lines .append (f"{ label } : { count } examples ({ percentage :.1f} %)" )
544
+
545
+ return "\n " .join (lines )
546
+
547
+ # Keep existing save/load methods for backwards compatibility
548
+ def save (self , save_dir : str ):
549
+ """Legacy save method for backwards compatibility."""
550
+ self ._save_pretrained (save_dir )
551
+
552
+ @classmethod
553
+ def load (cls , save_dir : str , device : Optional [str ] = None ) -> 'AdaptiveClassifier' :
554
+ """Legacy load method for backwards compatibility."""
555
+ return cls ._from_pretrained (save_dir , device = device )
384
556
385
557
def to (self , device : str ) -> 'AdaptiveClassifier' :
386
558
"""Move the model to specified device.
0 commit comments