Skip to content

Commit 9fe0a98

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Improve tokenizer pretty-pretty logic + __call__ method (pytorch#1417)
Summary: Use the __call__ method of tokenizers that returns a BatchEncoding with offsets. This allows us to grab text from the fully decoded string and not make assumptions about how many tokens correspond to a single string. Differential Revision: D64998804
1 parent 85d3130 commit 9fe0a98

File tree

5 files changed

+329
-21
lines changed

5 files changed

+329
-21
lines changed

captum/_utils/typing.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,23 @@ class TokenizerLike(Protocol):
5353
LLM attribution methods."""
5454

5555
@overload
56-
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
56+
def encode(
57+
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
58+
) -> List[int]: ...
59+
5760
@overload
58-
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
61+
def encode(
62+
self,
63+
text: str,
64+
add_special_tokens: bool = ...,
65+
return_tensors: Literal["pt"] = ...,
66+
) -> Tensor: ...
5967

6068
def encode(
61-
self, text: str, return_tensors: Optional[str] = None
69+
self,
70+
text: str,
71+
add_special_tokens: bool = True,
72+
return_tensors: Optional[str] = None,
6273
) -> Union[List[int], Tensor]: ...
6374

6475
def decode(self, token_ids: Tensor) -> str: ...
@@ -84,5 +95,6 @@ def convert_tokens_to_ids(
8495
def __call__(
8596
self,
8697
text: Optional[Union[str, List[str], List[List[str]]]] = None,
98+
add_special_tokens: bool = True,
8799
return_offsets_mapping: bool = False,
88100
) -> BatchEncodingType: ...

captum/attr/_core/llm_attr.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# pyre-strict
2+
3+
import warnings
4+
25
from copy import copy
36

47
from textwrap import shorten
@@ -216,6 +219,11 @@ def plot_seq_attr(
216219
return fig, ax
217220

218221

222+
def _clean_up_pretty_token(token: str) -> str:
223+
"""Remove newlines and leading/trailing whitespace from token."""
224+
return token.replace("\n", "\\n").strip()
225+
226+
219227
def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List[str]:
220228
"""
221229
Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
@@ -230,10 +238,63 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
230238
> BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm
231239
> used spaces in its process
232240
"""
241+
txt = tokenizer.decode(ids)
242+
# Don't add special tokens (they're either already there, or we don't want them)
243+
enc = tokenizer(txt, return_offsets_mapping=True, add_special_tokens=False)
244+
input_ids = cast(List[int], enc["input_ids"])
245+
offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"])
246+
247+
pretty_tokens = []
248+
end_prev = -1
249+
idx = 0
250+
for i, (input_id, offset) in enumerate(zip(input_ids, offset_mapping)):
251+
start, end = offset
252+
if start == end:
253+
# For the case where offsets are not set properly (the end and start are
254+
# equal for all tokens - fall back on the start of the next span in the
255+
# offset mapping)
256+
if (i + 1) < len(input_ids):
257+
end = offset_mapping[i + 1][0]
258+
else:
259+
end = len(txt)
260+
if input_id != ids[idx]:
261+
# When the re-encoded string doesn't match the original encoding we skip
262+
# this token and hope for the best, falling back on a naive method. This
263+
# can happen when a tokenizer might add a token that corresponds to
264+
# a space only when add_special_tokens=False.
265+
warnings.warn(
266+
f"(i={i}) input_id {input_id} != ids[idx] {ids[idx]} (corresponding "
267+
f"to text: {repr(txt[start:end])}). Skipping this token.",
268+
stacklevel=2,
269+
)
270+
continue
271+
pretty_tokens.append(
272+
_clean_up_pretty_token(txt[start:end])
273+
+ (" [OVERLAP]" if end_prev > start else "")
274+
)
275+
end_prev = end
276+
idx += 1
277+
if len(pretty_tokens) != len(ids):
278+
warnings.warn(
279+
f"Pretty tokens length {len(pretty_tokens)} != ids length {len(ids)}! "
280+
"Falling back to naive decoding logic.",
281+
stacklevel=2,
282+
)
283+
return _convert_ids_to_pretty_tokens_fallback(ids, tokenizer)
284+
return pretty_tokens
285+
286+
287+
def _convert_ids_to_pretty_tokens_fallback(
288+
ids: Tensor, tokenizer: TokenizerLike
289+
) -> List[str]:
290+
"""
291+
Fallback function that naively handles logic when multiple ids map to one string.
292+
"""
233293
pretty_tokens = []
234294
idx = 0
235295
while idx < len(ids):
236296
decoded = tokenizer.decode(ids[idx])
297+
decoded_pretty = _clean_up_pretty_token(decoded)
237298
# Handle case where single token (e.g. unicode) is split into multiple IDs
238299
# NOTE: This logic will fail if a tokenizer splits a token into 3+ IDs
239300
if decoded.strip() == "�" and tokenizer.encode(decoded) != [ids[idx]]:
@@ -244,17 +305,17 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
244305
]:
245306
# Both tokens are from a split, combine them
246307
decoded = tokenizer.decode(ids[idx : idx + 2])
247-
pretty_tokens.append(decoded + "[1/2]")
248-
pretty_tokens.append(decoded + "[2/2]")
308+
pretty_tokens.append(decoded_pretty)
309+
pretty_tokens.append(decoded_pretty + " [OVERLAP]")
249310
else:
250311
# Treat tokens as separate
251-
pretty_tokens.append(decoded)
252-
pretty_tokens.append(decoded_next)
312+
pretty_tokens.append(decoded_pretty)
313+
pretty_tokens.append(_clean_up_pretty_token(decoded_next))
253314
idx += 2
254315
else:
255316
# Just a normal token
256317
idx += 1
257-
pretty_tokens.append(decoded)
318+
pretty_tokens.append(decoded_pretty)
258319
return pretty_tokens
259320

260321

tests/attr/test_interpretable_input.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,23 @@ def __init__(self, vocab_list) -> None:
2020
self.unk_idx = len(vocab_list) + 1
2121

2222
@overload
23-
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
23+
def encode(
24+
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
25+
) -> List[int]: ...
26+
2427
@overload
25-
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
28+
def encode(
29+
self,
30+
text: str,
31+
add_special_tokens: bool = ...,
32+
return_tensors: Literal["pt"] = ...,
33+
) -> Tensor: ...
2634

2735
def encode(
28-
self, text: str, return_tensors: Optional[str] = "pt"
36+
self,
37+
text: str,
38+
add_special_tokens: bool = True,
39+
return_tensors: Optional[str] = "pt",
2940
) -> Union[List[int], Tensor]:
3041
assert return_tensors == "pt"
3142
return torch.tensor([self.convert_tokens_to_ids(text.split(" "))])
@@ -72,6 +83,7 @@ def decode(self, token_ids: Tensor) -> str:
7283
def __call__(
7384
self,
7485
text: Optional[Union[str, List[str], List[List[str]]]] = None,
86+
add_special_tokens: bool = True,
7587
return_offsets_mapping: bool = False,
7688
) -> BatchEncodingType:
7789
raise NotImplementedError

tests/attr/test_llm_attr.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# pyre-strict
44

55
import copy
6+
7+
from collections import UserDict
68
from typing import (
79
Any,
810
cast,
@@ -40,24 +42,38 @@ class DummyTokenizer:
4042
vocab_size: int = 256
4143
sos: int = 0
4244
unk: int = 1
43-
special_tokens: Dict[int, str] = {sos: "<sos>", unk: "<unk>"}
45+
sos_str: str = "<sos>"
46+
special_tokens: Dict[int, str] = {sos: sos_str, unk: "<unk>"}
4447

4548
@overload
46-
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
49+
def encode(
50+
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
51+
) -> List[int]: ...
52+
4753
@overload
48-
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
54+
def encode(
55+
self,
56+
text: str,
57+
add_special_tokens: bool = ...,
58+
return_tensors: Literal["pt"] = ...,
59+
) -> Tensor: ...
4960

5061
def encode(
51-
self, text: str, return_tensors: Optional[str] = None
62+
self,
63+
text: str,
64+
add_special_tokens: bool = True,
65+
return_tensors: Optional[str] = None,
5266
) -> Union[List[int], Tensor]:
5367
tokens = text.split(" ")
5468

5569
tokens_ids: Union[List[int], Tensor] = [
56-
ord(s[0]) if len(s) == 1 else self.unk for s in tokens
70+
ord(s[0]) if len(s) == 1 else (self.sos if s == self.sos_str else self.unk)
71+
for s in tokens
5772
]
5873

5974
# start with sos
60-
tokens_ids = [self.sos, *tokens_ids]
75+
if add_special_tokens:
76+
tokens_ids = [self.sos, *tokens_ids]
6177

6278
if return_tensors:
6379
return torch.tensor([tokens_ids])
@@ -100,9 +116,24 @@ def decode(self, token_ids: Tensor) -> str:
100116
def __call__(
101117
self,
102118
text: Optional[Union[str, List[str], List[List[str]]]] = None,
119+
add_special_tokens: bool = True,
103120
return_offsets_mapping: bool = False,
104121
) -> BatchEncodingType:
105-
raise NotImplementedError
122+
assert isinstance(text, str)
123+
input_ids = self.encode(text, add_special_tokens=add_special_tokens)
124+
125+
result: BatchEncodingType = UserDict()
126+
result["input_ids"] = input_ids
127+
128+
if return_offsets_mapping:
129+
offset_mapping = []
130+
idx = 0
131+
for token in text.split(" "):
132+
offset_mapping.append((idx - (0 if idx == 0 else 1), idx + len(token)))
133+
idx += len(token) + 1 # +1 for space
134+
result["offset_mapping"] = offset_mapping
135+
136+
return result
106137

107138

108139
class Result(NamedTuple):

0 commit comments

Comments
 (0)