You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat: Add multilabel classification for training (#191)
* Added multilabel option to training
* Added multilabel option to training
* Added multilabel option to training
* Added multilabel option to training
* Added multilabel option to training
* Added multilabel option to training
* Added threshold to predict
* Updated docs
* Updated docs
* Removed fallback logic
* Updated docs
* Updated docs
* Resolved feedback
* Update model2vec/train/README.md
Co-authored-by: Stephan Tulkens <[email protected]>
* Resolved feedback
* Resolved feedback
* Resolved feedback
* Resolved feedback
* add multilabel targets, fix tests (#194)
* Fixed bug with array conversion
* Optimized inference performance
* Changed classes to np array
* Added int as possible label type
* Added int as possible label type
* Use previous logic
* Updated type check
* Updated type check
* Updated type check logic
* Only return object type array for multilabel
---------
Co-authored-by: Stephan Tulkens <[email protected]>
Copy file name to clipboardExpand all lines: model2vec/train/README.md
+50Lines changed: 50 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -2,6 +2,8 @@
2
2
3
3
Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html).
4
4
5
+
We support both single and multi-label classification, which work seamlessly based on the labels you provide.
6
+
5
7
# Installation
6
8
7
9
To train, make sure you install the training extra:
@@ -65,6 +67,54 @@ print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} ins
65
67
# Took 67 milliseconds for 2000 instances on CPU.
66
68
```
67
69
70
+
## Multi-label classification
71
+
72
+
Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset:
73
+
74
+
```python
75
+
from datasets import load_dataset
76
+
from model2vec.train import StaticModelForClassification
77
+
78
+
# Initialize a classifier from a pre-trained model
The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster.
117
+
68
118
# Persistence
69
119
70
120
You can turn a classifier into a scikit-learn compatible pipeline, as follows:
0 commit comments