Skip to content

Fix mypy issue in visualization.py #1416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/retry.yml
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ jobs:
echo "event: ${{ github.event.workflow_run.conclusion }}"
echo "event: ${{ github.event.workflow_run.event }}"
- name: Rerun Failed Workflows
if: github.event.workflow_run.conclusion == 'failure' && github.event.run_attempt <= 3
if: github.event.workflow_run.conclusion == 'failure' && github.event.workflow_run.run_attempt <= 3
env:
GH_TOKEN: ${{ github.token }}
RUN_ID: ${{ github.event.workflow_run.id }}
4 changes: 2 additions & 2 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
@@ -90,8 +90,8 @@ def _is_tuple(inputs: Tensor) -> Literal[False]: ...

@typing.overload
def _is_tuple(
inputs: TensorOrTupleOfTensorsGeneric,
) -> bool: ... # type: ignore
inputs: TensorOrTupleOfTensorsGeneric, # type: ignore
) -> bool: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
44 changes: 40 additions & 4 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,18 @@

# pyre-strict

from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union
from collections import UserDict
from typing import (
List,
Literal,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)

from torch import Tensor
from torch.nn import Module
@@ -30,17 +41,35 @@
]


# Necessary for Python >=3.7 and <3.9!
if TYPE_CHECKING:
BatchEncodingType = UserDict[Union[int, str], object]
else:
BatchEncodingType = UserDict


class TokenizerLike(Protocol):
"""A protocol for tokenizer-like objects that can be used with Captum
LLM attribution methods."""

@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
def encode(
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
) -> List[int]: ...

@overload
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
def encode(
self,
text: str,
add_special_tokens: bool = ...,
return_tensors: Literal["pt"] = ...,
) -> Tensor: ...

def encode(
self, text: str, return_tensors: Optional[str] = None
self,
text: str,
add_special_tokens: bool = True,
return_tensors: Optional[str] = None,
) -> Union[List[int], Tensor]: ...

def decode(self, token_ids: Tensor) -> str: ...
@@ -62,3 +91,10 @@ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...
def convert_tokens_to_ids(
self, tokens: Union[List[str], str]
) -> Union[List[int], int]: ...

def __call__(
self,
text: Optional[Union[str, List[str], List[List[str]]]] = None,
add_special_tokens: bool = True,
return_offsets_mapping: bool = False,
) -> BatchEncodingType: ...
71 changes: 66 additions & 5 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# pyre-strict

import warnings

from copy import copy

from textwrap import shorten
@@ -216,6 +219,11 @@ def plot_seq_attr(
return fig, ax


def _clean_up_pretty_token(token: str) -> str:
"""Remove newlines and leading/trailing whitespace from token."""
return token.replace("\n", "\\n").strip()


def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List[str]:
"""
Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
@@ -230,10 +238,63 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
> BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm
> used spaces in its process
"""
txt = tokenizer.decode(ids)
# Don't add special tokens (they're either already there, or we don't want them)
enc = tokenizer(txt, return_offsets_mapping=True, add_special_tokens=False)
input_ids = cast(List[int], enc["input_ids"])
offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"])

pretty_tokens = []
end_prev = -1
idx = 0
for i, (input_id, offset) in enumerate(zip(input_ids, offset_mapping)):
start, end = offset
if start == end:
# For the case where offsets are not set properly (the end and start are
# equal for all tokens - fall back on the start of the next span in the
# offset mapping)
if (i + 1) < len(input_ids):
end = offset_mapping[i + 1][0]
else:
end = len(txt)
if input_id != ids[idx]:
# When the re-encoded string doesn't match the original encoding we skip
# this token and hope for the best, falling back on a naive method. This
# can happen when a tokenizer might add a token that corresponds to
# a space only when add_special_tokens=False.
warnings.warn(
f"(i={i}) input_id {input_id} != ids[idx] {ids[idx]} (corresponding "
f"to text: {repr(txt[start:end])}). Skipping this token.",
stacklevel=2,
)
continue
pretty_tokens.append(
_clean_up_pretty_token(txt[start:end])
+ (" [OVERLAP]" if end_prev > start else "")
)
end_prev = end
idx += 1
if len(pretty_tokens) != len(ids):
warnings.warn(
f"Pretty tokens length {len(pretty_tokens)} != ids length {len(ids)}! "
"Falling back to naive decoding logic.",
stacklevel=2,
)
return _convert_ids_to_pretty_tokens_fallback(ids, tokenizer)
return pretty_tokens


def _convert_ids_to_pretty_tokens_fallback(
ids: Tensor, tokenizer: TokenizerLike
) -> List[str]:
"""
Fallback function that naively handles logic when multiple ids map to one string.
"""
pretty_tokens = []
idx = 0
while idx < len(ids):
decoded = tokenizer.decode(ids[idx])
decoded_pretty = _clean_up_pretty_token(decoded)
# Handle case where single token (e.g. unicode) is split into multiple IDs
# NOTE: This logic will fail if a tokenizer splits a token into 3+ IDs
if decoded.strip() == "�" and tokenizer.encode(decoded) != [ids[idx]]:
@@ -244,17 +305,17 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
]:
# Both tokens are from a split, combine them
decoded = tokenizer.decode(ids[idx : idx + 2])
pretty_tokens.append(decoded + "[1/2]")
pretty_tokens.append(decoded + "[2/2]")
pretty_tokens.append(decoded_pretty)
pretty_tokens.append(decoded_pretty + " [OVERLAP]")
else:
# Treat tokens as separate
pretty_tokens.append(decoded)
pretty_tokens.append(decoded_next)
pretty_tokens.append(decoded_pretty)
pretty_tokens.append(_clean_up_pretty_token(decoded_next))
idx += 2
else:
# Just a normal token
idx += 1
pretty_tokens.append(decoded)
pretty_tokens.append(decoded_pretty)
return pretty_tokens


