Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix label option #3636

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Fix label option #3636

wants to merge 5 commits into from

Conversation

SFatemehM
Copy link

@SFatemehM SFatemehM commented May 1, 2024

Fixes label option for more than 1 row

Overview

When there is more than one pixel_values as test images in shap.plots.image and the labels option is defined, there will be more than 1 rows in the plot and when setting titles this error rises:

IndexError: index 1 is out of bounds for axis 0 with size 1

It can be fixed by checking row==0 and setting titles for the first row only.

Checklist

  • All pre-commit checks pass.
  • Unit tests added (if fixing a bug or adding a new feature)

fix label option for more than 1 row - convert shap_value to list with compatible dimensions for plotting
@connortann connortann added the visualization Relating to plotting label May 3, 2024
Fix label option for more than 1 row
@SFatemehM SFatemehM changed the title Update _image.py Fix label option May 3, 2024
@CloseChoice
Copy link
Collaborator

@SFatemehM thanks for the PR, would like to review this. Do you have a code snippet I could use to test this?

@SFatemehM
Copy link
Author

@SFatemehM thanks for the PR, would like to review this. Do you have a code snippet I could use to test this?

@CloseChoice
Thank you, of course.
I was running the PyTorch Deep Explainer MNIST example and I wanted to add class labels on top of the plot. My code is the same as the tutorial, except for the last 3 lines:

shap_numpy = list(np.array(shap_values).transpose(4, 0, 2, 3, 1))
test_numpy = np.array(test_images).transpose(0, 2, 3, 1)

labels = [f'{x}' for x in range(10)]
shap.image_plot(shap_numpy, -test_numpy, labels=labels)

So the plot would be like this:
mnist_shap_labels

Full code snippet:

#%%
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import shap

#%%
device = torch.device("cpu")

batch_size = 128
num_epochs = 2

#%%
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5),
            nn.Dropout(),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(50, 10),
            nn.Softmax(dim=1),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 320)
        x = self.fc_layers(x)
        return x
#%%
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output.log(), target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(
                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}"
                f" ({100.0 * batch_idx / len(train_loader):.0f}%)]"
                f"\tLoss: {loss.item():.6f}"
            )
#%%
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output.log(), target).item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[
                1
            ]  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(
        f"\nTest set: Average loss: {test_loss:.4f},"
        f" Accuracy: {correct}/{len(test_loader.dataset)}"
        f" ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
    )

#%%
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()]),
    ),
    batch_size=batch_size,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data", train=False, transform=transforms.Compose([transforms.ToTensor()])
    ),
    batch_size=batch_size,
    shuffle=True,
)
#%%
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

#%%
batch = next(iter(test_loader))
images, labels = batch

background = images[:100]
test_images = images[100:103]

e = shap.DeepExplainer(model, background)
shap_values = e.shap_values(test_images)

shap_numpy = list(np.array(shap_values).transpose(4, 0, 2, 3, 1))
test_numpy = np.array(test_images).transpose(0, 2, 3, 1)

labels = [f'{x}' for x in range(10)]
shap.image_plot(shap_numpy, -test_numpy, labels=labels)

Copy link
Collaborator

@CloseChoice CloseChoice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good. Since we do not add a unit test I would like to get the notebook
PyTorch Deep Explainer MNIST example back to work. You transposed the axis of the shap values correctly, could you please add that in the notebook?

And to test that the notebook actually runs, could you please try to remove the notebook from here? Not sure if that works, I guess it'll fail. If this fails, we should add a unit test that this passes and future code changes won't break this

