diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 4fff7a8fc8e..f44da95d321 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -63,14 +63,16 @@ class Request: output_len: int -def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: +def sample_tokens(tokenizer: PreTrainedTokenizerBase, + length: int) -> list[int]: vocab = tokenizer.get_vocab() + all_special_ids = set(tokenizer.all_special_ids) + # Remove the special tokens. - vocab = { - k: v - for k, v in vocab.items() if k not in tokenizer.all_special_ids - } - return random.choices(list(vocab.values()), k=length) + return random.choices( + [v for k, v in vocab.items() if k not in all_special_ids], + k=length, + ) def sample_requests_from_dataset( diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py index cd60cefd7cc..c740fde4263 100644 --- a/tests/tokenization/test_cached_tokenizer.py +++ b/tests/tokenization/test_cached_tokenizer.py @@ -1,24 +1,43 @@ # SPDX-License-Identifier: Apache-2.0 - +import pickle from copy import deepcopy +import pytest from transformers import AutoTokenizer -from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + get_cached_tokenizer) -def test_cached_tokenizer(): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") +@pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"]) +def test_cached_tokenizer(model_id: str): + reference_tokenizer = AutoTokenizer.from_pretrained(model_id, + trust_remote_code=True) reference_tokenizer.add_special_tokens({"cls_token": ""}) reference_tokenizer.add_special_tokens( {"additional_special_tokens": [""]}) + cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) + _check_consistency(cached_tokenizer, reference_tokenizer) + + pickled_tokenizer = pickle.dumps(cached_tokenizer) + unpickled_tokenizer = pickle.loads(pickled_tokenizer) + _check_consistency(unpickled_tokenizer, reference_tokenizer) + + +def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer): + assert isinstance(target, type(expected)) + + # Cached attributes + assert target.all_special_ids == expected.all_special_ids + assert target.all_special_tokens == expected.all_special_tokens + assert (target.all_special_tokens_extended == + expected.all_special_tokens_extended) + assert target.get_vocab() == expected.get_vocab() + assert len(target) == len(expected) + + # Other attributes + assert getattr(target, "padding_side", + None) == getattr(expected, "padding_side", None) - assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( - "prompt") - assert set(reference_tokenizer.all_special_ids) == set( - cached_tokenizer.all_special_ids) - assert set(reference_tokenizer.all_special_tokens) == set( - cached_tokenizer.all_special_tokens) - assert set(reference_tokenizer.all_special_tokens_extended) == set( - cached_tokenizer.all_special_tokens_extended) + assert target.encode("prompt") == expected.encode("prompt") diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 1bfb5032833..da5bec85666 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import os import warnings from functools import lru_cache @@ -70,18 +71,17 @@ def encode_tokens( def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: - """Get tokenizer with cached properties. - - This will patch the tokenizer object in place. - + """ By default, transformers will recompute multiple tokenizer properties - each time they are called, leading to a significant slowdown. This - function caches these properties for faster access.""" + each time they are called, leading to a significant slowdown. + This proxy caches these properties for faster access. + """ + cached_tokenizer = copy.copy(tokenizer) - tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_ids = tokenizer.all_special_ids + tokenizer_all_special_tokens = tokenizer.all_special_tokens tokenizer_all_special_tokens_extended = ( tokenizer.all_special_tokens_extended) - tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_vocab = tokenizer.get_vocab() tokenizer_len = len(tokenizer) @@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: class CachedTokenizer(tokenizer.__class__): # type: ignore @property - def all_special_ids(self): + def all_special_ids(self) -> list[int]: return tokenizer_all_special_ids @property - def all_special_tokens(self): + def all_special_tokens(self) -> list[str]: return tokenizer_all_special_tokens @property - def all_special_tokens_extended(self): + def all_special_tokens_extended(self) -> list[str]: return tokenizer_all_special_tokens_extended @property - def max_token_id(self): + def max_token_id(self) -> int: return max_token_id - def get_vocab(self): + def get_vocab(self) -> dict[str, int]: return tokenizer_vocab - def __len__(self): + def __len__(self) -> int: return tokenizer_len + def __reduce__(self): + return get_cached_tokenizer, (tokenizer, ) + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" - tokenizer.__class__ = CachedTokenizer - return tokenizer + cached_tokenizer.__class__ = CachedTokenizer + return cached_tokenizer def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index bb5ddaf88b2..b4eb081c9b9 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -2,7 +2,7 @@ import importlib from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -12,17 +12,17 @@ class TokenizerBase(ABC): @property @abstractmethod - def all_special_tokens_extended(self) -> List[str]: + def all_special_tokens_extended(self) -> list[str]: raise NotImplementedError() @property @abstractmethod - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: raise NotImplementedError() @property @abstractmethod - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: raise NotImplementedError() @property @@ -66,7 +66,7 @@ def __len__(self) -> int: @abstractmethod def __call__( self, - text: Union[str, List[str], List[int]], + text: Union[str, list[str], list[int]], text_pair: Optional[str] = None, add_special_tokens: bool = False, truncation: bool = False, @@ -75,11 +75,11 @@ def __call__( raise NotImplementedError() @abstractmethod - def get_vocab(self) -> Dict[str, int]: + def get_vocab(self) -> dict[str, int]: raise NotImplementedError() @abstractmethod - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: raise NotImplementedError() @abstractmethod @@ -88,44 +88,44 @@ def encode_one( text: str, truncation: bool = False, max_length: Optional[int] = None, - ) -> List[int]: + ) -> list[int]: raise NotImplementedError() @abstractmethod def encode(self, text: str, - add_special_tokens: Optional[bool] = None) -> List[int]: + add_special_tokens: Optional[bool] = None) -> list[int]: raise NotImplementedError() @abstractmethod def apply_chat_template(self, - messages: List["ChatCompletionMessageParam"], - tools: Optional[List[Dict[str, Any]]] = None, - **kwargs) -> List[int]: + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs) -> list[int]: raise NotImplementedError() @abstractmethod - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() @abstractmethod def decode(self, - ids: Union[List[int], int], + ids: Union[list[int], int], skip_special_tokens: bool = True) -> str: raise NotImplementedError() @abstractmethod def convert_ids_to_tokens( self, - ids: List[int], + ids: list[int], skip_special_tokens: bool = True, - ) -> List[str]: + ) -> list[str]: raise NotImplementedError() class TokenizerRegistry: # Tokenizer name -> (tokenizer module, tokenizer class) - REGISTRY: Dict[str, Tuple[str, str]] = {} + REGISTRY: dict[str, tuple[str, str]] = {} @staticmethod def register(name: str, module: str, class_name: str) -> None: diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 58a114fa3a3..296149a4569 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -257,7 +257,7 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str, # the following attributes are set to fit vLLM's design and are used # by the guided structured output backends. @property - def all_special_tokens_extended(self) -> List[str]: + def all_special_tokens_extended(self) -> list[str]: from mistral_common.tokens.tokenizers.base import SpecialTokens # tekken defines its own extended special tokens list @@ -271,11 +271,11 @@ def all_special_tokens_extended(self) -> List[str]: ] @property - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: return self.all_special_tokens_extended @property - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: return [ self.all_special_tokens.index(t) for t in self.all_special_tokens ] @@ -335,12 +335,12 @@ def __call__( input_ids = self.encode_one(text, truncation, max_length) return Encoding(input_ids=input_ids) - def get_vocab(self) -> Dict[str, int]: + def get_vocab(self) -> dict[str, int]: # NB: the dictionary form of the vocabulary collapses token ids that map # to the same string but have different bytes return self._vocab_dict - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: # Mistral tokenizers have no added vocabulary return {}