Download Latest Version v0.10.0 Fix Multi Label Activations source code.tar.gz (7.7 MB)
Email in envelope

Get an email when there's a new version of Transformers-Interpret

Home / v0.9.5
Name Modified Size InfoDownloads / Week
Parent folder
ImageClassificationExplainer source code.tar.gz 2022-10-17 7.7 MB
ImageClassificationExplainer source code.zip 2022-10-17 7.7 MB
README.md 2022-10-17 4.9 kB
Totals: 3 Items   15.4 MB 0

This is a hugely exciting release for us as it is our first foray into the domain of computer vision. With this update, we are adding support for image classification models inside the Huggingface Transformers ecosystem. We are very excited to bring a simple API for calculating and visualizing attributions for vision transformers and their numerous variants in just 3 lines of code.

ImageClassificationExplainer (#105)

The ImageClassificationExplainer is designed to work with all models from the Transformers library that are trained for image classification (Swin, ViT etc). It provides attributions for every pixel in that image that can be easily visualized using the explainer's built-in visualize method.

Initialising an image classification is very simple, all you need is an image classification model finetuned or trained to work with Huggingface and its feature extractor.

For this example we are using google/vit-base-patch16-224, a Vision Transformer (ViT) model pre-trained on ImageNet-21k that predicts from 1000 possible classes.

:::python
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from transformers_interpret import ImageClassificationExplainer
from PIL import Image
import requests

model_name = "google/vit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# With both the model and feature extractor initialized we are now able to get explanations on an image, we will use a simple image of a golden retriever.
image_link = "https://imagesvc.meredithcorp.io/v3/mm/image?url=https%3A%2F%2Fstatic.onecms.io%2Fwp-content%2Fuploads%2Fsites%2F47%2F2020%2F08%2F16%2Fgolden-retriever-177213599-2000.jpg"

image = Image.open(requests.get(image_link, stream=True).raw)

image_classification_explainer = ImageClassificationExplainer(model=model, feature_extractor=feature_extractor)

image_attributions = image_classification_explainer(
    image
)

print(image_attributions.shape)

Which will return the following list of tuples:

:::python
>>> torch.Size([1, 3, 224, 224])

Visualizing Image Attributions

Because we are dealing with images visualization is even more straightforward than in text models.

Attributions can be easily visualized using the visualize method of the explainer. There are currently 4 supported visualization methods.

  • heatmap - a heatmap of positive and negative attributions is drawn in using the dimensions of the image.
  • overlay - the heatmap is overlayed over a grayscaled version of the original image
  • masked_image - the absolute value of attributions is used to create a mask over the original image
  • alpha_scaling - Sets the alpha channel (transparency) of each pixel to be equal to the normalized attribution value.

Heatmap

:::python
image_classification_explainer.visualize(
    method="heatmap",
    side_by_side=True,
    outlier_threshold=0.03

)

Overlay

:::python
image_classification_explainer.visualize(
    method="overlay",
    side_by_side=True,
    outlier_threshold=0.03

)

Masked Image

:::python
image_classification_explainer.visualize(
    method="masked_image",
    side_by_side=True,
    outlier_threshold=0.03

)

Alpha Scaling

:::python
image_classification_explainer.visualize(
    method="alpha_scaling",
    side_by_side=True,
    outlier_threshold=0.03

)

Source: README.md, updated 2022-10-17