Source code for evaluator.fidelity.deletion

from evaluator.evaluate import Evaluate
import numpy as np
import copy
from scipy.integrate import simpson
from utils.plots import plot_logit_change_with_images
from torch.nn.functional import softmax


[docs] class Deletion(Evaluate): """ A class to evaluate the faithfulness of an explainability method by performing a deletion-based experiment. It removes the most important pixels (based on the heatmap scores) from the input image and measures the corresponding drop in model's confidence. The final metric is the Area Under the Deletion Curve (AUC). A lower AUC indicates that the model's confidence drops more quickly, suggesting a more faithful explanation. """ def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise', explainability_method="cam"): super(Deletion, self).__init__(model, processor, target_layers, targets, explainability_method=explainability_method, cam_method=cam_method)
[docs] def get_idx_list(self, image): _, _, heatmap, _, _ = self._generate_heatmaps(image) heatmap_flat = heatmap.flatten() sorted_indices = np.argsort(heatmap_flat)[::-1] # Sort from high to low score return sorted_indices
[docs] def modify_input(self, image, idx_list, mask_ratio, base_value=None): if mask_ratio > 1: mask_ratio = 1 n_pixels_masked = int(len(idx_list) * mask_ratio) deletion_indices = idx_list[:n_pixels_masked] if not isinstance(image, np.ndarray): image = np.array(image) modified_image = copy.deepcopy(image) if base_value is None: base_value = modified_image.mean() # Reshape the image for efficient modification modified_image_flat = modified_image.reshape(-1, modified_image.shape[-1]) modified_image_flat[deletion_indices] = base_value mod_image_reshaped = modified_image_flat.reshape(modified_image.shape) return mod_image_reshaped
def _calculate_auc_deletion(self, logit_values, fractions_removed): """ Calculates the Area Under the Deletion Curve (AUC) using the trapezoidal rule. Args: fractions_removed (list): List of fractions of the image removed. logit_values (list): List of corresponding logit values. Returns: float: The AUC score. """ return simpson(logit_values, x=fractions_removed)
[docs] def deletion(self, image, interval=0.1, base_value=None): logits, _ = self._logits_and_predictions(image) pred_idx = logits.argmax(-1).item() modified_input_list = [] logit_list = [] fractions_removed = np.arange(0, 1.00001, interval) idx_list = self.get_idx_list(image) for frac in fractions_removed: modified_input = self.modify_input(image, idx_list, mask_ratio=frac, base_value=base_value) modified_input_list.append(modified_input) mod_logits, _ = self._logits_and_predictions(modified_input) pred = softmax(mod_logits, dim=-1) mod_logits = pred[0, pred_idx].item() logit_list.append(mod_logits) return modified_input_list, logit_list, fractions_removed
[docs] def plot_deletion_curve(self, image, interval=0.2, base_value=None, title=None, save_path=None): """ Plots the deletion curve. Args: fractions_removed (list): List of fractions of the image removed. logit_values (list): List of corresponding logit values. save_path (str, optional): Path to save the plot. If None, the plot is displayed. """ images_list, logit_list, fractions_removed = self.deletion(image, interval=interval, base_value=base_value) auc = self._calculate_auc_deletion(logit_list, fractions_removed) plot_logit_change_with_images(images=images_list, logits=logit_list, percentages=fractions_removed, auc=auc, title=title, save_path=save_path)
def __call__(self, images_list, interval=0.1, base_value=None): auc_scores = [] for image in images_list: _, logit_list, fractions_removed = self.deletion(image, interval=interval, base_value=base_value) auc = self._calculate_auc_deletion(logit_list, fractions_removed) auc_scores.append(auc) return auc_scores