Source code for evaluator.consistency.consistency

from evaluator.evaluate import Evaluate
import numpy as np
from typing import List
import PIL


[docs] class Stability(Evaluate): """ A class to evaluate the stability of an explainability method by calculating the relative change in explanations for similar inputs w.r.t the change in the input itself. Lower is better """ def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise', explainability_method="cam"): super(Stability, self).__init__(model, processor, target_layers, targets, explainability_method=explainability_method, cam_method=cam_method) def _normalize(self, arr): normalized_arr = np.array(arr).copy().astype('float') if normalized_arr.max() - normalized_arr.min() == 0: return normalized_arr normalized_arr = (normalized_arr - normalized_arr.min()) / (normalized_arr.max() - normalized_arr.min()) return normalized_arr def __call__(self, images: List[PIL.Image.Image], transformed_images: List[List[PIL.Image.Image]] = None, transform_type: str = 'rotation', p_val=2, min_eps=0.00001, **kwargs): """ Calculates the Relative Stability metric for a batch of transformed images. Args: image (np.ndarray): The original input image. transformed_images (List[np.ndarray]): A list of transformed versions of the original image. p_val (int): The p-value for the L_p norm calculation. min_eps (float): A small epsilon value to prevent division by zero. Returns: float: The maximum RIS value found among the valid transformed images. """ try: metric_list = [] if transformed_images is not None: for image, transformed_image_list in zip(images, transformed_images): metric = self.evaluation_metric(image=image, transformed_images=transformed_image_list, p_val=p_val, min_eps=min_eps, kwargs=kwargs) metric_list.append(metric) else: for image in images: metric = self.evaluation_metric(image=image, transform_type=transform_type, p_val=p_val, min_eps=min_eps, kwargs=kwargs) metric_list.append(metric) return metric_list finally: self.explainer.cleanup()
[docs] def evaluation_metric(self, image: PIL.Image.Image, transformed_images: List[PIL.Image.Image] = [], transform_type: str = 'rotation', p_val=2, min_eps=0.00001, **kwargs): """ Calculates the metric for a batch of transformed images. Args: image (np.ndarray): The original input image. transformed_images (List[np.ndarray]): A list of transformed versions of the original image. p_val (int): The p-value for the L_p norm calculation. min_eps (float): A small epsilon value to prevent division by zero. Returns: float: The metric value for a given image. """ pass
[docs] def generate_transformed_images(self, image: PIL.Image.Image, transform_type: str, **kwargs) -> List[PIL.Image.Image]: """ Generates a list of transformed images based on the specified transformation type. Args: image (PIL.Image.Image): The original input image. transform_type (str): The type of transformation to apply ('rotation' or 'noise'). Kwargs: max_angle (int): [For 'rotation'] The maximum rotation angle in degrees from the original image. Default is 5. step (float): [For 'rotation' and 'noise'] The increment between each transformation level. Default is 0.1. max_std (float): [For 'noise'] The maximum standard deviation of the Gaussian noise to be added. Default is 5. Returns: List[PIL.Image.Image]: A list of transformed images. Raises: ValueError: If an unknown transform type is specified. """ transformed_images = [] if transform_type == 'rotation': max_angle = kwargs.get('max_angle', 2) step = kwargs.get('step', 0.1) angle_list = [i for i in np.around(np.arange(-max_angle, max_angle + step, step), 2) if i != 0] transformed_images = self.transforms.get_rotated_images(image, rotations=angle_list).values() elif transform_type == 'noise': max_std = kwargs.get('max_std', 5) step = kwargs.get('step', 0.1) std_list = [i for i in np.arange(0, max_std + step, step) if i != 0] transformed_images = self.transforms.get_noisy_images(image, stds=std_list).values() else: raise ValueError(f"Unknown transform type: {transform_type}. Choose 'rotation' or 'noise'.") return list(transformed_images)