diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index db1f0b3e..8df47b88 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -PCADimType = Union[int, None, Literal["auto"]] +PCADimType = Union[int, None, float, Literal["auto"]] def distill_from_model( @@ -258,10 +258,13 @@ def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA." ) elif pca_dims <= embeddings.shape[1]: - logger.info(f"Applying PCA with n_components {pca_dims}") + if isinstance(pca_dims, float): + logger.info(f"Applying PCA with {pca_dims} explained variance.") + else: + logger.info(f"Applying PCA with n_components {pca_dims}") orig_dims = embeddings.shape[1] - p = PCA(n_components=pca_dims, whiten=False) + p = PCA(n_components=pca_dims, svd_solver="full") embeddings = p.fit_transform(embeddings) if embeddings.shape[1] < orig_dims: diff --git a/tests/test_distillation.py b/tests/test_distillation.py index 168404f2..bdf3d891 100644 --- a/tests/test_distillation.py +++ b/tests/test_distillation.py @@ -37,6 +37,7 @@ (None, True, "auto", False), # Subword, PCA set to 'auto' (None, True, 1024, False), # Subword, PCA set to high number. (None, True, None, True), # No PCA applied + (None, True, 0.9, True), # PCA as float applied (["wordA", "wordB"], False, 4, True), # Custom vocab without subwords PCA and Zipf applied (None, False, 256, True), # use_subword = False without passing a vocabulary should raise an error ],