Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ config.yaml
.nfs*
.env
.venv
.ipynb_checkpoints/
# Documentation build files
docs/_build/
docs/build/
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ You might encounter the error `with block not found at line xyz` when running th
- Create a git tag with the version number `git tag vx.y.z; git push origin vx.y.z`
- Build with `python -m build`
- Publish with e.g. `twine upload dist/*x.y.z*`
- test with `pytest --cache-clear`. **cache-clear is mandatory for now otherwise `NNsight`'s source can break.** It might not be sufficient, in which case you can do `make clean` to remove Python cache.
- Test with `uv run pytest nnterp/tests` or `uv run pytest nnterp/tests --model-names gpt2` to test with a specific model. Also you can use `uv run pytest nnterp/tests --class-names LlamaForCausalLM` to test with a specific class.
<!--commented out as it is likely fixed now - test with `pytest --cache-clear`. **cache-clear is mandatory for now otherwise `NNsight`'s source can break.** It might not be sufficient, in which case you can do `make clean` to remove Python cache. -->


## Citation
Expand Down
28 changes: 28 additions & 0 deletions nnterp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
from .standardized_vllm import StandardizedVLLM
from .standardized_transformer import StandardizedTransformer
from .rename_utils import get_rename_dict
from .nnsight_utils import ModuleAccessor


__all__ = [
"StandardizedTransformer",
"load_model",
"get_rename_dict",
"ModuleAccessor",
]


def load_model(
model: str, use_vllm: bool = False, **model_kwargs
) -> Union[StandardizedTransformer, "StandardizedVLLM"]:
"""
Load a model using the appropriate wrapper.

Args:
model: The model to load.
use_vllm: Whether to use the VLLM wrapper.
**model_kwargs: Keyword arguments to pass to the model wrapper.

Returns:
A StandardizedTransformer or StandardizedVLLM instance.
"""
if use_vllm:
from .standardized_vllm import StandardizedVLLM

return StandardizedVLLM(model, **model_kwargs)
else:
return StandardizedTransformer(model, **model_kwargs)
19 changes: 19 additions & 0 deletions nnterp/rename_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,3 +743,22 @@ def check_model_renaming(
allow_dispatch,
errors_to_raise=(RenamingError,),
)


HF_TO_VLLM_KWARGS_MAP = dict(
max_num_tokens="max_tokens",
)
VLLM_TO_HF_KWARGS_MAP = {v: k for k, v in HF_TO_VLLM_KWARGS_MAP.items()}


def hf_kwargs_to_vllm_kwargs(args, kwargs: dict) -> dict:
for k, v in kwargs.items():
if k in HF_TO_VLLM_KWARGS_MAP:
if VLLM_TO_HF_KWARGS_MAP[k] in kwargs:
if kwargs[VLLM_TO_HF_KWARGS_MAP[k]] != v:
raise ValueError(
f"Conflicting values for {VLLM_TO_HF_KWARGS_MAP[k]} and {k}, which correspond to the same argument in hf and vllm but are set to different values: {kwargs[VLLM_TO_HF_KWARGS_MAP[k]]} and {v}"
)
kwargs[VLLM_TO_HF_KWARGS_MAP[k]] = v

return kwargs
194 changes: 138 additions & 56 deletions nnterp/standardized_transformer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations
from typing import Callable
from loguru import logger
import torch as th
from torch.nn import Module
from torch import Size
from nnsight import LanguageModel
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from .utils import (
TraceTensor,
DummyCache,
Expand All @@ -28,21 +25,11 @@
)


