Source code for xai_methods.gradcam_vit

from pytorch_grad_cam import (GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM)
from pytorch_grad_cam.ablation_layer import AblationLayerVit
from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np
import torch
import matplotlib.pyplot as plt
from functools import partial
from textwrap import wrap
import gc

# Helper functions

cam_methods = methods = {
    "gradcam": GradCAM,
    "hirescam": HiResCAM,
    "scorecam": ScoreCAM,
    "gradcam++": GradCAMPlusPlus,
    "ablationcam": AblationCAM,
    "xgradcam": XGradCAM,
    "eigencam": EigenCAM,
    "eigengradcam": EigenGradCAM,
    "layercam": LayerCAM,
    "fullgrad": FullGrad,
    "gradcamelementwise": GradCAMElementWise,
    'kpcacam': KPCA_CAM,
}


# Run the model
[docs] def vit_reshape_attention_transform(tensor, height=14, width=14): result = tensor[:, :, 0, 1:].squeeze(2) result = result.reshape(result.size(0), result.size(1), height, width) return result
[docs] class HuggingfaceToTensorModelWrapper(torch.nn.Module): """ Model wrapper to return a tensor""" def __init__(self, model): super(HuggingfaceToTensorModelWrapper, self).__init__() self.model = model
[docs] def forward(self, x): output = self.model(x, output_attentions=True).logits return output
# Support for multiple pytorch-grad-cam methods (see cam_methods dict).
[docs] class ViTCAM: """ A class for generating and visualizing Class Activation Maps (CAMs) for Vision Transformer (ViT) models. It supports various CAM methods implemented in the pytorch gradcam library. """ def __init__(self, model, processor, target_layers, targets, cam_method="gradcamelementwise", use_attentions=False, patch_size=16, image_size=224): """ Initializes the ViTCAM object. Args: model (torch.nn.Module): The pre-trained ViT model. processor (transformers.ImageProcessor): The associated image processor for the ViT model. target_layers (torch.nn.Module or list of torch.nn.Module): The target layer(s) in the model for which to compute the CAM. targets (list): A list of Captum `Target` objects specifying the target class(es) for CAM computation. cam_method (str, optional): The CAM method to use. Defaults to "gradcamelementwise". Supported methods are keys in the `cam_methods` dictionary. """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.cam_method = cam_method self.model = model.to(self.device) self.processor = processor self.targets = targets self.target_layers = target_layers self.patch_size = patch_size self.image_size = image_size self.cam = None # Initialize the cam method and store it in self.cam self._init_cam(use_attentions) def _init_cam(self, use_attentions): if use_attentions: transform_function = vit_reshape_attention_transform else: transform_function = vit_reshape_transform reshape_transform = partial(transform_function, width=self.image_size // self.patch_size, height=self.image_size // self.patch_size) if self.cam_method == "ablationcam": self.cam = AblationCAM(model=HuggingfaceToTensorModelWrapper(self.model), target_layers=[self.target_layers], reshape_transform=reshape_transform, ablation_layer=AblationLayerVit()) else: self.cam = cam_methods[self.cam_method](model=HuggingfaceToTensorModelWrapper(self.model), target_layers=self.target_layers, reshape_transform=reshape_transform) self.cam.batch_size = 1 def _get_logits(self, input_image): """ Processes an input image and gets the raw logits from the ViT model. Args: input_image (PIL.Image.Image, np.ndarray): The input image. Returns: torch.Tensor: The raw logits output from the model. """ input_tensor = self._preprocess(input_image).to(self.device) # Prediction label with torch.no_grad(): logits = self.model(input_tensor).logits return logits def _predict_label(self, inputs): """ Predicts the class label for a given input tensor using the provided model. Args: model (torch.nn.Module): The trained PyTorch model for prediction. input_tensor (torch.Tensor): The input tensor to the model. Returns: str: The predicted class label (obtained by finding the argmax of the logits). """ logits = self._get_logits(inputs) predictions = logits.argmax(-1) return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
[docs] def cleanup(self, clear_cache=False): """Cleans up the stored activations and gradients in the CAM object to free memory.""" if clear_cache: torch.cuda.empty_cache() gc.collect()
def _preprocess(self, image): """Preprocesses an input tensor using the processor for ViT Args: image (torch.Tensor): The input image (assumed to be resized to 224x224) (C, H, W) """ # image = transforms.ToTensor()(image).unsqueeze(0) inputs = self.processor(images=image, return_tensors="pt")['pixel_values'] return inputs def _get_heatmap(self, inputs): """Generates the grayscale heatmap for the input tensor image using self.cam and self.targets Args: inputs (torch.Tensor): The processed input tensor (1, C, H, W) """ grayscale_cam = self.cam(input_tensor=inputs, targets=self.targets) # grayscale_cam = grayscale_cam[0, :] return grayscale_cam def _get_cam_image(self, image): inputs = self._preprocess(image) # grayscale_cam = self.cam(input_tensor=inputs, targets=self.targets)[0, :] activations_list = [np.array(a) for a in self.cam.activations_and_grads.activations] grad_list = [np.array(g) for g in self.cam.activations_and_grads.gradients] # activations = activations_list[0].cpu().detach().numpy() # grads = grad_list[0] outputs = self.cam.activations_and_grads(inputs) if self.targets is None: target_categories = np.argmax(outputs.detach().numpy(), axis=-1) targets = [ClassifierOutputTarget(category) for category in target_categories] # print(np.array(grads).mean(axis=(2,3))) cam_image = self.cam.get_cam_image(inputs, self.target_layers, targets, activations_list[0], grad_list[0], eigen_smooth=False) return np.clip(cam_image, 0, None) def _show_cam_on_image(self, input_image, grayscale_cam): """Overlays the grayscale heatmap to the original input image to visualize the regions of interest Args: - input_image (pil.Image): The resized input image on 0-255 scale - grayscale_cam (np.array): The grayscale heatmap generated for the input image Returns: - np.ndarray: The CAM overlaid on the original image 0-255 scale (uint8) - image (np.ndarray): The input image as a NumPy array (range 0-1). """ image = np.float32(input_image) if image.max() > 1: image = image / 255 # Range 0-1 cam_image = show_cam_on_image(image, grayscale_cam, use_rgb=True) return cam_image, image
[docs] def generate(self, input_image): """ Generates a heatmap for a given input image using the selected CAM method. Args: input_image (PIL.Image.Image): The input image (assumed to be resized to 224x224) (H, W, C) Returns: tuple: A tuple containing: - input_tensor (torch.Tensor): The processed input tensor. - image (np.ndarray): The input image as a NumPy array (range 0-1). - grayscale_cam (np.ndarray): The generated grayscale CAM heatmap (H, W). - cam_image (np.ndarray): The CAM overlaid on the original image (uint8). - label (str): The predicted class label for the input image. """ # Image processing input_tensor = self._preprocess(input_image) # Prediction label label = self._predict_label(input_image) # CAM grayscale_cam = self._get_heatmap(input_tensor) # print("grascale cam shape", grayscale_cam.shape) # Show CAM on image cam_images = [] for gray_cam, img in zip(grayscale_cam, input_tensor): image_for_viz = (img - img.min()) / (img.max() - img.min() + 1e-8) image_for_viz = image_for_viz.permute(1, 2, 0).detach().cpu().numpy() cam_image, image = self._show_cam_on_image(image_for_viz, gray_cam) cam_images.append(cam_image) # Processed tensor, rescaled_image, grayscale_cam, cam_on_image, label return input_tensor, image, grayscale_cam, cam_images, label
[docs] def plot_heatmaps(self, input_image, rotations=[0]): """ Generates and plots heatmaps for an input image at different rotation angles. Args: input_image (PIL.Image.Image): The input image to generate heatmaps for. rotations (list, optional): A list of rotation angles (in degrees) to apply to the image before generating heatmaps. Defaults to [0] (no rotation). """ figsize = (15, len(rotations) * 5) fig, ax = plt.subplots(len(rotations), 3, figsize=figsize) for i, angle in enumerate(rotations): ax1 = ax[i, 0] if len(rotations) > 1 else ax[0] ax2 = ax[i, 1] if len(rotations) > 1 else ax[1] ax3 = ax[i, 2] if len(rotations) > 1 else ax[2] # Image processing image_rotated = input_image.rotate(angle) # CAM for unprocessed/processed image _, image, cam_unprocessed, cam_image_unprocessed, label_unprocessed = self._get_heatmap(image_rotated, processed=False) _, image, cam_processed, cam_image_processed, label_processed = self._get_heatmap(image_rotated, processed=True) # self.get_bbox(grayscale_cam) if angle == 0 or len(rotations) == 1: title_label = label_processed ax1.imshow(image) ax1.set_title(f"Rotation: {angle}") ax2.imshow(cam_image_unprocessed) ax2.set_title(f"Label: {label_unprocessed}") ax3.imshow(cam_image_processed) ax3.set_title(f"Processed Image, Label: {label_processed}") ax1.axis('off') ax2.axis('off') ax3.axis('off') """if len(rotations) == 1: ax[1].set_title(f"Predicted Label: {label}") else: ax[0,1].set_title(f"Predicted Label: {label}")""" fig.suptitle("\n".join(wrap(f"{self.cam_method} Heatmaps for ViT-Imagenet, Label: {title_label}", 40)), fontsize=24) fig.savefig(f"./outputs/{self.cam_method}_{title_label.replace(' ', '_')}") fig.show()
[docs] def get_rot_bbox_true(self, bbox, rot): """ Gets the bounding box coordinates after rotating an image. Args: bbox (tuple or list): A tuple/list of four integers (x_min, y_min, x_max, y_max) representing the original bounding box. rot (int): The rotation angle in degrees (0, 90, 180, or 270). Returns: list: A list of four integers [x_min_rot, y_min_rot, x_max_rot, y_max_rot] representing the bounding box coordinates after rotation. """ x_min, y_min, x_max, y_max = bbox bbox_rot = { 0: [x_min, y_min, x_max, y_max], 90: [y_min, 224 - x_max, y_max, 224 - x_min], 180: [224 - x_max, 224 - y_max, 224 - x_min, 224 - y_min], 270: [224 - y_max, x_min, 224 - y_min, x_max] } return bbox_rot[rot]
# main()