diff --git a/README.md b/README.md index 944b978e..83f806ab 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ from model2vec.train import StaticModelForClassification # Initialize a classifier from a pre-trained model classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M") -# Load a dataset +# Load a dataset. Note: both single and multi-label classification datasets are supported ds = load_dataset("setfit/subj") # Train the classifier on text (X) and labels (y) diff --git a/model2vec/inference/model.py b/model2vec/inference/model.py index 5b08dad8..6f5a17ba 100644 --- a/model2vec/inference/model.py +++ b/model2vec/inference/model.py @@ -7,6 +7,7 @@ import huggingface_hub import numpy as np import skops.io +from sklearn.neural_network import MLPClassifier from sklearn.pipeline import Pipeline from model2vec.hf_utils import _create_model_card @@ -21,6 +22,20 @@ def __init__(self, model: StaticModel, head: Pipeline) -> None: """Create a pipeline with a StaticModel encoder.""" self.model = model self.head = head + classifier = self.head[-1] + # Check if the classifier is a multilabel classifier. + # NOTE: this doesn't look robust, but it is. + # Different classifiers, such as OVR wrappers, support multilabel output natively, so we + # can just use predict. + self.multilabel = False + if isinstance(classifier, MLPClassifier): + if classifier.out_activation_ == "logistic": + self.multilabel = True + + @property + def classes_(self) -> np.ndarray: + """The classes of the classifier.""" + return self.head.classes_ @classmethod def from_pretrained( @@ -60,7 +75,7 @@ def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = Fa self.model.save_pretrained(temp_dir) push_folder_to_hub(Path(temp_dir), repo_id, private, token) - def _predict_and_coerce_to_2d( + def _encode_and_coerce_to_2d( self, X: list[str] | str, show_progress_bar: bool, @@ -69,7 +84,7 @@ def _predict_and_coerce_to_2d( use_multiprocessing: bool, multiprocessing_threshold: int, ) -> np.ndarray: - """Predict the labels of the input and coerce the output to a matrix.""" + """Encode the instances and coerce the output to a matrix.""" encoded = self.model.encode( X, show_progress_bar=show_progress_bar, @@ -91,9 +106,21 @@ def predict( batch_size: int = 1024, use_multiprocessing: bool = True, multiprocessing_threshold: int = 10_000, + threshold: float = 0.5, ) -> np.ndarray: - """Predict the labels of the input.""" - encoded = self._predict_and_coerce_to_2d( + """ + Predict the labels of the input. + + :param X: The input data to predict. Can be a list of strings or a single string. + :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False. + :param max_length: The maximum length of the input sequences. Defaults to 512. + :param batch_size: The batch size for prediction. Defaults to 1024. + :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True. + :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000. + :param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel. + :return: The predicted labels or probabilities. + """ + encoded = self._encode_and_coerce_to_2d( X, show_progress_bar=show_progress_bar, max_length=max_length, @@ -102,6 +129,13 @@ def predict( multiprocessing_threshold=multiprocessing_threshold, ) + if self.multilabel: + out_labels = [] + proba = self.head.predict_proba(encoded) + for vector in proba: + out_labels.append(self.classes_[vector > threshold]) + return np.asarray(out_labels, dtype=object) + return self.head.predict(encoded) def predict_proba( @@ -113,8 +147,18 @@ def predict_proba( use_multiprocessing: bool = True, multiprocessing_threshold: int = 10_000, ) -> np.ndarray: - """Predict the probabilities of the labels of the input.""" - encoded = self._predict_and_coerce_to_2d( + """ + Predict the labels of the input. + + :param X: The input data to predict. Can be a list of strings or a single string. + :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False. + :param max_length: The maximum length of the input sequences. Defaults to 512. + :param batch_size: The batch size for prediction. Defaults to 1024. + :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True. + :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000. + :return: The predicted labels or probabilities. + """ + encoded = self._encode_and_coerce_to_2d( X, show_progress_bar=show_progress_bar, max_length=max_length, diff --git a/model2vec/train/README.md b/model2vec/train/README.md index 87365fc1..2c908fff 100644 --- a/model2vec/train/README.md +++ b/model2vec/train/README.md @@ -2,6 +2,8 @@ Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html). +We support both single and multi-label classification, which work seamlessly based on the labels you provide. + # Installation To train, make sure you install the training extra: @@ -65,6 +67,54 @@ print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} ins # Took 67 milliseconds for 2000 instances on CPU. ``` +## Multi-label classification + +Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset: + +```python +from datasets import load_dataset +from model2vec.train import StaticModelForClassification + +# Initialize a classifier from a pre-trained model +classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M") + +# Load a multi-label dataset +ds = load_dataset("google-research-datasets/go_emotions") + +# Inspect some of the labels +print(ds["train"]["labels"][40:50]) +# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]] + +# Train the classifier on text (X) and labels (y) +classifier.fit(ds["train"]["text"], ds["train"]["labels"]) +``` + +Then, we can evaluate the classifier: + +```python +from sklearn import metrics +from sklearn.preprocessing import MultiLabelBinarizer + +# Make predictions on the test set with a threshold of 0.3 +predictions = classifier.predict(ds["test"]["text"], threshold=0.3) + +# Evaluate the classifier +mlb = MultiLabelBinarizer(classes=classifier.classes) +y_true = mlb.fit_transform(ds["test"]["labels"]) +y_pred = mlb.transform(predictions) + +print(f"Accuracy: {metrics.accuracy_score(y_true, y_pred):.3f}") +print(f"Precision: {metrics.precision_score(y_true, y_pred, average='macro', zero_division=0):.3f}") +print(f"Recall: {metrics.recall_score(y_true, y_pred, average='macro', zero_division=0):.3f}") +print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0):.3f}") +# Accuracy: 0.410 +# Precision: 0.527 +# Recall: 0.410 +# F1: 0.439 +``` + +The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster. + # Persistence You can turn a classifier into a scikit-learn compatible pipeline, as follows: diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index d90986a8..213b75ab 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -2,13 +2,16 @@ import logging from collections import Counter +from itertools import chain from tempfile import TemporaryDirectory +from typing import TypeVar, cast import lightning as pl import numpy as np import torch from lightning.pytorch.callbacks import Callback, EarlyStopping from lightning.pytorch.utilities.types import OptimizerLRScheduler +from sklearn.metrics import jaccard_score from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPClassifier from sklearn.pipeline import make_pipeline @@ -20,9 +23,10 @@ from model2vec.train.base import FinetunableStaticModel, TextDataset logger = logging.getLogger(__name__) - _RANDOM_SEED = 42 +LabelType = TypeVar("LabelType", list[str], list[list[str]]) + class StaticModelForClassification(FinetunableStaticModel): def __init__( @@ -40,12 +44,14 @@ def __init__( self.hidden_dim = hidden_dim # Alias: Follows scikit-learn. Set to dummy classes self.classes_: list[str] = [str(x) for x in range(out_dim)] + # multilabel flag will be set based on the type of `y` passed to fit. + self.multilabel: bool = False super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer) @property - def classes(self) -> list[str]: + def classes(self) -> np.ndarray: """Return all clasess in the correct order.""" - return self.classes_ + return np.array(self.classes_) def construct_head(self) -> nn.Sequential: """Constructs a simple classifier head.""" @@ -66,14 +72,35 @@ def construct_head(self) -> nn.Sequential: return nn.Sequential(*modules) - def predict(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: - """Predict a class for a set of texts.""" - pred: list[str] = [] + def predict( + self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024, threshold: float = 0.5 + ) -> np.ndarray: + """ + Predict labels for a set of texts. + + In single-label mode, each prediction is a single class. + In multilabel mode, each prediction is a list of classes. + + :param X: The texts to predict on. + :param show_progress_bar: Whether to show a progress bar. + :param batch_size: The batch size. + :param threshold: The threshold for multilabel classification. + :return: The predictions. + """ + pred = [] for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): logits = self._predict_single_batch(X[batch : batch + batch_size]) - pred.extend([self.classes[idx] for idx in logits.argmax(1)]) - - return np.asarray(pred) + if self.multilabel: + probs = torch.sigmoid(logits) + mask = (probs > threshold).cpu().numpy() + pred.extend([self.classes[np.flatnonzero(row)] for row in mask]) + else: + pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()]) + if self.multilabel: + # Return as object array to allow for lists of varying lengths. + return np.array(pred, dtype=object) + else: + return np.array(pred) @torch.no_grad() def _predict_single_batch(self, X: list[str]) -> torch.Tensor: @@ -82,18 +109,25 @@ def _predict_single_batch(self, X: list[str]) -> torch.Tensor: return vectors def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: - """Predict the probability of each class.""" - pred: list[np.ndarray] = [] + """ + Predict probabilities for each class. + + In single-label mode, returns softmax probabilities. + In multilabel mode, returns sigmoid probabilities. + """ + pred = [] for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): logits = self._predict_single_batch(X[batch : batch + batch_size]) - pred.append(torch.softmax(logits, dim=1).numpy()) - - return np.concatenate(pred) + if self.multilabel: + pred.append(torch.sigmoid(logits).cpu().numpy()) + else: + pred.append(torch.softmax(logits, dim=1).cpu().numpy()) + return np.concatenate(pred, axis=0) def fit( self, X: list[str], - y: list[str], + y: LabelType, learning_rate: float = 1e-3, batch_size: int | None = None, min_epochs: int | None = None, @@ -106,16 +140,16 @@ def fit( Fit a model. This function creates a Lightning Trainer object and fits the model to the data. - We use early stopping. After training, the weigths of the best model are loaded back into the model. + It supports both single-label and multi-label classification. + We use early stopping. After training, the weights of the best model are loaded back into the model. This function seeds everything with a seed of 42, so the results are reproducible. It also splits the data into a train and validation set, again with a random seed. :param X: The texts to train on. - :param y: The labels to train on. + :param y: The labels to train on. If the first element is a list, multi-label classification is assumed. :param learning_rate: The learning rate. - :param batch_size: The batch size. - If this is None, a good batch size is chosen automatically. + :param batch_size: The batch size. If None, a good batch size is chosen automatically. :param min_epochs: The minimum number of epochs to train for. :param max_epochs: The maximum number of epochs to train for. If this is -1, the model trains until early stopping is triggered. @@ -127,10 +161,15 @@ def fit( """ pl.seed_everything(_RANDOM_SEED) logger.info("Re-initializing model.") + + # Determine whether the task is multilabel based on the type of y. + self._initialize(y) train_texts, validation_texts, train_labels, validation_labels = self._train_test_split( - X, y, test_size=test_size + X, + y, + test_size=test_size, ) if batch_size is None: @@ -186,24 +225,44 @@ def fit( self.load_state_dict(state_dict) self.eval() - return self - def _initialize(self, y: list[str]) -> None: - """Sets the out dimensionality, the classes and initializes the head.""" - classes = sorted(set(y)) - self.classes_ = classes + def _initialize(self, y: LabelType) -> None: + """ + Sets the output dimensionality, the classes, and initializes the head. - if len(self.classes) != self.out_dim: - self.out_dim = len(self.classes) + :param y: The labels. + :raises ValueError: If the labels are inconsistent. + """ + if isinstance(y[0], (str, int)): + # Check if all labels are strings or integers. + if not all(isinstance(label, (str, int)) for label in y): + raise ValueError("Inconsistent label types in y. All labels must be strings or integers.") + self.multilabel = False + classes = sorted(set(y)) + else: + # Check if all labels are lists or tuples. + if not all(isinstance(label, (list, tuple)) for label in y): + raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") + self.multilabel = True + classes = sorted(set(chain.from_iterable(y))) + self.classes_ = classes + self.out_dim = len(self.classes_) # Update output dimension self.head = self.construct_head() self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id) self.w = self.construct_weights() self.train() - def _prepare_dataset(self, X: list[str], y: list[str], max_length: int = 512) -> TextDataset: - """Prepare a dataset.""" + def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> TextDataset: + """ + Prepare a dataset. For multilabel classification, each target is converted into a multi-hot vector. + + :param X: The texts. + :param y: The labels. + :param max_length: The maximum length of the input. + :return: A TextDataset. + """ # This is a speed optimization. # assumes a mean token length of 10, which is really high, so safe. truncate_length = max_length * 10 @@ -211,20 +270,39 @@ def _prepare_dataset(self, X: list[str], y: list[str], max_length: int = 512) -> tokenized: list[list[int]] = [ encoding.ids[:max_length] for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False) ] - labels_tensor = torch.Tensor([self.classes.index(label) for label in y]).long() - + if self.multilabel: + # Convert labels to multi-hot vectors + num_classes = len(self.classes_) + labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float) + mapping = {label: idx for idx, label in enumerate(self.classes_)} + for i, sample_labels in enumerate(y): + indices = [mapping[label] for label in sample_labels] + labels_tensor[i, indices] = 1.0 + else: + labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long) return TextDataset(tokenized, labels_tensor) - @staticmethod def _train_test_split( - X: list[str], y: list[str], test_size: float - ) -> tuple[list[str], list[str], list[str], list[str]]: - """Split the data.""" - label_counts = Counter(y) - if min(label_counts.values()) < 2: - logger.info("Some classes have less than 2 samples. Stratification is disabled.") + self, + X: list[str], + y: list[str] | list[list[str]], + test_size: float, + ) -> tuple[list[str], list[str], LabelType, LabelType]: + """ + Split the data. + + For single-label classification, stratification is attempted (if possible). + For multilabel classification, a random split is performed. + """ + if not self.multilabel: + label_counts = Counter(y) + if min(label_counts.values()) < 2: + logger.info("Some classes have less than 2 samples. Stratification is disabled.") + return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True) + return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y) + else: + # Multilabel classification does not support stratification. return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True) - return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y) def to_pipeline(self) -> StaticModelPipeline: """Convert the model to an sklearn pipeline.""" @@ -248,56 +326,58 @@ def to_pipeline(self) -> StaticModelPipeline: # To convert correctly, we need to set the outputs correctly, and fix the activation function. # Make sure n_outputs is set to > 1. mlp_head.n_outputs_ = self.out_dim - # Set to softmax - mlp_head.out_activation_ = "softmax" + # Set to softmax or sigmoid + mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax" return StaticModelPipeline(static_model, converted) class _ClassifierLightningModule(pl.LightningModule): def __init__(self, model: StaticModelForClassification, learning_rate: float) -> None: - """Initialize the lightningmodule.""" + """Initialize the LightningModule.""" super().__init__() self.model = model self.learning_rate = learning_rate + self.loss_function = nn.CrossEntropyLoss() if not model.multilabel else nn.BCEWithLogitsLoss() def forward(self, x: torch.Tensor) -> torch.Tensor: """Simple forward pass.""" return self.model(x) def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Simple training step using cross entropy loss.""" + """Training step using cross-entropy loss for single-label and binary cross-entropy for multilabel training.""" x, y = batch head_out, _ = self.model(x) - loss = nn.functional.cross_entropy(head_out, y).mean() - + loss = self.loss_function(head_out, y) self.log("train_loss", loss) return loss def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Simple validation step using cross entropy loss and accuracy.""" + """Validation step computing loss and accuracy.""" x, y = batch head_out, _ = self.model(x) - loss = nn.functional.cross_entropy(head_out, y).mean() - accuracy = (head_out.argmax(1) == y).float().mean() - + loss = self.loss_function(head_out, y) + if self.model.multilabel: + preds = (torch.sigmoid(head_out) > 0.5).float() + # Multilabel accuracy is defined as the Jaccard score averaged over samples. + accuracy = jaccard_score(y.cpu(), preds.cpu(), average="samples") + else: + accuracy = (head_out.argmax(dim=1) == y).float().mean() self.log("val_loss", loss) self.log("val_accuracy", accuracy, prog_bar=True) return loss def configure_optimizers(self) -> OptimizerLRScheduler: - """Simple Adam optimizer.""" + """Configure optimizer and learning rate scheduler.""" optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=3, - verbose=True, min_lr=1e-6, threshold=0.03, threshold_mode="rel", ) - return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} diff --git a/tests/conftest.py b/tests/conftest.py index 62203292..2f258868 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,15 +5,12 @@ import numpy as np import pytest import torch -from sklearn.neural_network import MLPClassifier -from sklearn.pipeline import make_pipeline from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace from transformers import AutoModel, AutoTokenizer from model2vec.inference import StaticModelPipeline -from model2vec.model import StaticModel from model2vec.train import StaticModelForClassification @@ -86,13 +83,23 @@ def mock_inference_pipeline(mock_trained_pipeline: StaticModelForClassification) return mock_trained_pipeline.to_pipeline() -@pytest.fixture(scope="session") -def mock_trained_pipeline() -> StaticModelForClassification: +@pytest.fixture(params=[False, True], ids=["single_label", "multilabel"], scope="session") +def mock_trained_pipeline(request: pytest.FixtureRequest) -> StaticModelForClassification: """Mock staticmodelforclassification.""" tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer torch.random.manual_seed(42) vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) - s = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") - s.fit(["dog", "cat"], ["a", "b"], device="cpu") + model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") + + X = ["dog", "cat"] + y: list[str] | list[list[str]] + if request.param: + # Use multilabel targets. + y = [["a", "b"], ["a"]] + else: + # Use singlelabel targets. + y = ["a", "b"] + + model.fit(X, y) - return s + return model diff --git a/tests/test_inference.py b/tests/test_inference.py index 9f4618df..9a12894b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -10,8 +10,13 @@ def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None: """Test successful initialization of StaticModelPipeline.""" - assert mock_inference_pipeline.predict("dog").tolist() == ["b"] - assert mock_inference_pipeline.predict(["dog"]).tolist() == ["b"] + target: list[str] | list[list[str]] + if mock_inference_pipeline.multilabel: + target = [["a", "b"]] + else: + target = ["b"] + assert mock_inference_pipeline.predict("dog").tolist() == target + assert mock_inference_pipeline.predict(["dog"]).tolist() == target def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> None: @@ -25,8 +30,13 @@ def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None: with TemporaryDirectory() as temp_dir: mock_inference_pipeline.save_pretrained(temp_dir) loaded = StaticModelPipeline.from_pretrained(temp_dir) - assert loaded.predict("dog") == ["b"] - assert loaded.predict(["dog"]) == ["b"] + target: list[str] | list[list[str]] + if mock_inference_pipeline.multilabel: + target = [["a", "b"]] + else: + target = ["b"] + assert loaded.predict("dog").tolist() == target + assert loaded.predict(["dog"]).tolist() == target assert loaded.predict_proba("dog").argmax() == 1 assert loaded.predict_proba(["dog"]).argmax(1).tolist() == [1] diff --git a/tests/test_trainable.py b/tests/test_trainable.py index dc9bb811..2fd11e88 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -17,8 +17,8 @@ def test_init_predict(n_layers: int, mock_vectors: np.ndarray, mock_tokenizer: T s = StaticModelForClassification(vectors=vectors_torched, tokenizer=mock_tokenizer, n_layers=n_layers) assert s.vectors.shape == mock_vectors.shape assert s.w.shape[0] == mock_vectors.shape[0] - assert s.classes == s.classes_ - assert s.classes == ["0", "1"] + assert list(s.classes) == s.classes_ + assert list(s.classes) == ["0", "1"] head = s.construct_head() assert head[0].in_features == mock_vectors.shape[1] @@ -112,7 +112,10 @@ def test_textdataset_init_incorrect() -> None: def test_predict(mock_trained_pipeline: StaticModelForClassification) -> None: """Test the predict function.""" result = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist() - assert result == ["b", "b"] + if mock_trained_pipeline.multilabel: + assert result == [["a", "b"], ["a", "b"]] + else: + assert result == ["b", "b"] def test_predict_proba(mock_trained_pipeline: StaticModelForClassification) -> None: @@ -136,9 +139,9 @@ def test_convert_to_pipeline(mock_trained_pipeline: StaticModelForClassification assert np.allclose(p1, p2) -def test_train_test_split() -> None: +def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -> None: """Test the train test split function.""" - a, b, c, d = StaticModelForClassification._train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) + a, b, c, d = mock_trained_pipeline._train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) assert len(a) == 2 assert len(b) == 2 assert len(c) == len(a) diff --git a/uv.lock b/uv.lock index f7d37b8c..a0465750 100644 --- a/uv.lock +++ b/uv.lock @@ -791,7 +791,7 @@ wheels = [ [[package]] name = "model2vec" -version = "0.3.8" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "jinja2" },