From 3efdbe58a87746d289b48032cfb2a1892a671be7 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 23 Jan 2025 21:25:55 +0100 Subject: [PATCH 1/3] feat: float pca dims --- model2vec/distill/distillation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index db1f0b3e..0a083eeb 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -PCADimType = Union[int, None, Literal["auto"]] +PCADimType = Union[int, None, float, Literal["auto"]] def distill_from_model( @@ -258,7 +258,10 @@ def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA." ) elif pca_dims <= embeddings.shape[1]: - logger.info(f"Applying PCA with n_components {pca_dims}") + if isinstance(pca_dims, float): + logger.info(f"Applying PCA with {pca_dims} explained variance.") + else: + logger.info(f"Applying PCA with n_components {pca_dims}") orig_dims = embeddings.shape[1] p = PCA(n_components=pca_dims, whiten=False) From b7c0624f609246ceeacc4ff327e8ba58382af382 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 23 Jan 2025 21:37:41 +0100 Subject: [PATCH 2/3] fix test, set svd solver to full for reproducibility --- model2vec/distill/distillation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index 0a083eeb..8df47b88 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -264,7 +264,7 @@ def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply logger.info(f"Applying PCA with n_components {pca_dims}") orig_dims = embeddings.shape[1] - p = PCA(n_components=pca_dims, whiten=False) + p = PCA(n_components=pca_dims, svd_solver="full") embeddings = p.fit_transform(embeddings) if embeddings.shape[1] < orig_dims: From 5caab7d13ae130b7d2dc0b35562b3cd3c4668261 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 23 Jan 2025 21:38:41 +0100 Subject: [PATCH 3/3] test: add test for float --- tests/test_distillation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_distillation.py b/tests/test_distillation.py index 168404f2..bdf3d891 100644 --- a/tests/test_distillation.py +++ b/tests/test_distillation.py @@ -37,6 +37,7 @@ (None, True, "auto", False), # Subword, PCA set to 'auto' (None, True, 1024, False), # Subword, PCA set to high number. (None, True, None, True), # No PCA applied + (None, True, 0.9, True), # PCA as float applied (["wordA", "wordB"], False, 4, True), # Custom vocab without subwords PCA and Zipf applied (None, False, 256, True), # use_subword = False without passing a vocabulary should raise an error ],