Skip to content

Commit 9e55123

Browse files
authored
feat: float pca dims (#163)
* feat: float pca dims * fix test, set svd solver to full for reproducibility * test: add test for float
1 parent 43eab43 commit 9e55123

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

model2vec/distill/distillation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
logger = logging.getLogger(__name__)
3232

3333

34-
PCADimType = Union[int, None, Literal["auto"]]
34+
PCADimType = Union[int, None, float, Literal["auto"]]
3535

3636

3737
def distill_from_model(
@@ -258,10 +258,13 @@ def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply
258258
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
259259
)
260260
elif pca_dims <= embeddings.shape[1]:
261-
logger.info(f"Applying PCA with n_components {pca_dims}")
261+
if isinstance(pca_dims, float):
262+
logger.info(f"Applying PCA with {pca_dims} explained variance.")
263+
else:
264+
logger.info(f"Applying PCA with n_components {pca_dims}")
262265

263266
orig_dims = embeddings.shape[1]
264-
p = PCA(n_components=pca_dims, whiten=False)
267+
p = PCA(n_components=pca_dims, svd_solver="full")
265268
embeddings = p.fit_transform(embeddings)
266269

267270
if embeddings.shape[1] < orig_dims:

tests/test_distillation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
(None, True, "auto", False), # Subword, PCA set to 'auto'
3838
(None, True, 1024, False), # Subword, PCA set to high number.
3939
(None, True, None, True), # No PCA applied
40+
(None, True, 0.9, True), # PCA as float applied
4041
(["wordA", "wordB"], False, 4, True), # Custom vocab without subwords PCA and Zipf applied
4142
(None, False, 256, True), # use_subword = False without passing a vocabulary should raise an error
4243
],

0 commit comments

Comments
 (0)