Source code for xai_methods.rollout_vit

# Code adapted for Huggingface ViT from Source: https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py

import PIL
import torch
import numpy
import numpy as np
import cv2
import gc


[docs] def show_mask_on_image(img, mask): img = np.float32(img) / 255 heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 cam = heatmap + np.float32(img) cam = cam / np.max(cam) return np.uint8(255 * cam)
# class VITAttentionRollout: # def __init__(self, model, processor, attention_layer_name='attention.dropout', head_fusion="mean", # discard_ratio=0.9): # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # self.model = model.to(self.device) # self.processor = processor # self.head_fusion = head_fusion # self.discard_ratio = discard_ratio # self.hooks = [] # self.attentions = [] # for name, module in self.model.named_modules(): # if attention_layer_name in name: # hook = module.register_forward_hook(self.get_attention) # self.hooks.append(hook) # def _preprocess(self, image): # """Preprocesses an input tensor using the processor for ViT # Args: # image (torch.Tensor): The input image (assumed to be resized to 224x224) (C, H, W) # """ # #image = transforms.ToTensor()(image).unsqueeze(0) # inputs = self.processor(images=image, return_tensors="pt").to(self.device) # inputs = inputs['pixel_values'] # return inputs # def get_attention(self, module, input, output): # self.attentions.append(output.detach().cpu()) # def remove_hooks(self): # for hook in self.hooks: # hook.remove() # self.hooks = [] # def cleanup(self, clear_cache = False): # self.remove_hooks() # self.attentions = [] # if clear_cache: # torch.cuda.empty_cache() # gc.collect() # def __call__(self, input_tensor): # self.attentions = [] # with torch.no_grad(): # output = self.model(input_tensor, output_attentions = True) # return self.rollout() # def rollout(self): # result = torch.eye(self.attentions[0].size(-1)) # with torch.no_grad(): # for attention in self.attentions: # if self.head_fusion == "mean": # attention_heads_fused = attention.mean(axis=1) # elif self.head_fusion == "max": # attention_heads_fused = attention.max(axis=1)[0] # elif self.head_fusion == "min": # attention_heads_fused = attention.min(axis=1)[0] # else: # raise "Attention head fusion type Not supported" # # Drop the lowest attentions, but # # don't drop the class token # flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) # _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False) # indices = indices[indices != 0] # flat[0, indices] = 0 # I = torch.eye(attention_heads_fused.size(-1)) # a = (attention_heads_fused + 1.0*I)/2 # a = a / a.sum(dim=-1) # result = torch.matmul(a, result) # # Look at the total attention between the class token, # # and the image patches # mask = result[0, 0 , 1 :] # # In case of 224x224 image, this brings us from 196 to 14 # width = int(mask.size(-1)**0.5) # mask = mask.reshape(width, width).numpy() # mask = mask / np.max(mask) # return mask # def generate(self, images): # if isinstance(images, PIL.Image.Image): # images = [images] # input_tensors = self._preprocess(images) # print(input_tensors.shape) # labels = self._predict_label(input_tensors) # heatmaps = [] # mask_images = [] # for image, input_tensor in zip(images, input_tensors): # input_tensor = input_tensor.unsqueeze(0) # width = input_tensor.shape[-2] # height = input_tensor.shape[-1] # mask = self(input_tensor) # heatmap = cv2.resize(mask, (height, width)) # heatmaps.append(heatmap) # mask_image = show_mask_on_image(image, heatmap) # mask_images.append(mask_image) # #heatmap = np.expand_dims(np.float32(heatmap), axis=0) # mask_image = np.expand_dims(mask_image, axis=0) # plt.imshow(heatmap) # plt.show() # return input_tensors, images, np.array(heatmaps), mask_images, labels # def _get_logits(self, input_image): # """ # Processes an input image and gets the raw logits from the ViT model. # Args: # input_image (PIL.Image.Image, np.ndarray): The input image. # Returns: # torch.Tensor: The raw logits output from the model. # """ # input_tensor = self._preprocess(input_image).to(self.device) # #Prediction label # with torch.no_grad(): # logits = self.model(input_tensor).logits # return logits # def _predict_label(self, inputs): # """ # Predicts the class label for a given input tensor using the provided model. # Args: # model (torch.nn.Module): The trained PyTorch model for prediction. # input_tensor (torch.Tensor): The input tensor to the model. # Returns: # str: The predicted class label (obtained by finding the argmax of the logits). # """ # logits = self._get_logits(inputs) # predictions = logits.argmax(-1) # return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
[docs] class VITAttentionRollout: def __init__(self, model, processor, attention_layer_name='attention.dropout', head_fusion="mean", discard_ratio=0.9): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = model.to(self.device) self.processor = processor self.head_fusion = head_fusion self.discard_ratio = discard_ratio self.hooks = [] self.attentions = [] for name, module in self.model.named_modules(): if attention_layer_name in name: hook = module.register_forward_hook(self.get_attention) self.hooks.append(hook) def _preprocess(self, image): """Preprocesses an input tensor using the processor for ViT""" inputs = self.processor(images=image, return_tensors="pt").to(self.device) inputs = inputs['pixel_values'] return inputs
[docs] def get_attention(self, module, input, output): # Store detached activations on device for fast rollout math self.attentions.append(output.detach())
[docs] def remove_hooks(self): for hook in self.hooks: hook.remove() self.hooks = []
[docs] def cleanup(self, clear_cache=False): self.remove_hooks() self.attentions = [] if clear_cache: torch.cuda.empty_cache() gc.collect()
def __call__(self, input_tensor): self.attentions = [] # Clear previous run data with torch.no_grad(): self.model(input_tensor, output_attentions=True).attentions return self.rollout()
[docs] def rollout(self): # [MODIFIED]: Initializing for batch processing batch_size = self.attentions[0].shape[0] num_tokens = self.attentions[0].shape[-1] # Batch of identity matrices result = torch.eye(num_tokens).to(self.device).unsqueeze(0).repeat(batch_size, 1, 1) for attention in self.attentions: # [MODIFIED]: dim=1 is the heads dimension in (B, H, T, T) if self.head_fusion == "mean": a_fused = attention.mean(dim=1) elif self.head_fusion == "max": a_fused = attention.max(dim=1)[0] elif self.head_fusion == "min": a_fused = attention.min(dim=1)[0] else: raise ValueError("Attention head fusion type Not supported") # [MODIFIED]: Batched Discard logic (avoids hardcoded [0] index) if self.discard_ratio > 0: flat = a_fused.view(batch_size, -1) # Find threshold per image in batch val, _ = torch.topk(flat, int(flat.size(-1) * self.discard_ratio), dim=-1, largest=False) thresholds = val[:, -1].view(batch_size, 1, 1) a_fused = torch.where(a_fused <= thresholds, torch.tensor(0.0).to(self.device), a_fused) # A = (Attention + Identity) / 2 eye = torch.eye(num_tokens).to(self.device) a_fused = (a_fused + eye) / 2 a_fused = a_fused / a_fused.sum(dim=-1, keepdim=True) # Batched Matrix Multiplication result = torch.matmul(a_fused, result) # [MODIFIED]: Extract CLS token attention for all batch items mask = result[:, 0, 1:] width = int(mask.size(-1)**0.5) mask = mask.reshape(batch_size, width, width).cpu().numpy() # [MODIFIED]: Safe Normalization per image to prevent NaNs for i in range(batch_size): denom = np.max(mask[i]) - np.min(mask[i]) if denom > 0: mask[i] = (mask[i] - np.min(mask[i])) / denom else: mask[i] = np.zeros_like(mask[i]) return mask
[docs] def generate(self, images): if isinstance(images, PIL.Image.Image): images = [images] input_tensors = self._preprocess(images) labels = self._predict_label(images) # [MODIFIED]: Single model pass for the whole batch masks = self(input_tensors) heatmaps = [] mask_images = [] for i, (image, mask) in enumerate(zip(images, masks)): # Handle resizing to original image dimensions orig_w, orig_h = image.size if hasattr(image, 'size') else (224, 224) heatmap = cv2.resize(mask, (orig_w, orig_h)) heatmaps.append(heatmap) # Use show_mask_on_image (ensure it's imported in your environment) mask_image = show_mask_on_image(image, heatmap) mask_images.append(mask_image) return input_tensors, images, np.array(heatmaps), mask_images, labels
def _get_logits(self, input_image): input_tensor = self._preprocess(input_image) with torch.no_grad(): logits = self.model(input_tensor).logits return logits def _predict_label(self, images): logits = self._get_logits(images) predictions = logits.argmax(-1) return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
[docs] class VITAttentionGradRollout: def __init__(self, model, processor, attention_layer_name='attention.dropout', discard_ratio=0.9): self.model = model self.processor = processor self.discard_ratio = discard_ratio self.hooks = [] for name, module in self.model.named_modules(): if attention_layer_name in name: fwd_hook = module.register_forward_hook(self.get_attention) bwd_hook = module.register_backward_hook(self.get_attention_gradient) self.hooks.append(fwd_hook) self.hooks.append(bwd_hook) self.attentions = [] self.attention_gradients = []
[docs] def get_attention(self, module, input, output): self.attentions.append(output.cpu())
[docs] def get_attention_gradient(self, module, grad_input, grad_output): self.attention_gradients.append(grad_input[0].cpu())
def _get_logits(self, input_image): """ Processes an input image and gets the raw logits from the ViT model. Args: input_image (PIL.Image.Image, np.ndarray): The input image. Returns: torch.Tensor: The raw logits output from the model. """ input_tensor = self._preprocess(input_image).to(self.device) # Prediction label with torch.no_grad(): logits = self.model(input_tensor).logits return logits def _predict_label(self, inputs): """ Predicts the class label for a given input tensor using the provided model. Args: model (torch.nn.Module): The trained PyTorch model for prediction. input_tensor (torch.Tensor): The input tensor to the model. Returns: str: The predicted class label (obtained by finding the argmax of the logits). """ logits = self._get_logits(inputs) predictions = logits.argmax(-1) return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions] def _preprocess(self, image): """Preprocesses an input tensor using the processor for ViT Args: image (torch.Tensor): The input image (assumed to be resized to 224x224) (C, H, W) """ # image = transforms.ToTensor()(image).unsqueeze(0) inputs = self.processor(images=image, return_tensors="pt")['pixel_values'] return inputs def __call__(self, image, category_index=1): input_tensor = self._preprocess(image) self.model.zero_grad() output = self.model(input_tensor, output_attentions=True) category_mask = torch.zeros(output.logits.size()) category_mask[:, category_index] = 1 loss = (output.logits * category_mask).sum() loss.backward() return self.grad_rollout()
[docs] def remove_hooks(self): for hook in self.hooks: hook.remove() self.hooks = []
[docs] def cleanup(self, clear_cache=False): self.remove_hooks() self.attentions = [] self.attention_gradients = [] if clear_cache: torch.cuda.empty_cache() gc.collect()
[docs] def grad_rollout(self): result = torch.eye(self.attentions[0].size(-1)) with torch.no_grad(): for attention, grad in zip(self.attentions, self.attention_gradients): weights = grad attention_heads_fused = (attention * weights).mean(axis=1) attention_heads_fused[attention_heads_fused < 0] = 0 # Drop the lowest attentions, but # don't drop the class token flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) _, indices = flat.topk(int(flat.size(-1) * self.discard_ratio), -1, False) # indices = indices[indices != 0] flat[0, indices] = 0 eye = torch.eye(attention_heads_fused.size(-1)) a = (attention_heads_fused + 1.0 * eye) / 2 a = a / a.sum(dim=-1) result = torch.matmul(a, result) # Look at the total attention between the class token, # and the image patches mask = result[0, 0, 1:] # In case of 224x224 image, this brings us from 196 to 14 width = int(mask.size(-1)**0.5) mask = mask.reshape(width, width).numpy() mask = mask / np.max(mask) return mask
[docs] def generate(self, image): try: input_tensor = self._preprocess(image) # Prediction label label = self._predict_label(input_tensor) width = input_tensor.shape[-2] height = input_tensor.shape[-1] mask = self(input_tensor) heatmap = cv2.resize(mask, (height, width)) mask_image = show_mask_on_image(image, heatmap) self.attentions = [] self.attention_gradients = [] return input_tensor, image, heatmap, mask_image, label finally: self.cleanup()