Skip to content

Bug in provided SGDLasso(), SGDRidge(), and SGDLinearRegression() classes? #910

Closed
@th789

Description

@th789

🐛 Bug

For the LimeBase example provided at https://captum.ai/api/lime.html (after making the updates in #908), using the provided SGDLasso(), SGDRidge(), and SGDLinearRegression() classes for the interpretable_model argument in LimeBase() leads to the following error message: "RuntimeError: expected scalar type Float but found Double" (more info below).

To Reproduce

import torch
import torch.nn as nn
from captum.attr import LimeBase
from captum._utils.models.linear_model import SkLearnLinearModel, SGDLasso, SGDRidge, SGDLinearRegression

class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 3)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

net = SimpleClassifier()

def similarity_kernel(original_input, perturbed_input, perturbed_interpretable_input, **kwargs):
      # kernel_width will be provided to attribute as a kwarg
      kernel_width = kwargs["kernel_width"]
      l2_dist = torch.norm(original_input - perturbed_input)
      return torch.exp(- (l2_dist**2) / (kernel_width**2))

def perturb_func(original_input, **kwargs):
      return original_input + torch.randn_like(original_input)

def to_interp_rep_transform(curr_sample, original_inp, **kwargs):
      return curr_sample

input = torch.randn(1, 5)

lime_attr = LimeBase(net,
                     interpretable_model=SGDLinearRegression(),
                     similarity_func=similarity_kernel,
                     perturb_func=perturb_func,
                     perturb_interpretable_space=False,
                     from_interp_rep_transform=None,
                     to_interp_rep_transform=to_interp_rep_transform)

attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1) 

The code above runs if interpretable_model=SkLearnLinearModel("linear_model.Ridge"), but does not run if interpretable_model=SGDLasso(), interpretable_model=SGDRidge(), or interpretable_model=SGDLinearRegression().

image

Expected behavior

There should be no error message and attr_coefs should return the feature attributions.

Environment

 - Captum version: 0.5.0
 - Pytorch version: 1.10.0+cu111
 - OS (e.g., Linux): macOS
 - How you installed Captum (`conda`, `pip`, source): 'conda' and 'pip' --> this error message arises whether I use `conda install captum -c pytorch` or `pip install captum` to install captum
 - Python version: 3.7.12

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions