Source code for xai_methods.chefer_vit

import PIL
import torch
import numpy as np
import cv2
from xai_methods.rollout_vit import show_mask_on_image
import gc


[docs] def show_cam_on_image(img, mask): heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 cam = heatmap + np.float32(img) cam = cam / np.max(cam) return cam
# rule 5 from paper
[docs] def avg_heads(cam, grad): cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]) grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) cam = grad * cam cam = cam.clamp(min=0).mean(dim=0) return cam
# rule 6 from paper
[docs] def apply_self_attention_rules(R_ss, cam_ss): R_ss_addition = torch.matmul(cam_ss, R_ss) return R_ss_addition
[docs] def generate_relevance(model, input, index=None): output = model(input, register_hook=True) if index is None: index = np.argmax(output.cpu().data.numpy(), axis=-1) one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) one_hot[0, index] = 1 # one_hot_vector = one_hot one_hot = torch.from_numpy(one_hot).requires_grad_(True) one_hot = torch.sum(one_hot.cuda() * output) model.zero_grad() one_hot.backward(retain_graph=True) num_tokens = 197 R = torch.eye(num_tokens, num_tokens).cuda() for blk in model.blocks: grad = blk.attn.get_attn_gradients() cam = blk.attn.get_attention_map() cam = avg_heads(cam, grad) R += apply_self_attention_rules(R.cuda(), cam.cuda()) return R[0, 1:]
[docs] def generate_visualization(original_image, model=None, class_index=None): transformer_attribution = generate_relevance(model, original_image.unsqueeze(0).cuda(), index=class_index).detach() transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear') transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy() transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min()) image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy() image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min()) vis = show_cam_on_image(image_transformer_attribution, transformer_attribution) vis = np.uint8(255 * vis) vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) return vis
[docs] class CheferRelevance: def __init__(self, model, processor, attention_layer_name='attention.dropout'): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = model.to(self.device) self.processor = processor self.hooks = [] self.attentions = [] self.attention_gradients = [] # Method 1: Hook Registration for name, module in self.model.named_modules(): if attention_layer_name in name: self.hooks.append(module.register_forward_hook(self.get_attention)) # Using full_backward_hook for better compatibility with modern PyTorch self.hooks.append(module.register_full_backward_hook(self.get_attention_gradient)) # Method 2: Forward Hook logic
[docs] def get_attention(self, module, input, output): self.attentions.append(output.detach())
# Method 3: Backward Hook logic
[docs] def get_attention_gradient(self, module, grad_input, grad_output): self.attention_gradients.append(grad_output[0].detach())
# Method 4: Hook Removal
[docs] def remove_hooks(self): for hook in self.hooks: hook.remove() self.hooks = []
# Method 5: Cleanup
[docs] def cleanup(self, clear_cache=False): self.remove_hooks() self.attentions = [] self.attention_gradients = [] if clear_cache: torch.cuda.empty_cache() gc.collect()
# Method 6: Preprocessing def _preprocess(self, image): inputs = self.processor(images=image, return_tensors="pt").to(self.device) inputs = inputs['pixel_values'] return inputs # Method 7: Rule 5 - Average Heads
[docs] def avg_heads(self, cam, grad): cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]) grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) cam = grad * cam cam = cam.clamp(min=0).mean(dim=0) return cam
# Method 8: Rule 6 - Apply Self Attention Rules
[docs] def apply_self_attention_rules(self, R_ss, cam_ss): return torch.matmul(cam_ss, R_ss)
# Method 9: Core Logic (__call__) def __call__(self, input_tensor, index=None): self.attentions = [] self.attention_gradients = [] output = self.model(input_tensor, output_attentions=True).logits if index is None: index = output.argmax(dim=-1).item() one_hot = torch.zeros((1, output.size()[-1]), device=self.device) one_hot[0, index] = 1 self.model.zero_grad() output.backward(gradient=one_hot, retain_graph=True) # Fix: Reverse gradients to match layer order gradients = self.attention_gradients # [::-1] num_tokens = self.attentions[0].shape[-1] R = torch.eye(num_tokens, num_tokens).to(self.device) for cam, grad in zip(self.attentions, gradients): # Process single image in batch cam_weighted = self.avg_heads(cam[0], grad[0]) R += self.apply_self_attention_rules(R, cam_weighted) return R[0, 1:] # Method 10: High-level Generation
[docs] def generate(self, images, class_index=None): input_tensors = self._preprocess(images) labels = self._predict_label(images) if isinstance(images, PIL.Image.Image): images = [images] grayscale_cams = [] mask_images = [] for i in range(len(images)): single_tensor = input_tensors[i:i + 1] # Compute attribution rel_map = self(single_tensor, index=class_index) # Reshape and Normalize grid_size = int(np.sqrt(rel_map.shape[-1])) rel_map = rel_map.reshape(grid_size, grid_size).cpu().numpy() # Safe Normalization denom = rel_map.max() - rel_map.min() rel_map = (rel_map - rel_map.min()) / (denom + 1e-8) # Visualization orig_img = np.array(images[i]) heatmap_resized = cv2.resize(rel_map, (orig_img.shape[1], orig_img.shape[0])) grayscale_cams.append(heatmap_resized) mask_image = show_mask_on_image(orig_img.astype(np.float32) / 255.0, heatmap_resized) mask_images.append(mask_image) # Plotting as in your original VITAttentionRollout generate # plt.imshow(mask_image) # plt.title(f"Label: {labels[i]}") # plt.show() # Internal Cleanup to prevent memory accumulation self.attentions = [] self.attention_gradients = [] return input_tensors, images, np.array(grayscale_cams), mask_images, labels
def _get_logits(self, input_image): input_tensor = self._preprocess(input_image) with torch.no_grad(): logits = self.model(input_tensor).logits return logits def _predict_label(self, images): logits = self._get_logits(images) predictions = logits.argmax(-1) return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]
# class CheferRelevance: # def __init__(self, model, processor): # self.model = model.cuda() # self.processor = processor # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # self.hooks = [] # for name, module in self.model.named_modules(): # if 'attention.dropout' in name: # fwd_hook = module.register_forward_hook(self.get_attention) # bwd_hook = module.register_backward_hook(self.get_attention_gradient) # self.hooks.append(fwd_hook) # self.hooks.append(bwd_hook) # self.attentions = [] # self.attention_gradients = [] # def get_attention(self, module, input, output): # self.attentions.append(output.detach().cpu()) # def get_attention_gradient(self, module, grad_input, grad_output): # self.attention_gradients.append(grad_input[0].detach().cpu()) # def remove_hooks(self): # for hook in self.hooks: # hook.remove() # self.hooks = [] # def cleanup(self, clear_cache = False): # self.remove_hooks() # self.attentions = [] # self.attention_gradients = [] # if clear_cache: # torch.cuda.empty_cache() # gc.collect() # def _preprocess(self, image): # """Preprocesses an input tensor using the processor for ViT # Args: # image (torch.Tensor): The input image (assumed to be resized to 224x224) (C, H, W) # """ # #image = transforms.ToTensor()(image).unsqueeze(0) # inputs = self.processor(images=image, return_tensors="pt")['pixel_values'] # return inputs.cuda() # def avg_heads(self, cam, grad): # cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]) # grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) # cam = grad * cam # cam = cam.clamp(min=0).mean(dim=0) # return cam # # rule 6 from paper # def apply_self_attention_rules(self, R_ss, cam_ss): # R_ss_addition = torch.matmul(cam_ss, R_ss) # return R_ss_addition # def __call__(self, input_tensor, index = 1): # self.attentions = [] # self.attention_gradients = [] # self.model.zero_grad() # output = self.model(input_tensor, output_attentions = True).logits # if index == None: # index = np.argmax(output.detach().cpu().data.numpy(), axis=-1) # one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) # one_hot[0, index] = 1 # one_hot = torch.from_numpy(one_hot).requires_grad_(True).cuda() # one_hot = torch.sum(one_hot * output) # self.model.zero_grad() # one_hot.backward(retain_graph=True) # num_tokens = 197 # R = torch.eye(num_tokens, num_tokens).cuda() # for cam, grad in zip(self.attentions, self.attention_gradients): # cam = self.avg_heads(cam, grad) # R += apply_self_attention_rules(R.cuda(), cam.cuda()) # return R[0, 1:] # def generate(self, image, class_index = None): # input_tensor = self._preprocess(image) # label = self._predict_label(input_tensor) # print(label) # height, width = input_tensor.shape[-2], input_tensor.shape[-1] # transformer_attribution = self(input_tensor, index=class_index).detach() # transformer_attribution = transformer_attribution.reshape(14, 14).cpu().numpy() # transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min()) # transformer_attribution = cv2.resize(transformer_attribution, (height, width)) # mask_image = show_mask_on_image(image, transformer_attribution) # heatmap = np.expand_dims(np.float32(transformer_attribution), axis=0) # #mask_image = np.expand_dims(mask_image, axis=0) # print("heatmap shape:", heatmap.shape, "mask_image shape:", mask_image.shape) # plt.imshow(mask_image) # plt.show() # self.attentions = [] # self.attention_gradients = [] # return input_tensor, image, heatmap, mask_image, label # def _get_logits(self, input_image): # """ # Processes an input image and gets the raw logits from the ViT model. # Args: # input_image (PIL.Image.Image, np.ndarray): The input image. # Returns: # torch.Tensor: The raw logits output from the model. # """ # input_tensor = self._preprocess(input_image).to(self.device) # #Prediction label # with torch.no_grad(): # logits = self.model(input_tensor).logits # return logits # def _predict_label(self, inputs): # """ # Predicts the class label for a given input tensor using the provided model. # Args: # model (torch.nn.Module): The trained PyTorch model for prediction. # input_tensor (torch.Tensor): The input tensor to the model. # Returns: # str: The predicted class label (obtained by finding the argmax of the logits). # """ # logits = self._get_logits(inputs) # predictions = logits.argmax(-1) # return [self.model.config.id2label[i.item()].split(',')[0] for i in predictions]