Skip to content

Commit 2d51516

Browse files
Pringledstephantul
andauthored
feat: Add multilabel classification for training (#191)
* Added multilabel option to training * Added multilabel option to training * Added multilabel option to training * Added multilabel option to training * Added multilabel option to training * Added multilabel option to training * Added threshold to predict * Updated docs * Updated docs * Removed fallback logic * Updated docs * Updated docs * Resolved feedback * Update model2vec/train/README.md Co-authored-by: Stephan Tulkens <[email protected]> * Resolved feedback * Resolved feedback * Resolved feedback * Resolved feedback * add multilabel targets, fix tests (#194) * Fixed bug with array conversion * Optimized inference performance * Changed classes to np array * Added int as possible label type * Added int as possible label type * Use previous logic * Updated type check * Updated type check * Updated type check logic * Only return object type array for multilabel --------- Co-authored-by: Stephan Tulkens <[email protected]>
1 parent 84e3fa8 commit 2d51516

File tree

8 files changed

+271
-77
lines changed

8 files changed

+271
-77
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ from model2vec.train import StaticModelForClassification
105105
# Initialize a classifier from a pre-trained model
106106
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
107107

108-
# Load a dataset
108+
# Load a dataset. Note: both single and multi-label classification datasets are supported
109109
ds = load_dataset("setfit/subj")
110110

111111
# Train the classifier on text (X) and labels (y)

model2vec/inference/model.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import huggingface_hub
88
import numpy as np
99
import skops.io
10+
from sklearn.neural_network import MLPClassifier
1011
from sklearn.pipeline import Pipeline
1112

1213
from model2vec.hf_utils import _create_model_card
@@ -21,6 +22,20 @@ def __init__(self, model: StaticModel, head: Pipeline) -> None:
2122
"""Create a pipeline with a StaticModel encoder."""
2223
self.model = model
2324
self.head = head
25+
classifier = self.head[-1]
26+
# Check if the classifier is a multilabel classifier.
27+
# NOTE: this doesn't look robust, but it is.
28+
# Different classifiers, such as OVR wrappers, support multilabel output natively, so we
29+
# can just use predict.
30+
self.multilabel = False
31+
if isinstance(classifier, MLPClassifier):
32+
if classifier.out_activation_ == "logistic":
33+
self.multilabel = True
34+
35+
@property
36+
def classes_(self) -> np.ndarray:
37+
"""The classes of the classifier."""
38+
return self.head.classes_
2439

2540
@classmethod
2641
def from_pretrained(
@@ -60,7 +75,7 @@ def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = Fa
6075
self.model.save_pretrained(temp_dir)
6176
push_folder_to_hub(Path(temp_dir), repo_id, private, token)
6277

63-
def _predict_and_coerce_to_2d(
78+
def _encode_and_coerce_to_2d(
6479
self,
6580
X: list[str] | str,
6681
show_progress_bar: bool,
@@ -69,7 +84,7 @@ def _predict_and_coerce_to_2d(
6984
use_multiprocessing: bool,
7085
multiprocessing_threshold: int,
7186
) -> np.ndarray:
72-
"""Predict the labels of the input and coerce the output to a matrix."""
87+
"""Encode the instances and coerce the output to a matrix."""
7388
encoded = self.model.encode(
7489
X,
7590
show_progress_bar=show_progress_bar,
@@ -91,9 +106,21 @@ def predict(
91106
batch_size: int = 1024,
92107
use_multiprocessing: bool = True,
93108
multiprocessing_threshold: int = 10_000,
109+
threshold: float = 0.5,
94110
) -> np.ndarray:
95-
"""Predict the labels of the input."""
96-
encoded = self._predict_and_coerce_to_2d(
111+
"""
112+
Predict the labels of the input.
113+
114+
:param X: The input data to predict. Can be a list of strings or a single string.
115+
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
116+
:param max_length: The maximum length of the input sequences. Defaults to 512.
117+
:param batch_size: The batch size for prediction. Defaults to 1024.
118+
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
119+
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
120+
:param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel.
121+
:return: The predicted labels or probabilities.
122+
"""
123+
encoded = self._encode_and_coerce_to_2d(
97124
X,
98125
show_progress_bar=show_progress_bar,
99126
max_length=max_length,
@@ -102,6 +129,13 @@ def predict(
102129
multiprocessing_threshold=multiprocessing_threshold,
103130
)
104131

132+
if self.multilabel:
133+
out_labels = []
134+
proba = self.head.predict_proba(encoded)
135+
for vector in proba:
136+
out_labels.append(self.classes_[vector > threshold])
137+
return np.asarray(out_labels, dtype=object)
138+
105139
return self.head.predict(encoded)
106140

107141
def predict_proba(
@@ -113,8 +147,18 @@ def predict_proba(
113147
use_multiprocessing: bool = True,
114148
multiprocessing_threshold: int = 10_000,
115149
) -> np.ndarray:
116-
"""Predict the probabilities of the labels of the input."""
117-
encoded = self._predict_and_coerce_to_2d(
150+
"""
151+
Predict the labels of the input.
152+
153+
:param X: The input data to predict. Can be a list of strings or a single string.
154+
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
155+
:param max_length: The maximum length of the input sequences. Defaults to 512.
156+
:param batch_size: The batch size for prediction. Defaults to 1024.
157+
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
158+
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
159+
:return: The predicted labels or probabilities.
160+
"""
161+
encoded = self._encode_and_coerce_to_2d(
118162
X,
119163
show_progress_bar=show_progress_bar,
120164
max_length=max_length,

model2vec/train/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html).
44

5+
We support both single and multi-label classification, which work seamlessly based on the labels you provide.
6+
57
# Installation
68

79
To train, make sure you install the training extra:
@@ -65,6 +67,54 @@ print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} ins
6567
# Took 67 milliseconds for 2000 instances on CPU.
6668
```
6769

70+
## Multi-label classification
71+
72+
Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset:
73+
74+
```python
75+
from datasets import load_dataset
76+
from model2vec.train import StaticModelForClassification
77+
78+
# Initialize a classifier from a pre-trained model
79+
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
80+
81+
# Load a multi-label dataset
82+
ds = load_dataset("google-research-datasets/go_emotions")
83+
84+
# Inspect some of the labels
85+
print(ds["train"]["labels"][40:50])
86+
# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]]
87+
88+
# Train the classifier on text (X) and labels (y)
89+
classifier.fit(ds["train"]["text"], ds["train"]["labels"])
90+
```
91+
92+
Then, we can evaluate the classifier:
93+
94+
```python
95+
from sklearn import metrics
96+
from sklearn.preprocessing import MultiLabelBinarizer
97+
98+
# Make predictions on the test set with a threshold of 0.3
99+
predictions = classifier.predict(ds["test"]["text"], threshold=0.3)
100+
101+
# Evaluate the classifier
102+
mlb = MultiLabelBinarizer(classes=classifier.classes)
103+
y_true = mlb.fit_transform(ds["test"]["labels"])
104+
y_pred = mlb.transform(predictions)
105+
106+
print(f"Accuracy: {metrics.accuracy_score(y_true, y_pred):.3f}")
107+
print(f"Precision: {metrics.precision_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
108+
print(f"Recall: {metrics.recall_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
109+
print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
110+
# Accuracy: 0.410
111+
# Precision: 0.527
112+
# Recall: 0.410
113+
# F1: 0.439
114+
```
115+
116+
The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster.
117+
68118
# Persistence
69119

70120
You can turn a classifier into a scikit-learn compatible pipeline, as follows:

0 commit comments

Comments
 (0)