import PIL
import torch
import numpy as np
import cv2
from xai_methods.rollout_vit import show_mask_on_image
import gc
[docs]
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
# rule 5 from paper
[docs]
def avg_heads(cam, grad):
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
return cam
# rule 6 from paper
[docs]
def apply_self_attention_rules(R_ss, cam_ss):
R_ss_addition = torch.matmul(cam_ss, R_ss)
return R_ss_addition
[docs]
def generate_relevance(model, input, index=None):
output = model(input, register_hook=True)
if index is None:
index = np.argmax(output.cpu().data.numpy(), axis=-1)
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
one_hot[0, index] = 1
# one_hot_vector = one_hot
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.cuda() * output)
model.zero_grad()
one_hot.backward(retain_graph=True)
num_tokens = 197
R = torch.eye(num_tokens, num_tokens).cuda()
for blk in model.blocks:
grad = blk.attn.get_attn_gradients()
cam = blk.attn.get_attention_map()
cam = avg_heads(cam, grad)
R += apply_self_attention_rules(R.cuda(), cam.cuda())
return R[0, 1:]
[docs]
def generate_visualization(original_image, model=None, class_index=None):
transformer_attribution = generate_relevance(model, original_image.unsqueeze(0).cuda(), index=class_index).detach()
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
[docs]
class CheferRelevance:
def __init__(self, model, processor, attention_layer_name='attention.dropout'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = model.to(self.device)
self.processor = processor
self.hooks = []
self.attentions = []
self.attention_gradients = []
# Method 1: Hook Registration
for name, module in self.model.named_modules():
if attention_layer_name in name:
self.hooks.append(module.register_forward_hook(self.get_attention))
# Using full_backward_hook for better compatibility with modern PyTorch
self.hooks.append(module.register_full_backward_hook(self.get_attention_gradient))
# Method 2: Forward Hook logic
[docs]
def get_attention(self, module, input, output):
self.attentions.append(output.detach())
# Method 3: Backward Hook logic
[docs]
def get_attention_gradient(self, module, grad_input, grad_output):
self.attention_gradients.append(grad_output[0].detach())
# Method 4: Hook Removal
[docs]
def remove_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks = []
# Method 5: Cleanup
[docs]
def cleanup(self, clear_cache=False):
self.remove_hooks()
self.attentions = []
self.attention_gradients = []
if clear_cache:
torch.cuda.empty_cache()
gc.collect()
# Method 6: Preprocessing
def _preprocess(self, image):
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
inputs = inputs['pixel_values']
return inputs
# Method 7: Rule 5 - Average Heads
[docs]
def avg_heads(self, cam, grad):
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
return cam
# Method 8: Rule 6 - Apply Self Attention Rules
[docs]
def apply_self_attention_rules(self, R_ss, cam_ss):
return torch.matmul(cam_ss, R_ss)
# Method 9: Core Logic (__call__)
def __call__(self, input_tensor, index=None):
self.attentions = []
self.attention_gradients = []
output = self.model(input_tensor, output_attentions=True).logits
if index is None:
index = output.argmax(dim=-1).item()
one_hot = torch.zeros((1, output.size()[-1]), device=self.device)
one_hot[0, index] = 1
self.model.zero_grad()
output.backward(gradient=one_hot, retain_graph=True)
# Fix: Reverse gradients to match layer order
gradients = self.attention_gradients # [::-1]
num_tokens = self.attentions[0].shape[-1]
R = torch.eye(num_tokens, num_tokens).to(self.device)
for cam, grad in zip(self.attentions, gradients):
# Process single image in batch
cam_weighted = self.avg_heads(cam[0], grad[0])
R += self.apply_self_attention_rules(R, cam_weighted)
return R[0, 1:]
# Method 10: High-level Generation
[docs]
def generate(self, images, class_index=None):
input_tensors = self._preprocess(images)
labels = self._predict_label(images)
if isinstance(images, PIL.Image.Image):
images = [images]
grayscale_cams = []
mask_images = []
for i in range(len(images)):
single_tensor = input_tensors[i:i + 1]
# Compute attribution
rel_map = self(single_tensor, index=class_index)
# Reshape and Normalize
grid_size = int(np.sqrt(rel_map.shape[-1]))
rel_map = rel_map.reshape(grid_size, grid_size).cpu().numpy()
# Safe Normalization
denom = rel_map.max() - rel_map.min()
rel_map = (rel_map - rel_map.min()) / (denom + 1e-8)
# Visualization
orig_img = np.array(images[i])
heatmap_resized = cv2.resize(rel_map, (orig_img.shape[1], orig_img.shape[0]))
grayscale_cams.append(heatmap_resized)
mask_image = show_mask_on_image(orig_img.astype(np.float32) / 255.0, heatmap_resized)
mask_images.append(mask_image)
# Plotting as in your original VITAttentionRollout generate
# plt.imshow(mask_image)
# plt.title(f"Label: {labels[i]}")
# plt.show()
# Internal Cleanup to prevent memory accumulation
self.attentions = []
self.attention_gradients = []
return input_tensors, images, np.array(grayscale_cams), mask_images, labels
def _get_logits(self, input_image):
input_tensor = self._preprocess(input_image)
with torch.no_grad():
logits = self.model(input_tensor).logits
return logits
def _predict_label(self, images):
logits = self._get_logits(images)
predictions = logits.argmax(-1)
return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
# class CheferRelevance:
# def __init__(self, model, processor):
# self.model = model.cuda()
# self.processor = processor
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# self.hooks = []
# for name, module in self.model.named_modules():
# if 'attention.dropout' in name:
# fwd_hook = module.register_forward_hook(self.get_attention)
# bwd_hook = module.register_backward_hook(self.get_attention_gradient)
# self.hooks.append(fwd_hook)
# self.hooks.append(bwd_hook)
# self.attentions = []
# self.attention_gradients = []
# def get_attention(self, module, input, output):
# self.attentions.append(output.detach().cpu())
# def get_attention_gradient(self, module, grad_input, grad_output):
# self.attention_gradients.append(grad_input[0].detach().cpu())
# def remove_hooks(self):
# for hook in self.hooks:
# hook.remove()
# self.hooks = []
# def cleanup(self, clear_cache = False):
# self.remove_hooks()
# self.attentions = []
# self.attention_gradients = []
# if clear_cache:
# torch.cuda.empty_cache()
# gc.collect()
# def _preprocess(self, image):
# """Preprocesses an input tensor using the processor for ViT
# Args:
# image (torch.Tensor): The input image (assumed to be resized to 224x224) (C, H, W)
# """
# #image = transforms.ToTensor()(image).unsqueeze(0)
# inputs = self.processor(images=image, return_tensors="pt")['pixel_values']
# return inputs.cuda()
# def avg_heads(self, cam, grad):
# cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
# grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
# cam = grad * cam
# cam = cam.clamp(min=0).mean(dim=0)
# return cam
# # rule 6 from paper
# def apply_self_attention_rules(self, R_ss, cam_ss):
# R_ss_addition = torch.matmul(cam_ss, R_ss)
# return R_ss_addition
# def __call__(self, input_tensor, index = 1):
# self.attentions = []
# self.attention_gradients = []
# self.model.zero_grad()
# output = self.model(input_tensor, output_attentions = True).logits
# if index == None:
# index = np.argmax(output.detach().cpu().data.numpy(), axis=-1)
# one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
# one_hot[0, index] = 1
# one_hot = torch.from_numpy(one_hot).requires_grad_(True).cuda()
# one_hot = torch.sum(one_hot * output)
# self.model.zero_grad()
# one_hot.backward(retain_graph=True)
# num_tokens = 197
# R = torch.eye(num_tokens, num_tokens).cuda()
# for cam, grad in zip(self.attentions, self.attention_gradients):
# cam = self.avg_heads(cam, grad)
# R += apply_self_attention_rules(R.cuda(), cam.cuda())
# return R[0, 1:]
# def generate(self, image, class_index = None):
# input_tensor = self._preprocess(image)
# label = self._predict_label(input_tensor)
# print(label)
# height, width = input_tensor.shape[-2], input_tensor.shape[-1]
# transformer_attribution = self(input_tensor, index=class_index).detach()
# transformer_attribution = transformer_attribution.reshape(14, 14).cpu().numpy()
# transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
# transformer_attribution = cv2.resize(transformer_attribution, (height, width))
# mask_image = show_mask_on_image(image, transformer_attribution)
# heatmap = np.expand_dims(np.float32(transformer_attribution), axis=0)
# #mask_image = np.expand_dims(mask_image, axis=0)
# print("heatmap shape:", heatmap.shape, "mask_image shape:", mask_image.shape)
# plt.imshow(mask_image)
# plt.show()
# self.attentions = []
# self.attention_gradients = []
# return input_tensor, image, heatmap, mask_image, label
# def _get_logits(self, input_image):
# """
# Processes an input image and gets the raw logits from the ViT model.
# Args:
# input_image (PIL.Image.Image, np.ndarray): The input image.
# Returns:
# torch.Tensor: The raw logits output from the model.
# """
# input_tensor = self._preprocess(input_image).to(self.device)
# #Prediction label
# with torch.no_grad():
# logits = self.model(input_tensor).logits
# return logits
# def _predict_label(self, inputs):
# """
# Predicts the class label for a given input tensor using the provided model.
# Args:
# model (torch.nn.Module): The trained PyTorch model for prediction.
# input_tensor (torch.Tensor): The input tensor to the model.
# Returns:
# str: The predicted class label (obtained by finding the argmax of the logits).
# """
# logits = self._get_logits(inputs)
# predictions = logits.argmax(-1)
# return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]