Source code for evaluator.consistency.SSIM
from evaluator.consistency.consistency import Stability
import numpy as np
import PIL
[docs]
class SSIM(Stability):
"""
A class to evaluate the stability of an explainability method by calculating the SSIM in explanations for similar inputs.
Higher is better
"""
def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise',
explainability_method="cam"):
super(SSIM, self).__init__(model, processor, target_layers, targets,
explainability_method=explainability_method,
cam_method=cam_method)
[docs]
def evaluation_metric(self, image: PIL.Image.Image, transformed_images=None, transform_type: str = 'rotation', p_val=2, min_eps=0.00001, **kwargs):
"""
Calculates the Relative Input Stability (RIS) metric for a batch of transformed images.
Args:
image (np.ndarray): The original input image.
transformed_images (List[np.ndarray]): A list of transformed versions of the original image.
p_val (int): The p-value for the L_p norm calculation.
min_eps (float): A small epsilon value to prevent division by zero.
Returns:
float: The maximum RIS value found among the valid transformed images.
"""
if transformed_images is None:
# Get transformed images
transformed_images = self.generate_transformed_images(image, transform_type=transform_type, kwargs=kwargs)
# print("transformed images shape", len(transformed_images))
_, _, orig_heatmap, _, orig_label = self._generate_heatmaps(image)
# orig_norm_image = self._normalize(image)
# Process the entire batch of transformed images
if not transformed_images:
return 0.0
# Convert list of images to a single numpy array for batch processing
# transformed_images_batch = np.stack(transformed_images)
# Get predictions for the batch
logits_batch, labels_batch = self._logits_and_predictions(transformed_images)
# print("no of labels", len(labels_batch))
# Identify images where the label remains the same
# print("orig label", orig_label[0])
valid_indices = [i for i in range(len(labels_batch)) if labels_batch[i] == orig_label[0]]
# print("num valid indices", len(valid_indices))
if len(valid_indices) == 0:
return 0.0
# Filter to only include valid images and heatmaps
valid_images = [transformed_images[i] for i in valid_indices]
# Get heatmaps for the valid images in a batch
_, _, valid_heatmaps, _, _ = self._generate_heatmaps(valid_images)
ssim_scores = []
for heatmap in valid_heatmaps:
ssim_scores.append(self.calculate_metrics(orig_heatmap[0], heatmap)["ssim"])
return np.maximum(0.0, np.min(ssim_scores))