Source code for evaluator.evaluate


import numpy as np
from utils.heatmap_metrics import HeatmapMetrics
from pytorch_grad_cam.utils.image import show_cam_on_image

import cv2

from utils.transformations import Transforms
from xai_methods.gradcam_vit import ViTCAM
from xai_methods.rollout_vit import (
    show_mask_on_image,
    VITAttentionGradRollout,
    VITAttentionRollout,
)
from xai_methods.chefer_vit import CheferRelevance


[docs] class Evaluate: def __init__(self, model, processor, target_layers, targets, explainability_method="cam", cam_method='gradcamelementwise', discard_ratio=0.9, head_fusion="min"): self.targets = targets self.target_layers = target_layers self.cam_method = cam_method self.model = model self.processor = processor self.discard_ratio = discard_ratio self.head_fusion = head_fusion self.transforms = Transforms() self.explainability_method = explainability_method if self.explainability_method == "cam": self.explainer = ViTCAM(self.model, self.processor, target_layers=self.target_layers, targets=self.targets, cam_method=self.cam_method) elif self.explainability_method == "grad_rollout": self.explainer = VITAttentionGradRollout(model, self.processor, discard_ratio=self.discard_ratio) elif explainability_method == "attention_rollout": self.explainer = VITAttentionRollout(model, self.processor, head_fusion=self.head_fusion, discard_ratio=self.discard_ratio) elif self.explainability_method == "Trans_MM_Explainability": self.explainer = CheferRelevance(model, self.processor) else: raise NameError("Explainability method not supported") self.heatmap_metrics = HeatmapMetrics() def _generate_heatmaps(self, image, explainer=None): if explainer is None: return self.explainer.generate(image) else: return explainer.generate(image) def _logits_and_predictions(self, image): """ Args: input_image (PIL.Image.Image, np.ndarray): The input image. Returns: tuple: The raw logits output from the model, the label for the input. """ logits = self.explainer._get_logits(image) label = self.explainer._predict_label(image) return logits, label
[docs] def get_cam_on_image(self, heatmap, image): if self.explainability_method == "cam": mask_image = show_cam_on_image(image, heatmap, use_rgb=True) else: mask_image = show_mask_on_image(image, heatmap) return mask_image
[docs] def assert_heatmap_range(self, heatmap1, heatmap2): if heatmap1.shape != heatmap2.shape: raise ValueError("Heatmaps must have the same shape.") if not (np.all(heatmap1 >= 0) and np.all(heatmap1 <= 1)): raise AssertionError("heatmap1 has values outside the 0-1 range.") if not (np.all(heatmap2 >= 0) and np.all(heatmap2 <= 1)): raise AssertionError("heatmap2 has values outside the 0-1 range.")
def _remove_roi(self, image, bbox): x1, y1, x2, y2 = bbox perturbed_image = np.array(image).astype('uint8') # val = np.mean(perturbed_image) # perturbed_image[y1:y2, x1:x2] = val perturbed_image = cv2.rectangle(perturbed_image, (x1, y1), (x2, y2), color=(0, 0, 0), thickness=-1) return perturbed_image
[docs] def calculate_metrics(self, heatmap1: np.ndarray, heatmap2: np.ndarray): # Sanity Checks self.assert_heatmap_range(heatmap1, heatmap2) mse = self.heatmap_metrics.mse(heatmap1, heatmap2) ssim = self.heatmap_metrics.ssim(heatmap1, heatmap2) tanimoto = self.heatmap_metrics.tanimoto(heatmap1, heatmap2) return {"mse": mse, "ssim": ssim, "tanimoto": tanimoto}
[docs] def draw_boxes(self, image, boxes, classes=["heatmap", "ground_truth"], colors=[(255, 0, 0), (0, 0, 0)]): """ Draws bounding boxes on an image. Args: image (numpy.ndarray): The image on which to draw the bounding boxes. It's expected to be a NumPy array representing an image (e.g., with shape (height, width, 3) for BGR). boxes (list): A list of bounding box coordinates. Each bounding box should be a list or tuple of four integers: (x1, y1, x2, y2), where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner. classes (list, optional): A list of strings representing the class labels for each bounding box. Defaults to ["heatmap", "ground_truth"]. The length of this list should match the number of bounding boxes. colors (list, optional): A list of BGR color tuples (e.g., (255, 0, 0) for blue) for drawing each bounding box. Defaults to [(255, 0, 0), (0, 255, 0)] (blue and green). The length of this list should match the number of bounding boxes. Returns: numpy.ndarray: The image with the bounding boxes drawn on it. """ for box, color, cls in zip(boxes, colors, classes): cv2.rectangle( image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2 ) """cv2.putText(image, cls, (int(box[0]), int(box[1] - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2, lineType=cv2.LINE_AA)""" return image