Closed
Description
🐛 Bug
For the LimeBase example provided at https://captum.ai/api/lime.html (after making the updates in #908), using the provided SGDLasso()
, SGDRidge()
, and SGDLinearRegression()
classes for the interpretable_model
argument in LimeBase()
leads to the following error message: "RuntimeError: expected scalar type Float but found Double" (more info below).
To Reproduce
import torch
import torch.nn as nn
from captum.attr import LimeBase
from captum._utils.models.linear_model import SkLearnLinearModel, SGDLasso, SGDRidge, SGDLinearRegression
class SimpleClassifier(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 3)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.linear(x)
out = self.sigmoid(out)
return out
net = SimpleClassifier()
def similarity_kernel(original_input, perturbed_input, perturbed_interpretable_input, **kwargs):
# kernel_width will be provided to attribute as a kwarg
kernel_width = kwargs["kernel_width"]
l2_dist = torch.norm(original_input - perturbed_input)
return torch.exp(- (l2_dist**2) / (kernel_width**2))
def perturb_func(original_input, **kwargs):
return original_input + torch.randn_like(original_input)
def to_interp_rep_transform(curr_sample, original_inp, **kwargs):
return curr_sample
input = torch.randn(1, 5)
lime_attr = LimeBase(net,
interpretable_model=SGDLinearRegression(),
similarity_func=similarity_kernel,
perturb_func=perturb_func,
perturb_interpretable_space=False,
from_interp_rep_transform=None,
to_interp_rep_transform=to_interp_rep_transform)
attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)
The code above runs if interpretable_model=SkLearnLinearModel("linear_model.Ridge")
, but does not run if interpretable_model=SGDLasso()
, interpretable_model=SGDRidge()
, or interpretable_model=SGDLinearRegression()
.
Expected behavior
There should be no error message and attr_coefs
should return the feature attributions.
Environment
- Captum version: 0.5.0
- Pytorch version: 1.10.0+cu111
- OS (e.g., Linux): macOS
- How you installed Captum (`conda`, `pip`, source): 'conda' and 'pip' --> this error message arises whether I use `conda install captum -c pytorch` or `pip install captum` to install captum
- Python version: 3.7.12
Metadata
Metadata
Assignees
Labels
No labels