14 changes: 9 additions & 5 deletions captum/attr/_utils/visualization.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
# pyre-strict
import warnings
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union

import matplotlib

@@ -444,7 +444,7 @@ def visualize_image_attr_multiple(
fig_size: Tuple[int, int] = (8, 6),
use_pyplot: bool = True,
**kwargs: Any,
) -> Tuple[Figure, Axes]:
) -> Tuple[Figure, Union[Axes, List[Axes]]]:
r"""
Visualizes attribution using multiple visualization methods displayed
in a 1 x k grid, where k is the number of desired visualizations.
@@ -516,15 +516,19 @@ def visualize_image_attr_multiple(
plt_fig = plt.figure(figsize=fig_size)
else:
plt_fig = Figure(figsize=fig_size)
plt_axis = plt_fig.subplots(1, len(methods))
plt_axis_np = plt_fig.subplots(1, len(methods), squeeze=True)

plt_axis: Union[Axes, List[Axes]]
plt_axis_list: List[Axes] = []
# When visualizing one
if len(methods) == 1:
plt_axis_list = [plt_axis] # type: ignore
plt_axis = cast(Axes, plt_axis_np)
plt_axis_list = [plt_axis]
# Figure.subplots returns Axes or array of Axes
else:
plt_axis_list = plt_axis # type: ignore
# https://github.com/numpy/numpy/issues/24738
plt_axis = cast(List[Axes], cast(npt.NDArray, plt_axis_np).tolist())
plt_axis_list = plt_axis
# Figure.subplots returns Axes or array of Axes

for i in range(len(methods)):
26 changes: 23 additions & 3 deletions tests/attr/test_interpretable_input.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from typing import List, Literal, Optional, overload, Union

import torch
from captum._utils.typing import BatchEncodingType
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
from parameterized import parameterized
from tests.helpers import BaseTest
@@ -19,12 +20,23 @@ def __init__(self, vocab_list) -> None:
self.unk_idx = len(vocab_list) + 1

@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
def encode(
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
) -> List[int]: ...

@overload
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
def encode(
self,
text: str,
add_special_tokens: bool = ...,
return_tensors: Literal["pt"] = ...,
) -> Tensor: ...

def encode(
self, text: str, return_tensors: Optional[str] = "pt"
self,
text: str,
add_special_tokens: bool = True,
return_tensors: Optional[str] = "pt",
) -> Union[List[int], Tensor]:
assert return_tensors == "pt"
return torch.tensor([self.convert_tokens_to_ids(text.split(" "))])
@@ -68,6 +80,14 @@ def convert_tokens_to_ids(
def decode(self, token_ids: Tensor) -> str:
raise NotImplementedError

def __call__(
self,
text: Optional[Union[str, List[str], List[List[str]]]] = None,
add_special_tokens: bool = True,
return_offsets_mapping: bool = False,
) -> BatchEncodingType:
raise NotImplementedError


class TestTextTemplateInput(BaseTest):
@parameterized.expand(
51 changes: 45 additions & 6 deletions tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
# pyre-strict

import copy

from collections import UserDict
from typing import (
Any,
cast,
@@ -19,6 +21,7 @@

import torch
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.typing import BatchEncodingType
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
@@ -39,24 +42,38 @@ class DummyTokenizer:
vocab_size: int = 256
sos: int = 0
unk: int = 1
special_tokens: Dict[int, str] = {sos: "<sos>", unk: "<unk>"}
sos_str: str = "<sos>"
special_tokens: Dict[int, str] = {sos: sos_str, unk: "<unk>"}

@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
def encode(
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
) -> List[int]: ...

@overload
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
def encode(
self,
text: str,
add_special_tokens: bool = ...,
return_tensors: Literal["pt"] = ...,
) -> Tensor: ...

def encode(
self, text: str, return_tensors: Optional[str] = None
self,
text: str,
add_special_tokens: bool = True,
return_tensors: Optional[str] = None,
) -> Union[List[int], Tensor]:
tokens = text.split(" ")

tokens_ids: Union[List[int], Tensor] = [
ord(s[0]) if len(s) == 1 else self.unk for s in tokens
ord(s[0]) if len(s) == 1 else (self.sos if s == self.sos_str else self.unk)
for s in tokens
]

# start with sos
tokens_ids = [self.sos, *tokens_ids]
if add_special_tokens:
tokens_ids = [self.sos, *tokens_ids]

if return_tensors:
return torch.tensor([tokens_ids])
@@ -96,6 +113,28 @@ def decode(self, token_ids: Tensor) -> str:
# pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`.
return tokens if isinstance(tokens, str) else " ".join(tokens)

def __call__(
self,
text: Optional[Union[str, List[str], List[List[str]]]] = None,
add_special_tokens: bool = True,
return_offsets_mapping: bool = False,
) -> BatchEncodingType:
assert isinstance(text, str)
input_ids = self.encode(text, add_special_tokens=add_special_tokens)

result: BatchEncodingType = UserDict()
result["input_ids"] = input_ids

if return_offsets_mapping:
offset_mapping = []
idx = 0
for token in text.split(" "):
offset_mapping.append((idx - (0 if idx == 0 else 1), idx + len(token)))
idx += len(token) + 1 # +1 for space
result["offset_mapping"] = offset_mapping

return result


class Result(NamedTuple):
logits: Tensor
Loading