from pytorch_grad_cam import (GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM)
from pytorch_grad_cam.ablation_layer import AblationLayerVit
from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np
import torch
import matplotlib.pyplot as plt
from functools import partial
from textwrap import wrap
import gc
# Helper functions
cam_methods = methods = {
"gradcam": GradCAM,
"hirescam": HiResCAM,
"scorecam": ScoreCAM,
"gradcam++": GradCAMPlusPlus,
"ablationcam": AblationCAM,
"xgradcam": XGradCAM,
"eigencam": EigenCAM,
"eigengradcam": EigenGradCAM,
"layercam": LayerCAM,
"fullgrad": FullGrad,
"gradcamelementwise": GradCAMElementWise,
'kpcacam': KPCA_CAM,
}
# Run the model
[docs]
class HuggingfaceToTensorModelWrapper(torch.nn.Module):
""" Model wrapper to return a tensor"""
def __init__(self, model):
super(HuggingfaceToTensorModelWrapper, self).__init__()
self.model = model
[docs]
def forward(self, x):
output = self.model(x, output_attentions=True).logits
return output
# Support for multiple pytorch-grad-cam methods (see cam_methods dict).
[docs]
class ViTCAM:
"""
A class for generating and visualizing Class Activation Maps (CAMs) for Vision Transformer (ViT) models.
It supports various CAM methods implemented in the pytorch gradcam library.
"""
def __init__(self, model, processor, target_layers, targets, cam_method="gradcamelementwise", use_attentions=False, patch_size=16, image_size=224):
"""
Initializes the ViTCAM object.
Args:
model (torch.nn.Module): The pre-trained ViT model.
processor (transformers.ImageProcessor): The associated image processor for the ViT model.
target_layers (torch.nn.Module or list of torch.nn.Module): The target layer(s) in the model
for which to compute the CAM.
targets (list): A list of Captum `Target` objects specifying the target class(es) for CAM computation.
cam_method (str, optional): The CAM method to use. Defaults to "gradcamelementwise".
Supported methods are keys in the `cam_methods` dictionary.
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.cam_method = cam_method
self.model = model.to(self.device)
self.processor = processor
self.targets = targets
self.target_layers = target_layers
self.patch_size = patch_size
self.image_size = image_size
self.cam = None
# Initialize the cam method and store it in self.cam
self._init_cam(use_attentions)
def _init_cam(self, use_attentions):
if use_attentions:
transform_function = vit_reshape_attention_transform
else:
transform_function = vit_reshape_transform
reshape_transform = partial(transform_function,
width=self.image_size // self.patch_size,
height=self.image_size // self.patch_size)
if self.cam_method == "ablationcam":
self.cam = AblationCAM(model=HuggingfaceToTensorModelWrapper(self.model), target_layers=[self.target_layers], reshape_transform=reshape_transform, ablation_layer=AblationLayerVit())
else:
self.cam = cam_methods[self.cam_method](model=HuggingfaceToTensorModelWrapper(self.model), target_layers=self.target_layers, reshape_transform=reshape_transform)
self.cam.batch_size = 1
def _get_logits(self, input_image):
"""
Processes an input image and gets the raw logits from the ViT model.
Args:
input_image (PIL.Image.Image, np.ndarray): The input image.
Returns:
torch.Tensor: The raw logits output from the model.
"""
input_tensor = self._preprocess(input_image).to(self.device)
# Prediction label
with torch.no_grad():
logits = self.model(input_tensor).logits
return logits
def _predict_label(self, inputs):
"""
Predicts the class label for a given input tensor using the provided model.
Args:
model (torch.nn.Module): The trained PyTorch model for prediction.
input_tensor (torch.Tensor): The input tensor to the model.
Returns:
str: The predicted class label (obtained by finding the argmax of the logits).
"""
logits = self._get_logits(inputs)
predictions = logits.argmax(-1)
return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
[docs]
def cleanup(self, clear_cache=False):
"""Cleans up the stored activations and gradients in the CAM object to free memory."""
if clear_cache:
torch.cuda.empty_cache()
gc.collect()
def _preprocess(self, image):
"""Preprocesses an input tensor using the processor for ViT
Args:
image (torch.Tensor): The input image (assumed to be resized to 224x224) (C, H, W)
"""
# image = transforms.ToTensor()(image).unsqueeze(0)
inputs = self.processor(images=image, return_tensors="pt")['pixel_values']
return inputs
def _get_heatmap(self, inputs):
"""Generates the grayscale heatmap for the input tensor image using self.cam and self.targets
Args:
inputs (torch.Tensor): The processed input tensor (1, C, H, W)
"""
grayscale_cam = self.cam(input_tensor=inputs, targets=self.targets)
# grayscale_cam = grayscale_cam[0, :]
return grayscale_cam
def _get_cam_image(self, image):
inputs = self._preprocess(image)
# grayscale_cam = self.cam(input_tensor=inputs, targets=self.targets)[0, :]
activations_list = [np.array(a) for a in self.cam.activations_and_grads.activations]
grad_list = [np.array(g) for g in self.cam.activations_and_grads.gradients]
# activations = activations_list[0].cpu().detach().numpy()
# grads = grad_list[0]
outputs = self.cam.activations_and_grads(inputs)
if self.targets is None:
target_categories = np.argmax(outputs.detach().numpy(), axis=-1)
targets = [ClassifierOutputTarget(category) for category in target_categories]
# print(np.array(grads).mean(axis=(2,3)))
cam_image = self.cam.get_cam_image(inputs, self.target_layers, targets, activations_list[0], grad_list[0], eigen_smooth=False)
return np.clip(cam_image, 0, None)
def _show_cam_on_image(self, input_image, grayscale_cam):
"""Overlays the grayscale heatmap to the original input image to visualize the regions of interest
Args:
- input_image (pil.Image): The resized input image on 0-255 scale
- grayscale_cam (np.array): The grayscale heatmap generated for the input image
Returns:
- np.ndarray: The CAM overlaid on the original image 0-255 scale (uint8)
- image (np.ndarray): The input image as a NumPy array (range 0-1).
"""
image = np.float32(input_image)
if image.max() > 1:
image = image / 255 # Range 0-1
cam_image = show_cam_on_image(image, grayscale_cam, use_rgb=True)
return cam_image, image
[docs]
def generate(self, input_image):
"""
Generates a heatmap for a given input image using the selected CAM method.
Args:
input_image (PIL.Image.Image): The input image (assumed to be resized to 224x224) (H, W, C)
Returns:
tuple: A tuple containing:
- input_tensor (torch.Tensor): The processed input tensor.
- image (np.ndarray): The input image as a NumPy array (range 0-1).
- grayscale_cam (np.ndarray): The generated grayscale CAM heatmap (H, W).
- cam_image (np.ndarray): The CAM overlaid on the original image (uint8).
- label (str): The predicted class label for the input image.
"""
# Image processing
input_tensor = self._preprocess(input_image)
# Prediction label
label = self._predict_label(input_image)
# CAM
grayscale_cam = self._get_heatmap(input_tensor)
# print("grascale cam shape", grayscale_cam.shape)
# Show CAM on image
cam_images = []
for gray_cam, img in zip(grayscale_cam, input_tensor):
image_for_viz = (img - img.min()) / (img.max() - img.min() + 1e-8)
image_for_viz = image_for_viz.permute(1, 2, 0).detach().cpu().numpy()
cam_image, image = self._show_cam_on_image(image_for_viz, gray_cam)
cam_images.append(cam_image)
# Processed tensor, rescaled_image, grayscale_cam, cam_on_image, label
return input_tensor, image, grayscale_cam, cam_images, label
[docs]
def plot_heatmaps(self, input_image, rotations=[0]):
"""
Generates and plots heatmaps for an input image at different rotation angles.
Args:
input_image (PIL.Image.Image): The input image to generate heatmaps for.
rotations (list, optional): A list of rotation angles (in degrees) to apply to the image
before generating heatmaps. Defaults to [0] (no rotation).
"""
figsize = (15, len(rotations) * 5)
fig, ax = plt.subplots(len(rotations), 3, figsize=figsize)
for i, angle in enumerate(rotations):
ax1 = ax[i, 0] if len(rotations) > 1 else ax[0]
ax2 = ax[i, 1] if len(rotations) > 1 else ax[1]
ax3 = ax[i, 2] if len(rotations) > 1 else ax[2]
# Image processing
image_rotated = input_image.rotate(angle)
# CAM for unprocessed/processed image
_, image, cam_unprocessed, cam_image_unprocessed, label_unprocessed = self._get_heatmap(image_rotated, processed=False)
_, image, cam_processed, cam_image_processed, label_processed = self._get_heatmap(image_rotated, processed=True)
# self.get_bbox(grayscale_cam)
if angle == 0 or len(rotations) == 1:
title_label = label_processed
ax1.imshow(image)
ax1.set_title(f"Rotation: {angle}")
ax2.imshow(cam_image_unprocessed)
ax2.set_title(f"Label: {label_unprocessed}")
ax3.imshow(cam_image_processed)
ax3.set_title(f"Processed Image, Label: {label_processed}")
ax1.axis('off')
ax2.axis('off')
ax3.axis('off')
"""if len(rotations) == 1:
ax[1].set_title(f"Predicted Label: {label}")
else:
ax[0,1].set_title(f"Predicted Label: {label}")"""
fig.suptitle("\n".join(wrap(f"{self.cam_method} Heatmaps for ViT-Imagenet, Label: {title_label}", 40)), fontsize=24)
fig.savefig(f"./outputs/{self.cam_method}_{title_label.replace(' ', '_')}")
fig.show()
[docs]
def get_rot_bbox_true(self, bbox, rot):
"""
Gets the bounding box coordinates after rotating an image.
Args:
bbox (tuple or list): A tuple/list of four integers (x_min, y_min, x_max, y_max)
representing the original bounding box.
rot (int): The rotation angle in degrees (0, 90, 180, or 270).
Returns:
list: A list of four integers [x_min_rot, y_min_rot, x_max_rot, y_max_rot]
representing the bounding box coordinates after rotation.
"""
x_min, y_min, x_max, y_max = bbox
bbox_rot = {
0: [x_min, y_min, x_max, y_max],
90: [y_min, 224 - x_max, y_max, 224 - x_min],
180: [224 - x_max, 224 - y_max, 224 - x_min, 224 - y_min],
270: [224 - y_max, x_min, 224 - y_min, x_max]
}
return bbox_rot[rot]
# main()