class StandardizedTransformer(LanguageModel):
class StandardizationMixin:
"""
Renames the LanguageModel modules to match a standardized architecture.

The model structure is organized as follows::

StandardizedTransformer
├── embed_tokens
├── layers
│ ├── self_attn
│ └── mlp
├── ln_final
└── lm_head
Mixin class for standardizing the architecture of a model.

In addition to renaming modules, this class provides built-in accessors to extract and set intermediate activations:
This class provides built-in accessors to extract and set intermediate activations:

- embed_tokens: Get embedding module
- token_embeddings: Get/set token embeddings (equivalent to embed_tokens.output)
Expand All @@ -55,9 +42,7 @@ class StandardizedTransformer(LanguageModel):
- mlps_input[i] / mlps_output[i]: Get/set MLP input/output at layer i

Args:
repo_id (str): Hugging Face repository ID or path of the model to load.
trust_remote_code (bool, optional): If True, remote code will be trusted when
loading the model. Defaults to False.
model (str or Module): Hugging Face repository ID or path of the model to load or loaded model.
check_renaming (bool, default True): If True, the renaming of modules is validated.
Defaults to True.
allow_dispatch (bool, default True): If True, allows using trace() to dispatch the model
Expand All @@ -73,47 +58,18 @@ class StandardizedTransformer(LanguageModel):
num_heads: int
hidden_size: int
vocab_size: int
is_vllm: bool

def __init__(
def _init_standardization(
self,
model: str | Module,
trust_remote_code: bool = False,
check_renaming: bool = True,
allow_dispatch: bool = True,
enable_attention_probs: bool = False,
check_attn_probs_with_trace: bool = True,
rename_config: RenameConfig | None = None,
**kwargs,
):
kwargs.setdefault("device_map", "auto")
if "attn_implementation" in kwargs and enable_attention_probs:
if kwargs["attn_implementation"] != "eager":
raise ValueError(
f"Cannot use attn_implementation='{kwargs['attn_implementation']}' with enable_attention_probs=True. "
"Either set enable_attention_probs=False or don't pass attn_implementation."
)
attn_implementation = (
"eager"
if enable_attention_probs
else kwargs.pop("attn_implementation", None)
)

tokenizer_kwargs = kwargs.pop("tokenizer_kwargs", {})
rename = get_rename_dict(rename_config=rename_config)
user_rename = kwargs.pop("rename", None)
if user_rename is not None:
logger.info(
f"Updating default rename with user-provided rename: {user_rename}"
)
rename.update(user_rename)
super().__init__(
model,
attn_implementation=attn_implementation,
tokenizer_kwargs=tokenizer_kwargs,
trust_remote_code=trust_remote_code,
rename=rename,
**kwargs,
)
"""Initialize standardization after the base model has been initialized."""
if isinstance(model, str):
model_name = model
else:
Expand Down Expand Up @@ -149,6 +105,10 @@ def __init__(
rename_config=rename_config,
initialized_with_enable=enable_attention_probs,
)
if self.is_vllm and enable_attention_probs:
raise NotImplementedError(
"nnterp VLLM wrapper doesn't support attention probabilities yet, please set enable_attention_probs=False."
)
if check_renaming and enable_attention_probs:
self.attention_probabilities.check_source(
allow_dispatch=allow_dispatch,
Expand All @@ -159,6 +119,19 @@ def __init__(
self.attention_probabilities.disable()
self._add_prefix_false_tokenizer = None

def _get_rename(
self,
rename_config: RenameConfig | None = None,
user_rename: dict[str, str] | None = None,
):
rename = get_rename_dict(rename_config=rename_config)
if user_rename is not None:
logger.info(
f"Updating default rename with user-provided rename: {user_rename}"
)
rename.update(user_rename)
return rename

def detect_layer_output_type(self):
if self.layers_output.returns_tuple is None:

Expand All @@ -176,6 +149,13 @@ def test_layer_0():

@property
def add_prefix_false_tokenizer(self) -> PreTrainedTokenizerBase:
"""
Returns the tokenizer with add_prefix_space=False. Which means that "word" and " word" will be tokenized as different tokens.
"""
if self.is_vllm:
raise ValueError(
"nnterp VLLM wrapper doesn't support add_prefix_space=False, the normal tokenizer might already work but it might be model dependent."
)
if self._add_prefix_false_tokenizer is None:
self._add_prefix_false_tokenizer = AutoTokenizer.from_pretrained(
self.name_or_path, add_prefix_space=False
Expand All @@ -188,18 +168,30 @@ def attn_probs_available(self) -> bool:

@property
def input_ids(self) -> TraceTensor:
if self.is_vllm:
raise NotImplementedError(
"input_ids is not supported for VLLM models as it is flattened and without padding."
)
return self.inputs[1]["input_ids"]

@property
def input_size(self) -> Size:
"""
Returns the shape of the input tensor (batch_size, sequence_length)
"""
if self.is_vllm:
raise NotImplementedError(
"input_size is not supported for VLLM models as it is flattened and without padding."
)
return self.input_ids.shape

@property
def attention_mask(self) -> TraceTensor:
"""Returns the attention mask tensor."""
if self.is_vllm:
raise NotImplementedError(
"attention_mask is not supported yet for VLLM models as it's not in the inputs dictionary."
)
return self.inputs[1]["attention_mask"]

@property
Expand All @@ -212,11 +204,6 @@ def token_embeddings(self, value: TraceTensor):
"""Sets the token embeddings. Equivalent to self.embed_tokens.output = value"""
self.embed_tokens.output = value

@property
def logits(self) -> TraceTensor:
"""Returns the predicted logits."""
return self.output.logits

@property
def next_token_probs(self) -> TraceTensor:
"""Returns the predicted probabilities for the next token.
Expand Down Expand Up @@ -339,3 +326,98 @@ def get_topk_closest_tokens(
raise ValueError(
f"Unsupported hidden state shape {hidden_state.shape}. Expected 1D or 2D tensor."
)


class StandardizedTransformer(LanguageModel, StandardizationMixin):
"""
Renames the LanguageModel modules to match a standardized architecture.

The model structure is organized as follows::

StandardizedTransformer
├── embed_tokens
├── layers
│ ├── self_attn
│ └── mlp
├── ln_final
└── lm_head

The following properties are also available:

- num_layers: int
- num_heads: int
- hidden_size: int
- vocab_size: int

In addition to renaming modules, this class provides built-in accessors to extract and set intermediate activations:

- embed_tokens: Get embedding module
- token_embeddings: Get/set token embeddings (equivalent to embed_tokens.output)
- layers[i]: Get layer module at layer i
- layers_input[i]: Get/set layer input at layer i
- layers_output[i]: Get/set layer output at layer i
- attentions[i]: Get attention module at layer i
- attentions_input[i] / attentions_output[i]: Get/set attention input/output at layer i
- mlps[i]: Get MLP module at layer i
- mlps_input[i] / mlps_output[i]: Get/set MLP input/output at layer i

Args:
model (str or Module): Hugging Face repository ID or path of the model to load or loaded model.
check_renaming (bool, default True): If True, the renaming of modules is validated.
Defaults to True.
allow_dispatch (bool, default True): If True, allows using trace() to dispatch the model
when scan() fails during renaming checks. Defaults to True. You should set this to false
if you plan to use the model remotely.
enable_attention_probs (bool, default False): If True, enables attention probabilities
tracing by setting attn_implementation="eager". Defaults to False.
check_attn_probs_with_trace (bool, default True): If True, the model will be dispatched and a test will ensure that the attention probabilities returned sum to 1.
rename_config (RenameConfig, default None): A RenameConfig object to use for renaming the model. If None, a default RenameConfig will be used.
"""

is_vllm: bool = False

def __init__(
self,
model: str | Module,
check_renaming: bool = True,
allow_dispatch: bool = True,
enable_attention_probs: bool = False,
check_attn_probs_with_trace: bool = True,
rename_config: RenameConfig | None = None,
**kwargs,
):
kwargs.setdefault("device_map", "auto")
if "attn_implementation" in kwargs and enable_attention_probs:
if kwargs["attn_implementation"] != "eager":
raise ValueError(
f"Cannot use attn_implementation='{kwargs['attn_implementation']}' with enable_attention_probs=True. "
"Either set enable_attention_probs=False or don't pass attn_implementation."
)
attn_implementation = (
"eager"
if enable_attention_probs
else kwargs.pop("attn_implementation", None)
)

rename = self._get_rename(
rename_config=rename_config, user_rename=kwargs.pop("rename", None)
)
super().__init__(
model,
attn_implementation=attn_implementation,
rename=rename,
**kwargs,
)
self._init_standardization(
model=model,
check_renaming=check_renaming,
allow_dispatch=allow_dispatch,
enable_attention_probs=enable_attention_probs,
check_attn_probs_with_trace=check_attn_probs_with_trace,
rename_config=rename_config,
)

@property
def logits(self) -> TraceTensor:
"""Returns the predicted logits."""
return self.output.logits
Loading