import numpy as np
from utils.heatmap_metrics import HeatmapMetrics
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2
from utils.transformations import Transforms
from xai_methods.gradcam_vit import ViTCAM
from xai_methods.rollout_vit import (
show_mask_on_image,
VITAttentionGradRollout,
VITAttentionRollout,
)
from xai_methods.chefer_vit import CheferRelevance
[docs]
class Evaluate:
def __init__(self, model, processor, target_layers, targets, explainability_method="cam", cam_method='gradcamelementwise', discard_ratio=0.9, head_fusion="min"):
self.targets = targets
self.target_layers = target_layers
self.cam_method = cam_method
self.model = model
self.processor = processor
self.discard_ratio = discard_ratio
self.head_fusion = head_fusion
self.transforms = Transforms()
self.explainability_method = explainability_method
if self.explainability_method == "cam":
self.explainer = ViTCAM(self.model, self.processor, target_layers=self.target_layers, targets=self.targets, cam_method=self.cam_method)
elif self.explainability_method == "grad_rollout":
self.explainer = VITAttentionGradRollout(model, self.processor, discard_ratio=self.discard_ratio)
elif explainability_method == "attention_rollout":
self.explainer = VITAttentionRollout(model, self.processor, head_fusion=self.head_fusion, discard_ratio=self.discard_ratio)
elif self.explainability_method == "Trans_MM_Explainability":
self.explainer = CheferRelevance(model, self.processor)
else:
raise NameError("Explainability method not supported")
self.heatmap_metrics = HeatmapMetrics()
def _generate_heatmaps(self, image, explainer=None):
if explainer is None:
return self.explainer.generate(image)
else:
return explainer.generate(image)
def _logits_and_predictions(self, image):
"""
Args:
input_image (PIL.Image.Image, np.ndarray): The input image.
Returns:
tuple: The raw logits output from the model, the label for the input.
"""
logits = self.explainer._get_logits(image)
label = self.explainer._predict_label(image)
return logits, label
[docs]
def get_cam_on_image(self, heatmap, image):
if self.explainability_method == "cam":
mask_image = show_cam_on_image(image, heatmap, use_rgb=True)
else:
mask_image = show_mask_on_image(image, heatmap)
return mask_image
[docs]
def assert_heatmap_range(self, heatmap1, heatmap2):
if heatmap1.shape != heatmap2.shape:
raise ValueError("Heatmaps must have the same shape.")
if not (np.all(heatmap1 >= 0) and np.all(heatmap1 <= 1)):
raise AssertionError("heatmap1 has values outside the 0-1 range.")
if not (np.all(heatmap2 >= 0) and np.all(heatmap2 <= 1)):
raise AssertionError("heatmap2 has values outside the 0-1 range.")
def _remove_roi(self, image, bbox):
x1, y1, x2, y2 = bbox
perturbed_image = np.array(image).astype('uint8')
# val = np.mean(perturbed_image)
# perturbed_image[y1:y2, x1:x2] = val
perturbed_image = cv2.rectangle(perturbed_image, (x1, y1), (x2, y2), color=(0, 0, 0), thickness=-1)
return perturbed_image
[docs]
def calculate_metrics(self, heatmap1: np.ndarray, heatmap2: np.ndarray):
# Sanity Checks
self.assert_heatmap_range(heatmap1, heatmap2)
mse = self.heatmap_metrics.mse(heatmap1, heatmap2)
ssim = self.heatmap_metrics.ssim(heatmap1, heatmap2)
tanimoto = self.heatmap_metrics.tanimoto(heatmap1, heatmap2)
return {"mse": mse,
"ssim": ssim,
"tanimoto": tanimoto}
[docs]
def draw_boxes(self, image, boxes, classes=["heatmap", "ground_truth"], colors=[(255, 0, 0), (0, 0, 0)]):
"""
Draws bounding boxes on an image.
Args:
image (numpy.ndarray): The image on which to draw the bounding boxes.
It's expected to be a NumPy array representing an image (e.g., with shape (height, width, 3) for BGR).
boxes (list): A list of bounding box coordinates. Each bounding box should be a list or tuple
of four integers: (x1, y1, x2, y2), where (x1, y1) is the top-left corner and
(x2, y2) is the bottom-right corner.
classes (list, optional): A list of strings representing the class labels for each bounding box.
Defaults to ["heatmap", "ground_truth"]. The length of this list should match the
number of bounding boxes.
colors (list, optional): A list of BGR color tuples (e.g., (255, 0, 0) for blue) for drawing
each bounding box. Defaults to [(255, 0, 0), (0, 255, 0)] (blue and green).
The length of this list should match the number of bounding boxes.
Returns:
numpy.ndarray: The image with the bounding boxes drawn on it.
"""
for box, color, cls in zip(boxes, colors, classes):
cv2.rectangle(
image,
(int(box[0]), int(box[1])),
(int(box[2]), int(box[3])),
color, 2
)
"""cv2.putText(image, cls, (int(box[0]), int(box[1] - 5)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2,
lineType=cv2.LINE_AA)"""
return image