Source code for evaluator.consistency.SSIM

from evaluator.consistency.consistency import Stability
import numpy as np
import PIL


[docs] class SSIM(Stability): """ A class to evaluate the stability of an explainability method by calculating the SSIM in explanations for similar inputs. Higher is better """ def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise', explainability_method="cam"): super(SSIM, self).__init__(model, processor, target_layers, targets, explainability_method=explainability_method, cam_method=cam_method)
[docs] def evaluation_metric(self, image: PIL.Image.Image, transformed_images=None, transform_type: str = 'rotation', p_val=2, min_eps=0.00001, **kwargs): """ Calculates the Relative Input Stability (RIS) metric for a batch of transformed images. Args: image (np.ndarray): The original input image. transformed_images (List[np.ndarray]): A list of transformed versions of the original image. p_val (int): The p-value for the L_p norm calculation. min_eps (float): A small epsilon value to prevent division by zero. Returns: float: The maximum RIS value found among the valid transformed images. """ if transformed_images is None: # Get transformed images transformed_images = self.generate_transformed_images(image, transform_type=transform_type, kwargs=kwargs) # print("transformed images shape", len(transformed_images)) _, _, orig_heatmap, _, orig_label = self._generate_heatmaps(image) # orig_norm_image = self._normalize(image) # Process the entire batch of transformed images if not transformed_images: return 0.0 # Convert list of images to a single numpy array for batch processing # transformed_images_batch = np.stack(transformed_images) # Get predictions for the batch logits_batch, labels_batch = self._logits_and_predictions(transformed_images) # print("no of labels", len(labels_batch)) # Identify images where the label remains the same # print("orig label", orig_label[0]) valid_indices = [i for i in range(len(labels_batch)) if labels_batch[i] == orig_label[0]] # print("num valid indices", len(valid_indices)) if len(valid_indices) == 0: return 0.0 # Filter to only include valid images and heatmaps valid_images = [transformed_images[i] for i in valid_indices] # Get heatmaps for the valid images in a batch _, _, valid_heatmaps, _, _ = self._generate_heatmaps(valid_images) ssim_scores = [] for heatmap in valid_heatmaps: ssim_scores.append(self.calculate_metrics(orig_heatmap[0], heatmap)["ssim"]) return np.maximum(0.0, np.min(ssim_scores))