Name | Modified | Size | Downloads / Week |
---|---|---|---|
Parent folder | |||
ImageClassificationExplainer source code.tar.gz | 2022-10-17 | 7.7 MB | |
ImageClassificationExplainer source code.zip | 2022-10-17 | 7.7 MB | |
README.md | 2022-10-17 | 4.9 kB | |
Totals: 3 Items | 15.4 MB | 0 |
This is a hugely exciting release for us as it is our first foray into the domain of computer vision. With this update, we are adding support for image classification models inside the Huggingface Transformers ecosystem. We are very excited to bring a simple API for calculating and visualizing attributions for vision transformers and their numerous variants in just 3 lines of code.
ImageClassificationExplainer (#105)
The ImageClassificationExplainer
is designed to work with all models from the Transformers library that are trained for image classification (Swin, ViT etc). It provides attributions for every pixel in that image that can be easily visualized using the explainer's built-in visualize
method.
Initialising an image classification is very simple, all you need is an image classification model finetuned or trained to work with Huggingface and its feature extractor.
For this example we are using google/vit-base-patch16-224
, a Vision Transformer (ViT) model pre-trained on ImageNet-21k that predicts from 1000 possible classes.
:::python
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from transformers_interpret import ImageClassificationExplainer
from PIL import Image
import requests
model_name = "google/vit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# With both the model and feature extractor initialized we are now able to get explanations on an image, we will use a simple image of a golden retriever.
image_link = "https://imagesvc.meredithcorp.io/v3/mm/image?url=https%3A%2F%2Fstatic.onecms.io%2Fwp-content%2Fuploads%2Fsites%2F47%2F2020%2F08%2F16%2Fgolden-retriever-177213599-2000.jpg"
image = Image.open(requests.get(image_link, stream=True).raw)
image_classification_explainer = ImageClassificationExplainer(model=model, feature_extractor=feature_extractor)
image_attributions = image_classification_explainer(
image
)
print(image_attributions.shape)
Which will return the following list of tuples:
:::python
>>> torch.Size([1, 3, 224, 224])
Visualizing Image Attributions
Because we are dealing with images visualization is even more straightforward than in text models.
Attributions can be easily visualized using the visualize
method of the explainer. There are currently 4 supported visualization methods.
heatmap
- a heatmap of positive and negative attributions is drawn in using the dimensions of the image.overlay
- the heatmap is overlayed over a grayscaled version of the original imagemasked_image
- the absolute value of attributions is used to create a mask over the original imagealpha_scaling
- Sets the alpha channel (transparency) of each pixel to be equal to the normalized attribution value.
Heatmap
:::python
image_classification_explainer.visualize(
method="heatmap",
side_by_side=True,
outlier_threshold=0.03
)
Overlay
:::python
image_classification_explainer.visualize(
method="overlay",
side_by_side=True,
outlier_threshold=0.03
)
Masked Image
:::python
image_classification_explainer.visualize(
method="masked_image",
side_by_side=True,
outlier_threshold=0.03
)
Alpha Scaling
:::python
image_classification_explainer.visualize(
method="alpha_scaling",
side_by_side=True,
outlier_threshold=0.03
)