Source code for evaluator.consistency.RIS
from evaluator.consistency.consistency import Stability
import numpy as np
import PIL
[docs]
class RIS(Stability):
"""
A class to evaluate the stability of an explainability method by calculating the relative change in explanations for similar inputs w.r.t the change in the input itself.
Lower is better
"""
def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise',
explainability_method="cam"):
super(RIS, self).__init__(model, processor, target_layers, targets,
explainability_method=explainability_method,
cam_method=cam_method)
[docs]
def evaluation_metric(self, image: PIL.Image.Image, transformed_images=None, transform_type: str = 'rotation', p_val=2, min_eps=0.00001, **kwargs):
"""
Calculates the Relative Input Stability (RIS) metric for a batch of transformed images.
Args:
image (np.ndarray): The original input image.
transformed_images (List[np.ndarray]): A list of transformed versions of the original image.
p_val (int): The p-value for the L_p norm calculation.
min_eps (float): A small epsilon value to prevent division by zero.
Returns:
float: The maximum RIS value found among the valid transformed images.
"""
if transformed_images is None:
# Get transformed images
transformed_images = self.generate_transformed_images(image, transform_type=transform_type, kwargs=kwargs)
_, _, orig_heatmap, _, orig_label = self._generate_heatmaps(image)
orig_norm_image = self._normalize(image)
# Process the entire batch of transformed images
if transformed_images is None:
return 0.0
# Convert list of images to a single numpy array for batch processing
# transformed_images_batch = np.stack(transformed_images)
# Get predictions for the batch
logits_batch, labels_batch = self._logits_and_predictions(transformed_images)
# Identify images where the label remains the same
valid_indices = [i for i in range(len(labels_batch)) if labels_batch[i] == orig_label[0]]
if len(valid_indices) == 0:
return 0.0
# Filter to only include valid images and heatmaps
valid_images = [transformed_images[i] for i in valid_indices]
# Get heatmaps for the valid images in a batch
_, _, valid_heatmaps, _, _ = self._generate_heatmaps(valid_images)
# print("Valid heatmaps", valid_heatmaps)
# Normalize the valid images
valid_norm_images = np.stack([self._normalize(img) for img in valid_images])
# Reshape for broadcasting
# orig_heatmap_expanded = np.expand_dims(orig_heatmap, axis=0)
orig_norm_image_expanded = np.expand_dims(orig_norm_image, axis=0)
# print("Orig norm image shape", orig_norm_image_expanded.shape)
# Calculate relative explanation change for the batch
rel_exp_change = (orig_heatmap - valid_heatmaps) / np.maximum(orig_heatmap, min_eps)
rel_exp_change_norms = np.linalg.norm(rel_exp_change.reshape(rel_exp_change.shape[0], -1), axis=1)
# print("rel_exp_change_norms shape", rel_exp_change_norms)
# Calculate relative input change for the batch
rel_input_change = (orig_norm_image_expanded - valid_norm_images) / np.maximum(orig_norm_image_expanded, min_eps)
# print("rel_input_change shape", rel_input_change.shape)
rel_input_change_norms = np.linalg.norm(rel_input_change.reshape(rel_input_change.shape[0], -1), axis=1)
# print("rel_input_change_norms shape", rel_input_change_norms)
# Calculate RIS for the batch
ris_batch = rel_exp_change_norms / np.maximum(rel_input_change_norms, min_eps)
# print("ris_batch", ris_batch)
if not ris_batch.size:
return 0.0
return np.log(np.max(ris_batch))