Source code for evaluator.consistency.ROS

from evaluator.consistency.consistency import Stability
import numpy as np
import PIL


[docs] class ROS(Stability): """ A class to evaluate the stability of an explainability method by calculating the relative change in explanations for similar inputs w.r.t the change in the input itself. Lower is better """ def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise', explainability_method="cam"): super(ROS, self).__init__(model, processor, target_layers, targets, explainability_method=explainability_method, cam_method=cam_method)
[docs] def evaluation_metric(self, image: PIL.Image.Image, transformed_images=None, transform_type: str = 'rotation', p_val=2, min_eps=0.00001, **kwargs): """ Calculates the Relative Input Stability (RIS) metric for a batch of transformed images. Args: image (np.ndarray): The original input image. transformed_images (List[np.ndarray]): A list of transformed versions of the original image. p_val (int): The p-value for the L_p norm calculation. min_eps (float): A small epsilon value to prevent division by zero. Returns: float: The maximum RIS value found among the valid transformed images. """ # Get transformed images if not transformed_images: transformed_images = self.generate_transformed_images(image, transform_type=transform_type, kwargs=kwargs) _, _, orig_heatmap, _, orig_label = self._generate_heatmaps(image) # orig_norm_image = self._normalize(image) # Process the entire batch of transformed images if not transformed_images: return 0.0 # Convert list of images to a single numpy array for batch processing # transformed_images_batch = np.stack(transformed_images) # Get predictions for the batch orig_logits, _ = self._logits_and_predictions(image) orig_logits = orig_logits.detach().cpu() logits_batch, labels_batch = self._logits_and_predictions(transformed_images) # Identify images where the label remains the same valid_indices = [i for i in range(len(labels_batch)) if labels_batch[i] == orig_label[0]] if len(valid_indices) == 0: return 0.0 # Filter to only include valid images and heatmaps valid_images = [transformed_images[i] for i in valid_indices] # Get heatmaps # for the valid images in a batch _, _, valid_heatmaps, _, _ = self._generate_heatmaps(valid_images) # print("Valid heatmaps", valid_heatmaps) # Normalize the valid images valid_images_logits, _ = self._logits_and_predictions(valid_images) valid_images_logits = valid_images_logits.detach().cpu() # Calculate relative explanation change for the batch rel_exp_change = (orig_heatmap - valid_heatmaps) / np.maximum(orig_heatmap, min_eps) rel_exp_change_norms = np.linalg.norm(rel_exp_change.reshape(rel_exp_change.shape[0], -1), axis=1) # Calculate relative input change for the batch rel_output_change = (valid_images_logits - orig_logits) rel_output_change_norms = np.linalg.norm(rel_output_change.reshape(rel_output_change.shape[0], -1), axis=1) # Calculate ROS for the batch ros_batch = rel_exp_change_norms / np.maximum(rel_output_change_norms, min_eps) # print("ris_batch", ris_batch) if not ros_batch.size: return 0.0 return np.log(np.max(ros_batch))