diff --git a/model2vec/distill/tokenizer.py b/model2vec/distill/tokenizer.py index c50c5d5..858876c 100644 --- a/model2vec/distill/tokenizer.py +++ b/model2vec/distill/tokenizer.py @@ -111,8 +111,7 @@ def _process_wordpiece( tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None ) -> dict[str, Any]: """Process the WordPiece tokenizer JSON.""" - unk_token = unk_token or tokenizer_json["model"]["unk_token"] - tokenizer_json["model"]["unk_token"] = "[UNK]" if unk_token else None + tokenizer_json["model"]["unk_token"] = unk_token tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)} return tokenizer_json @@ -128,20 +127,15 @@ def _process_bpe(tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str] return tokenizer_json -def _process_unigram( - tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None -) -> dict[str, Any]: +def _process_unigram(tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str) -> dict[str, Any]: """Process the Unigram tokenizer JSON.""" - unk_id = tokenizer_json["model"]["unk_id"] - vocab = tokenizer_json["model"]["vocab"] - unk_token = vocab[unk_id][0] if unk_id is not None else None current_probas = dict(tokenizer_json["model"]["vocab"]) avg_proba = sum(current_probas.values()) / len(current_probas) new_probas = {word: current_probas.get(word, avg_proba) for word in pre_tokenized_tokens} tokenizer_json["model"]["vocab"] = sorted(new_probas.items(), key=lambda x: x[1], reverse=True) tokens, _ = zip(*tokenizer_json["model"]["vocab"]) - tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token) if unk_token in tokens else None + tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token) return tokenizer_json @@ -168,11 +162,11 @@ def replace_vocabulary( tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}] if model_type == "WordPiece": - tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, unk_token) + tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, "[UNK]") elif model_type == "BPE": tokenizer_json = _process_bpe(tokenizer_json, pre_tokenized_tokens) elif model_type == "Unigram": - tokenizer_json = _process_unigram(tokenizer_json, pre_tokenized_tokens, unk_token) + tokenizer_json = _process_unigram(tokenizer_json, pre_tokenized_tokens, "[UNK]") else: raise ValueError(f"Unknown model type {model_type}")