Skip to content

Commit fd661d7

Browse files
csauperfacebook-github-bot
authored andcommitted
Condense long text in plot axis (pytorch#1349)
Summary: Pull Request resolved: pytorch#1349 Captum will display complete plot axis labels, which makes the plot unreadable if you have longer segments than words; cap the max length at 50 characters. Reviewed By: cyrjano Differential Revision: D62758379
1 parent 6636f4d commit fd661d7

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# pyre-strict
22
from copy import copy
33

4+
from textwrap import shorten
5+
46
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
57

68
import matplotlib.pyplot as plt
@@ -103,7 +105,10 @@ def plot_token_attr(
103105
cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom")
104106

105107
# Show all ticks and label them with the respective list entries.
106-
ax.set_xticks(np.arange(data.shape[1]), labels=self.input_tokens)
108+
shortened_tokens = [
109+
shorten(t, width=50, placeholder="...") for t in self.input_tokens
110+
]
111+
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
107112
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)
108113

109114
# Let the horizontal axes labeling appear on top.
@@ -149,7 +154,10 @@ def plot_seq_attr(
149154

150155
data = self.seq_attr.cpu().numpy()
151156

152-
ax.set_xticks(range(data.shape[0]), labels=self.input_tokens)
157+
shortened_tokens = [
158+
shorten(t, width=50, placeholder="...") for t in self.input_tokens
159+
]
160+
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
153161

154162
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
155163

0 commit comments

Comments
 (0)