Source code for evaluator.consistency.ROS
from evaluator.consistency.consistency import Stability
import numpy as np
import PIL
[docs]
class ROS(Stability):
"""
A class to evaluate the stability of an explainability method by calculating the relative change in explanations for similar inputs w.r.t the change in the input itself.
Lower is better
"""
def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise',
explainability_method="cam"):
super(ROS, self).__init__(model, processor, target_layers, targets,
explainability_method=explainability_method,
cam_method=cam_method)
[docs]
def evaluation_metric(self, image: PIL.Image.Image, transformed_images=None, transform_type: str = 'rotation', p_val=2, min_eps=0.00001, **kwargs):
"""
Calculates the Relative Input Stability (RIS) metric for a batch of transformed images.
Args:
image (np.ndarray): The original input image.
transformed_images (List[np.ndarray]): A list of transformed versions of the original image.
p_val (int): The p-value for the L_p norm calculation.
min_eps (float): A small epsilon value to prevent division by zero.
Returns:
float: The maximum RIS value found among the valid transformed images.
"""
# Get transformed images
if not transformed_images:
transformed_images = self.generate_transformed_images(image, transform_type=transform_type, kwargs=kwargs)
_, _, orig_heatmap, _, orig_label = self._generate_heatmaps(image)
# orig_norm_image = self._normalize(image)
# Process the entire batch of transformed images
if not transformed_images:
return 0.0
# Convert list of images to a single numpy array for batch processing
# transformed_images_batch = np.stack(transformed_images)
# Get predictions for the batch
orig_logits, _ = self._logits_and_predictions(image)
orig_logits = orig_logits.detach().cpu()
logits_batch, labels_batch = self._logits_and_predictions(transformed_images)
# Identify images where the label remains the same
valid_indices = [i for i in range(len(labels_batch)) if labels_batch[i] == orig_label[0]]
if len(valid_indices) == 0:
return 0.0
# Filter to only include valid images and heatmaps
valid_images = [transformed_images[i] for i in valid_indices]
# Get heatmaps
# for the valid images in a batch
_, _, valid_heatmaps, _, _ = self._generate_heatmaps(valid_images)
# print("Valid heatmaps", valid_heatmaps)
# Normalize the valid images
valid_images_logits, _ = self._logits_and_predictions(valid_images)
valid_images_logits = valid_images_logits.detach().cpu()
# Calculate relative explanation change for the batch
rel_exp_change = (orig_heatmap - valid_heatmaps) / np.maximum(orig_heatmap, min_eps)
rel_exp_change_norms = np.linalg.norm(rel_exp_change.reshape(rel_exp_change.shape[0], -1), axis=1)
# Calculate relative input change for the batch
rel_output_change = (valid_images_logits - orig_logits)
rel_output_change_norms = np.linalg.norm(rel_output_change.reshape(rel_output_change.shape[0], -1), axis=1)
# Calculate ROS for the batch
ros_batch = rel_exp_change_norms / np.maximum(rel_output_change_norms, min_eps)
# print("ris_batch", ris_batch)
if not ros_batch.size:
return 0.0
return np.log(np.max(ros_batch))