diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 90b570d7b8..ad4b7a32d4 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -39,12 +39,13 @@ def _prepare_image(attr_visual: ndarray): def _normalize_scale(attr: ndarray, scale_factor: float): + assert scale_factor != 0, "Cannot normalize by scale factor = 0" if abs(scale_factor) < 1e-5: warnings.warn( - "Attempting to normalize by value approximately 0, skipping normalization." - "This likely means that attribution values are all close to 0." + "Attempting to normalize by value approximately 0, visualized results" + "may be misleading. This likely means that attribution values are all" + "close to 0." ) - return np.clip(attr, -1, 1) attr_norm = attr / scale_factor return np.clip(attr_norm, -1, 1)