I have some problems with shap (python) in Google Colab notebook; I want to use ResNet50 with CIFAR100 dataset with Partition explainer; I get a session crash but I don't understand exactly why:
Here is the data import (it's not the relevant part of the issue)
import torch
import torchvision
import torchvision.transforms as transforms
import shap
import numpy as np
from torchvision import models
from torch.utils.data import DataLoader
from torch import nn, optim
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])
])
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
# use just a subset of data
subset_of_data = []
for it in range(100):
subset_of_data.append(np.array(testset[it][0]))
subset_of_data = [el.transpose(1, 2, 0) for el in subset_of_data]
subset_of_data = np.array(subset_of_data)
#subset_of_data = [image_instance[0] for image_instance in testset[0:100]]
print(subset_of_data.shape)
Now the important part - similar to this example but in my case I use CIFAR100
import json
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
import shap
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar100
import tensorflow as tf
# ResNet50 with input size for CIFAR100
model = ResNet50(weights="imagenet", input_shape = (32, 32, 3), include_top=False)
subset_of_data = np.clip(subset_of_data, 0, 255).astype(np.uint8)
# name of classes manually
class_names = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl',
'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar',
'cattle', 'chair', 'chimpanzee',
'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish',
'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp',
'lawn_mower', 'leopard', 'lion', 'lizard',
'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom',
'oak_tree', 'orange', 'orchid', 'otter',
'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate',
'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon',
'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk',
'skyscraper', 'snail', 'snake', 'spider', 'squirrel',
'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television',
'tiger', 'tractor', 'train', 'trout',
'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']
# Function to preprocess input for ResNet50
def f(x):
tmp = x.copy()
preprocess_input(tmp)
return model(tmp)
# Define a masker that is used to mask out partitions of the input image.
masker = shap.maskers.Image("inpaint_telea", subset_of_data[0].shape)
# Create an explainer with model and image masker
explainer = shap.Explainer(f, masker, output_names=class_names)
# Select images from the CIFAR-100 test set to explain
shap_values = explainer(
subset_of_data[1:2], max_evals=10, batch_size=5, outputs=shap.Explanation.argsort.flip[:4]
)
# Plot SHAP values
shap.image_plot(shap_values)
According to log file (image below), there is an allocation problem, but I cannot figure out why