@@ -158,7 +158,8 @@ def image(shap_values: Explanation or np.ndarray,
max_val = np.nanpercentile(abs_vals, 99.9)
for i in range(len(shap_values)):
if labels is not None:
axes[row, i + 1].set_title(labels[row, i], **label_kwargs)
if row==0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if row==0:
if row == 0:

@CloseChoice
Copy link
Collaborator

CloseChoice commented May 26, 2024

Just looked at this and I think it would be really good if we could test this function. I hope we can somehow write a test that is shorter than the code you posted but still tests the behaviour.

EDIT: I think about either just copying one image array and pasting it into the test and then just create some dummy shap values from there (e.g. set shap values to value x if image value is above y, etc.) but am totally open to suggestions here. Feel free to ping me if you need help with that. Would be great if you could help with this, otherwise I am also willing to just merge your changes

transpose the axis of test images and shap values
@SFatemehM
Copy link
Author

This looks pretty good. Since we do not add a unit test I would like to get the notebook PyTorch Deep Explainer MNIST example back to work. You transposed the axis of the shap values correctly, could you please add that in the notebook?

And to test that the notebook actually runs, could you please try to remove the notebook from here? Not sure if that works, I guess it'll fail. If this fails, we should add a unit test that this passes and future code changes won't break this

I added shap_values.transpose() and test_images.transpose() to the notebook. But I didn't understand the second part, do you mean trying to remove this line? :
Path("image_examples/image_classification/PyTorch Deep Explainer MNIST example.ipynb"),

@CloseChoice
Copy link
Collaborator

This looks pretty good. Since we do not add a unit test I would like to get the notebook PyTorch Deep Explainer MNIST example back to work. You transposed the axis of the shap values correctly, could you please add that in the notebook?
And to test that the notebook actually runs, could you please try to remove the notebook from here? Not sure if that works, I guess it'll fail. If this fails, we should add a unit test that this passes and future code changes won't break this

I added shap_values.transpose() and test_images.transpose() to the notebook. But I didn't understand the second part, do you mean trying to remove this line? : Path("image_examples/image_classification/PyTorch Deep Explainer MNIST example.ipynb"),

Yes, that's what I mean. If this does not pass, then that means that the notebook takes too long to execute and I would really like to have a test for the function call you do.

@SFatemehM
Copy link
Author

Just looked at this and I think it would be really good if we could test this function. I hope we can somehow write a test that is shorter than the code you posted but still tests the behaviour.

EDIT: I think about either just copying one image array and pasting it into the test and then just create some dummy shap values from there (e.g. set shap values to value x if image value is above y, etc.) but am totally open to suggestions here. Feel free to ping me if you need help with that. Would be great if you could help with this, otherwise I am also willing to just merge your changes

How about test it with random values? a function like this for example:

def random_image_plot(n_images=3, n_classes=5):
    img_test = np.random.randn(n_images, 20, 20, 1)
    shap_test = list(np.random.randn(n_classes, n_images, 20, 20, 1))
    labels = [f'L{x+1}' for x in range(n_classes)]
    shap.image_plot(shap_test, img_test, labels=labels)

random_image_plot()

Do you think it can help? I can get the plot from this, but if it isn't useful for testing, I'll try and write it as you suggested.

removed PyTorch Deep Explainer MNIST example.ipynb
@CloseChoice
Copy link
Collaborator

Just looked at this and I think it would be really good if we could test this function. I hope we can somehow write a test that is shorter than the code you posted but still tests the behaviour.
EDIT: I think about either just copying one image array and pasting it into the test and then just create some dummy shap values from there (e.g. set shap values to value x if image value is above y, etc.) but am totally open to suggestions here. Feel free to ping me if you need help with that. Would be great if you could help with this, otherwise I am also willing to just merge your changes

How about test it with random values? a function like this for example:

def random_image_plot(n_images=3, n_classes=5):
    img_test = np.random.randn(n_images, 20, 20, 1)
    shap_test = list(np.random.randn(n_classes, n_images, 20, 20, 1))
    labels = [f'L{x+1}' for x in range(n_classes)]
    shap.image_plot(shap_test, img_test, labels=labels)

random_image_plot()

Do you think it can help? I can get the plot from this, but if it isn't useful for testing, I'll try and write it as you suggested.

I would like to have a test that we can decorate with @pytest.mark.mpl_image_compare like this one. For these tests the pytest plugin generates a figure and compares it with the one from the baselines folder, in this case: https://github.com/shap/shap/blob/master/tests/plots/baseline/test_simple_beeswarm.png

Then we have something to look at. I guess you could also use the imagenet function to load images

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
visualization Relating to plotting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants