1
1
# pyre-strict
2
+
3
+ import warnings
4
+
2
5
from copy import copy
3
6
4
7
from textwrap import shorten
@@ -216,6 +219,11 @@ def plot_seq_attr(
216
219
return fig , ax
217
220
218
221
222
+ def _clean_up_pretty_token (token : str ) -> str :
223
+ """Remove newlines and leading/trailing whitespace from token."""
224
+ return token .replace ("\n " , "\\ n" ).strip ()
225
+
226
+
219
227
def _convert_ids_to_pretty_tokens (ids : Tensor , tokenizer : TokenizerLike ) -> List [str ]:
220
228
"""
221
229
Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
@@ -230,10 +238,63 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
230
238
> BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm
231
239
> used spaces in its process
232
240
"""
241
+ txt = tokenizer .decode (ids )
242
+ # Don't add special tokens (they're either already there, or we don't want them)
243
+ enc = tokenizer (txt , return_offsets_mapping = True , add_special_tokens = False )
244
+ input_ids = cast (List [int ], enc ["input_ids" ])
245
+ offset_mapping = cast (List [Tuple [int , int ]], enc ["offset_mapping" ])
246
+
247
+ pretty_tokens = []
248
+ end_prev = - 1
249
+ idx = 0
250
+ for i , (input_id , offset ) in enumerate (zip (input_ids , offset_mapping )):
251
+ start , end = offset
252
+ if start == end :
253
+ # For the case where offsets are not set properly (the end and start are
254
+ # equal for all tokens - fall back on the start of the next span in the
255
+ # offset mapping)
256
+ if (i + 1 ) < len (input_ids ):
257
+ end = offset_mapping [i + 1 ][0 ]
258
+ else :
259
+ end = len (txt )
260
+ if input_id != ids [idx ]:
261
+ # When the re-encoded string doesn't match the original encoding we skip
262
+ # this token and hope for the best, falling back on a naive method. This
263
+ # can happen when a tokenizer might add a token that corresponds to
264
+ # a space only when add_special_tokens=False.
265
+ warnings .warn (
266
+ f"(i={ i } ) input_id { input_id } != ids[i] { ids [i ]} (corresponding to "
267
+ f"text: { repr (txt [start :end ])} ). Skipping this token." ,
268
+ stacklevel = 2 ,
269
+ )
270
+ continue
271
+ pretty_tokens .append (
272
+ _clean_up_pretty_token (txt [start :end ])
273
+ + (" [OVERLAP]" if end_prev > start else "" )
274
+ )
275
+ end_prev = end
276
+ idx += 1
277
+ if len (pretty_tokens ) != len (ids ):
278
+ warnings .warn (
279
+ f"Pretty tokens length { len (pretty_tokens )} != ids length { len (ids )} ! "
280
+ "Falling back to naive decoding logic." ,
281
+ stacklevel = 2 ,
282
+ )
283
+ return _convert_ids_to_pretty_tokens_fallback (ids , tokenizer )
284
+ return pretty_tokens
285
+
286
+
287
+ def _convert_ids_to_pretty_tokens_fallback (
288
+ ids : Tensor , tokenizer : TokenizerLike
289
+ ) -> List [str ]:
290
+ """
291
+ Fallback function that naively handles logic when multiple ids map to one string.
292
+ """
233
293
pretty_tokens = []
234
294
idx = 0
235
295
while idx < len (ids ):
236
296
decoded = tokenizer .decode (ids [idx ])
297
+ decoded_pretty = _clean_up_pretty_token (decoded )
237
298
# Handle case where single token (e.g. unicode) is split into multiple IDs
238
299
# NOTE: This logic will fail if a tokenizer splits a token into 3+ IDs
239
300
if decoded .strip () == "�" and tokenizer .encode (decoded ) != [ids [idx ]]:
@@ -244,17 +305,17 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
244
305
]:
245
306
# Both tokens are from a split, combine them
246
307
decoded = tokenizer .decode (ids [idx : idx + 2 ])
247
- pretty_tokens .append (decoded + "[1/2]" )
248
- pretty_tokens .append (decoded + "[2/2 ]" )
308
+ pretty_tokens .append (decoded_pretty )
309
+ pretty_tokens .append (decoded_pretty + " [OVERLAP ]" )
249
310
else :
250
311
# Treat tokens as separate
251
- pretty_tokens .append (decoded )
252
- pretty_tokens .append (decoded_next )
312
+ pretty_tokens .append (decoded_pretty )
313
+ pretty_tokens .append (_clean_up_pretty_token ( decoded_next ) )
253
314
idx += 2
254
315
else :
255
316
# Just a normal token
256
317
idx += 1
257
- pretty_tokens .append (decoded )
318
+ pretty_tokens .append (decoded_pretty )
258
319
return pretty_tokens
259
320
260
321
0 commit comments