# Code adapted for Huggingface ViT from Source: https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
import PIL
import torch
import numpy
import numpy as np
import cv2
import gc
[docs]
def show_mask_on_image(img, mask):
img = np.float32(img) / 255
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return np.uint8(255 * cam)
# class VITAttentionRollout:
# def __init__(self, model, processor, attention_layer_name='attention.dropout', head_fusion="mean",
# discard_ratio=0.9):
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# self.model = model.to(self.device)
# self.processor = processor
# self.head_fusion = head_fusion
# self.discard_ratio = discard_ratio
# self.hooks = []
# self.attentions = []
# for name, module in self.model.named_modules():
# if attention_layer_name in name:
# hook = module.register_forward_hook(self.get_attention)
# self.hooks.append(hook)
# def _preprocess(self, image):
# """Preprocesses an input tensor using the processor for ViT
# Args:
# image (torch.Tensor): The input image (assumed to be resized to 224x224) (C, H, W)
# """
# #image = transforms.ToTensor()(image).unsqueeze(0)
# inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# inputs = inputs['pixel_values']
# return inputs
# def get_attention(self, module, input, output):
# self.attentions.append(output.detach().cpu())
# def remove_hooks(self):
# for hook in self.hooks:
# hook.remove()
# self.hooks = []
# def cleanup(self, clear_cache = False):
# self.remove_hooks()
# self.attentions = []
# if clear_cache:
# torch.cuda.empty_cache()
# gc.collect()
# def __call__(self, input_tensor):
# self.attentions = []
# with torch.no_grad():
# output = self.model(input_tensor, output_attentions = True)
# return self.rollout()
# def rollout(self):
# result = torch.eye(self.attentions[0].size(-1))
# with torch.no_grad():
# for attention in self.attentions:
# if self.head_fusion == "mean":
# attention_heads_fused = attention.mean(axis=1)
# elif self.head_fusion == "max":
# attention_heads_fused = attention.max(axis=1)[0]
# elif self.head_fusion == "min":
# attention_heads_fused = attention.min(axis=1)[0]
# else:
# raise "Attention head fusion type Not supported"
# # Drop the lowest attentions, but
# # don't drop the class token
# flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
# _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False)
# indices = indices[indices != 0]
# flat[0, indices] = 0
# I = torch.eye(attention_heads_fused.size(-1))
# a = (attention_heads_fused + 1.0*I)/2
# a = a / a.sum(dim=-1)
# result = torch.matmul(a, result)
# # Look at the total attention between the class token,
# # and the image patches
# mask = result[0, 0 , 1 :]
# # In case of 224x224 image, this brings us from 196 to 14
# width = int(mask.size(-1)**0.5)
# mask = mask.reshape(width, width).numpy()
# mask = mask / np.max(mask)
# return mask
# def generate(self, images):
# if isinstance(images, PIL.Image.Image):
# images = [images]
# input_tensors = self._preprocess(images)
# print(input_tensors.shape)
# labels = self._predict_label(input_tensors)
# heatmaps = []
# mask_images = []
# for image, input_tensor in zip(images, input_tensors):
# input_tensor = input_tensor.unsqueeze(0)
# width = input_tensor.shape[-2]
# height = input_tensor.shape[-1]
# mask = self(input_tensor)
# heatmap = cv2.resize(mask, (height, width))
# heatmaps.append(heatmap)
# mask_image = show_mask_on_image(image, heatmap)
# mask_images.append(mask_image)
# #heatmap = np.expand_dims(np.float32(heatmap), axis=0)
# mask_image = np.expand_dims(mask_image, axis=0)
# plt.imshow(heatmap)
# plt.show()
# return input_tensors, images, np.array(heatmaps), mask_images, labels
# def _get_logits(self, input_image):
# """
# Processes an input image and gets the raw logits from the ViT model.
# Args:
# input_image (PIL.Image.Image, np.ndarray): The input image.
# Returns:
# torch.Tensor: The raw logits output from the model.
# """
# input_tensor = self._preprocess(input_image).to(self.device)
# #Prediction label
# with torch.no_grad():
# logits = self.model(input_tensor).logits
# return logits
# def _predict_label(self, inputs):
# """
# Predicts the class label for a given input tensor using the provided model.
# Args:
# model (torch.nn.Module): The trained PyTorch model for prediction.
# input_tensor (torch.Tensor): The input tensor to the model.
# Returns:
# str: The predicted class label (obtained by finding the argmax of the logits).
# """
# logits = self._get_logits(inputs)
# predictions = logits.argmax(-1)
# return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
[docs]
class VITAttentionRollout:
def __init__(self, model, processor, attention_layer_name='attention.dropout', head_fusion="mean", discard_ratio=0.9):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = model.to(self.device)
self.processor = processor
self.head_fusion = head_fusion
self.discard_ratio = discard_ratio
self.hooks = []
self.attentions = []
for name, module in self.model.named_modules():
if attention_layer_name in name:
hook = module.register_forward_hook(self.get_attention)
self.hooks.append(hook)
def _preprocess(self, image):
"""Preprocesses an input tensor using the processor for ViT"""
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
inputs = inputs['pixel_values']
return inputs
[docs]
def get_attention(self, module, input, output):
# Store detached activations on device for fast rollout math
self.attentions.append(output.detach())
[docs]
def remove_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks = []
[docs]
def cleanup(self, clear_cache=False):
self.remove_hooks()
self.attentions = []
if clear_cache:
torch.cuda.empty_cache()
gc.collect()
def __call__(self, input_tensor):
self.attentions = [] # Clear previous run data
with torch.no_grad():
self.model(input_tensor, output_attentions=True).attentions
return self.rollout()
[docs]
def rollout(self):
# [MODIFIED]: Initializing for batch processing
batch_size = self.attentions[0].shape[0]
num_tokens = self.attentions[0].shape[-1]
# Batch of identity matrices
result = torch.eye(num_tokens).to(self.device).unsqueeze(0).repeat(batch_size, 1, 1)
for attention in self.attentions:
# [MODIFIED]: dim=1 is the heads dimension in (B, H, T, T)
if self.head_fusion == "mean":
a_fused = attention.mean(dim=1)
elif self.head_fusion == "max":
a_fused = attention.max(dim=1)[0]
elif self.head_fusion == "min":
a_fused = attention.min(dim=1)[0]
else:
raise ValueError("Attention head fusion type Not supported")
# [MODIFIED]: Batched Discard logic (avoids hardcoded [0] index)
if self.discard_ratio > 0:
flat = a_fused.view(batch_size, -1)
# Find threshold per image in batch
val, _ = torch.topk(flat, int(flat.size(-1) * self.discard_ratio), dim=-1, largest=False)
thresholds = val[:, -1].view(batch_size, 1, 1)
a_fused = torch.where(a_fused <= thresholds, torch.tensor(0.0).to(self.device), a_fused)
# A = (Attention + Identity) / 2
eye = torch.eye(num_tokens).to(self.device)
a_fused = (a_fused + eye) / 2
a_fused = a_fused / a_fused.sum(dim=-1, keepdim=True)
# Batched Matrix Multiplication
result = torch.matmul(a_fused, result)
# [MODIFIED]: Extract CLS token attention for all batch items
mask = result[:, 0, 1:]
width = int(mask.size(-1)**0.5)
mask = mask.reshape(batch_size, width, width).cpu().numpy()
# [MODIFIED]: Safe Normalization per image to prevent NaNs
for i in range(batch_size):
denom = np.max(mask[i]) - np.min(mask[i])
if denom > 0:
mask[i] = (mask[i] - np.min(mask[i])) / denom
else:
mask[i] = np.zeros_like(mask[i])
return mask
[docs]
def generate(self, images):
if isinstance(images, PIL.Image.Image):
images = [images]
input_tensors = self._preprocess(images)
labels = self._predict_label(images)
# [MODIFIED]: Single model pass for the whole batch
masks = self(input_tensors)
heatmaps = []
mask_images = []
for i, (image, mask) in enumerate(zip(images, masks)):
# Handle resizing to original image dimensions
orig_w, orig_h = image.size if hasattr(image, 'size') else (224, 224)
heatmap = cv2.resize(mask, (orig_w, orig_h))
heatmaps.append(heatmap)
# Use show_mask_on_image (ensure it's imported in your environment)
mask_image = show_mask_on_image(image, heatmap)
mask_images.append(mask_image)
return input_tensors, images, np.array(heatmaps), mask_images, labels
def _get_logits(self, input_image):
input_tensor = self._preprocess(input_image)
with torch.no_grad():
logits = self.model(input_tensor).logits
return logits
def _predict_label(self, images):
logits = self._get_logits(images)
predictions = logits.argmax(-1)
return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
[docs]
class VITAttentionGradRollout:
def __init__(self, model, processor, attention_layer_name='attention.dropout', discard_ratio=0.9):
self.model = model
self.processor = processor
self.discard_ratio = discard_ratio
self.hooks = []
for name, module in self.model.named_modules():
if attention_layer_name in name:
fwd_hook = module.register_forward_hook(self.get_attention)
bwd_hook = module.register_backward_hook(self.get_attention_gradient)
self.hooks.append(fwd_hook)
self.hooks.append(bwd_hook)
self.attentions = []
self.attention_gradients = []
[docs]
def get_attention(self, module, input, output):
self.attentions.append(output.cpu())
[docs]
def get_attention_gradient(self, module, grad_input, grad_output):
self.attention_gradients.append(grad_input[0].cpu())
def _get_logits(self, input_image):
"""
Processes an input image and gets the raw logits from the ViT model.
Args:
input_image (PIL.Image.Image, np.ndarray): The input image.
Returns:
torch.Tensor: The raw logits output from the model.
"""
input_tensor = self._preprocess(input_image).to(self.device)
# Prediction label
with torch.no_grad():
logits = self.model(input_tensor).logits
return logits
def _predict_label(self, inputs):
"""
Predicts the class label for a given input tensor using the provided model.
Args:
model (torch.nn.Module): The trained PyTorch model for prediction.
input_tensor (torch.Tensor): The input tensor to the model.
Returns:
str: The predicted class label (obtained by finding the argmax of the logits).
"""
logits = self._get_logits(inputs)
predictions = logits.argmax(-1)
return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
def _preprocess(self, image):
"""Preprocesses an input tensor using the processor for ViT
Args:
image (torch.Tensor): The input image (assumed to be resized to 224x224) (C, H, W)
"""
# image = transforms.ToTensor()(image).unsqueeze(0)
inputs = self.processor(images=image, return_tensors="pt")['pixel_values']
return inputs
def __call__(self, image, category_index=1):
input_tensor = self._preprocess(image)
self.model.zero_grad()
output = self.model(input_tensor, output_attentions=True)
category_mask = torch.zeros(output.logits.size())
category_mask[:, category_index] = 1
loss = (output.logits * category_mask).sum()
loss.backward()
return self.grad_rollout()
[docs]
def remove_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks = []
[docs]
def cleanup(self, clear_cache=False):
self.remove_hooks()
self.attentions = []
self.attention_gradients = []
if clear_cache:
torch.cuda.empty_cache()
gc.collect()
[docs]
def grad_rollout(self):
result = torch.eye(self.attentions[0].size(-1))
with torch.no_grad():
for attention, grad in zip(self.attentions, self.attention_gradients):
weights = grad
attention_heads_fused = (attention * weights).mean(axis=1)
attention_heads_fused[attention_heads_fused < 0] = 0
# Drop the lowest attentions, but
# don't drop the class token
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
_, indices = flat.topk(int(flat.size(-1) * self.discard_ratio), -1, False)
# indices = indices[indices != 0]
flat[0, indices] = 0
eye = torch.eye(attention_heads_fused.size(-1))
a = (attention_heads_fused + 1.0 * eye) / 2
a = a / a.sum(dim=-1)
result = torch.matmul(a, result)
# Look at the total attention between the class token,
# and the image patches
mask = result[0, 0, 1:]
# In case of 224x224 image, this brings us from 196 to 14
width = int(mask.size(-1)**0.5)
mask = mask.reshape(width, width).numpy()
mask = mask / np.max(mask)
return mask
[docs]
def generate(self, image):
try:
input_tensor = self._preprocess(image)
# Prediction label
label = self._predict_label(input_tensor)
width = input_tensor.shape[-2]
height = input_tensor.shape[-1]
mask = self(input_tensor)
heatmap = cv2.resize(mask, (height, width))
mask_image = show_mask_on_image(image, heatmap)
self.attentions = []
self.attention_gradients = []
return input_tensor, image, heatmap, mask_image, label
finally:
self.cleanup()