Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a5ac2cd

Browse files
craymichaelfacebook-github-bot
authored andcommittedOct 25, 2024·
Improve tokenizer pretty-pretty logic + __call__ method (#1417)
Summary: Use the __call__ method of tokenizers that returns a BatchEncoding with offsets. This allows us to grab text from the fully decoded string and not make assumptions about how many tokens correspond to a single string. Differential Revision: D64998804
1 parent 8a49bcb commit a5ac2cd

File tree

5 files changed

+306
-17
lines changed

5 files changed

+306
-17
lines changed
 

‎captum/_utils/typing.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,23 @@ class TokenizerLike(Protocol):
5353
LLM attribution methods."""
5454

5555
@overload
56-
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
56+
def encode(
57+
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
58+
) -> List[int]: ...
59+
5760
@overload
58-
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
61+
def encode(
62+
self,
63+
text: str,
64+
add_special_tokens: bool = ...,
65+
return_tensors: Literal["pt"] = ...,
66+
) -> Tensor: ...
5967

6068
def encode(
61-
self, text: str, return_tensors: Optional[str] = None
69+
self,
70+
text: str,
71+
add_special_tokens: bool = True,
72+
return_tensors: Optional[str] = None,
6273
) -> Union[List[int], Tensor]: ...
6374

6475
def decode(self, token_ids: Tensor) -> str: ...
@@ -84,5 +95,6 @@ def convert_tokens_to_ids(
8495
def __call__(
8596
self,
8697
text: Optional[Union[str, List[str], List[List[str]]]] = None,
98+
add_special_tokens: bool = True,
8799
return_offsets_mapping: bool = False,
88100
) -> BatchEncodingType: ...

‎captum/attr/_core/llm_attr.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# pyre-strict
2+
3+
import warnings
4+
25
from copy import copy
36

47
from textwrap import shorten
@@ -216,6 +219,11 @@ def plot_seq_attr(
216219
return fig, ax
217220

218221

222+
def _clean_up_pretty_token(token: str) -> str:
223+
"""Remove newlines and leading/trailing whitespace from token."""
224+
return token.replace("\n", "\\n").strip()
225+
226+
219227
def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List[str]:
220228
"""
221229
Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
@@ -230,10 +238,63 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
230238
> BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm
231239
> used spaces in its process
232240
"""
241+
txt = tokenizer.decode(ids)
242+
# Don't add special tokens (they're either already there, or we don't want them)
243+
enc = tokenizer(txt, return_offsets_mapping=True, add_special_tokens=False)
244+
input_ids = cast(List[int], enc["input_ids"])
245+
offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"])
246+
247+
pretty_tokens = []
248+
end_prev = -1
249+
idx = 0
250+
for i, (input_id, offset) in enumerate(zip(input_ids, offset_mapping)):
251+
start, end = offset
252+
if start == end:
253+
# For the case where offsets are not set properly (the end and start are
254+
# equal for all tokens - fall back on the start of the next span in the
255+
# offset mapping)
256+
if (i + 1) < len(input_ids):
257+
end = offset_mapping[i + 1][0]
258+
else:
259+
end = len(txt)
260+
if input_id != ids[idx]:
261+
# When the re-encoded string doesn't match the original encoding we skip
262+
# this token and hope for the best, falling back on a naive method. This
263+
# can happen when a tokenizer might add a token that corresponds to
264+
# a space only when add_special_tokens=False.
265+
warnings.warn(
266+
f"(i={i}) input_id {input_id} != ids[i] {ids[i]} (corresponding to "
267+
f"text: {repr(txt[start:end])}). Skipping this token.",
268+
stacklevel=2,
269+
)
270+
continue
271+
pretty_tokens.append(
272+
_clean_up_pretty_token(txt[start:end])
273+
+ (" [OVERLAP]" if end_prev > start else "")
274+
)
275+
end_prev = end
276+
idx += 1
277+
if len(pretty_tokens) != len(ids):
278+
warnings.warn(
279+
f"Pretty tokens length {len(pretty_tokens)} != ids length {len(ids)}! "
280+
"Falling back to naive decoding logic.",
281+
stacklevel=2,
282+
)
283+
return _convert_ids_to_pretty_tokens_fallback(ids, tokenizer)
284+
return pretty_tokens
285+
286+
287+
def _convert_ids_to_pretty_tokens_fallback(
288+
ids: Tensor, tokenizer: TokenizerLike
289+
) -> List[str]:
290+
"""
291+
Fallback function that naively handles logic when multiple ids map to one string.
292+
"""
233293
pretty_tokens = []
234294
idx = 0
235295
while idx < len(ids):
236296
decoded = tokenizer.decode(ids[idx])
297+
decoded_pretty = _clean_up_pretty_token(decoded)
237298
# Handle case where single token (e.g. unicode) is split into multiple IDs
238299
# NOTE: This logic will fail if a tokenizer splits a token into 3+ IDs
239300
if decoded.strip() == "�" and tokenizer.encode(decoded) != [ids[idx]]:
@@ -244,17 +305,17 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
244305
]:
245306
# Both tokens are from a split, combine them
246307
decoded = tokenizer.decode(ids[idx : idx + 2])
247-
pretty_tokens.append(decoded + "[1/2]")
248-
pretty_tokens.append(decoded + "[2/2]")
308+
pretty_tokens.append(decoded_pretty)
309+
pretty_tokens.append(decoded_pretty + " [OVERLAP]")
249310
else:
250311
# Treat tokens as separate
251-
pretty_tokens.append(decoded)
252-
pretty_tokens.append(decoded_next)
312+
pretty_tokens.append(decoded_pretty)
313+
pretty_tokens.append(_clean_up_pretty_token(decoded_next))
253314
idx += 2
254315
else:
255316
# Just a normal token
256317
idx += 1
257-
pretty_tokens.append(decoded)
318+
pretty_tokens.append(decoded_pretty)
258319
return pretty_tokens
259320

260321

‎tests/attr/test_interpretable_input.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,23 @@ def __init__(self, vocab_list) -> None:
2020
self.unk_idx = len(vocab_list) + 1
2121

2222
@overload
23-
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
23+
def encode(
24+
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
25+
) -> List[int]: ...
26+
2427
@overload
25-
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
28+
def encode(
29+
self,
30+
text: str,
31+
add_special_tokens: bool = ...,
32+
return_tensors: Literal["pt"] = ...,
33+
) -> Tensor: ...
2634

2735
def encode(
28-
self, text: str, return_tensors: Optional[str] = "pt"
36+
self,
37+
text: str,
38+
add_special_tokens: bool = True,
39+
return_tensors: Optional[str] = "pt",
2940
) -> Union[List[int], Tensor]:
3041
assert return_tensors == "pt"
3142
return torch.tensor([self.convert_tokens_to_ids(text.split(" "))])
@@ -72,6 +83,7 @@ def decode(self, token_ids: Tensor) -> str:
7283
def __call__(
7384
self,
7485
text: Optional[Union[str, List[str], List[List[str]]]] = None,
86+
add_special_tokens: bool = True,
7587
return_offsets_mapping: bool = False,
7688
) -> BatchEncodingType:
7789
raise NotImplementedError

‎tests/attr/test_llm_attr.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,23 @@ class DummyTokenizer:
4343
special_tokens: Dict[int, str] = {sos: "<sos>", unk: "<unk>"}
4444

4545
@overload
46-
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
46+
def encode(
47+
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
48+
) -> List[int]: ...
49+
4750
@overload
48-
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
51+
def encode(
52+
self,
53+
text: str,
54+
add_special_tokens: bool = ...,
55+
return_tensors: Literal["pt"] = ...,
56+
) -> Tensor: ...
4957

5058
def encode(
51-
self, text: str, return_tensors: Optional[str] = None
59+
self,
60+
text: str,
61+
add_special_tokens: bool = True,
62+
return_tensors: Optional[str] = None,
5263
) -> Union[List[int], Tensor]:
5364
tokens = text.split(" ")
5465

@@ -100,6 +111,7 @@ def decode(self, token_ids: Tensor) -> str:
100111
def __call__(
101112
self,
102113
text: Optional[Union[str, List[str], List[List[str]]]] = None,
114+
add_special_tokens: bool = True,
103115
return_offsets_mapping: bool = False,
104116
) -> BatchEncodingType:
105117
raise NotImplementedError

0 commit comments

Comments
 (0)
Please sign in to comment.