Source code for evaluator.fidelity.insertion

from evaluator.evaluate import Evaluate
import numpy as np
import copy
from torch.nn.functional import softmax
from scipy.integrate import simpson
from utils.plots import plot_logit_change_with_images


[docs] class Insertion(Evaluate): """ A class to evaluate the faithfulness of an explainability method by performing a insertion-based experiment. It adds the most important pixels (based on the heatmap scores) from the input image to a base image and measures the corresponding gain in model's confidence. The final metric is the Area Under the Insertion Curve (AUC). A lower AUC indicates that the model's confidence drops more quickly, suggesting a more faithful explanation. """ def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise', explainability_method="cam"): super(Insertion, self).__init__(model, processor, target_layers, targets, explainability_method=explainability_method, cam_method=cam_method)
[docs] def get_idx_list(self, image): _, _, heatmap, _, _ = self._generate_heatmaps(image) heatmap_flat = heatmap.flatten() sorted_indices = np.argsort(heatmap_flat)[::-1] # Sort from high to low score return sorted_indices
[docs] def modify_input(self, image, idx_list, insertion_ratio, base_value=None): if insertion_ratio > 1: insertion_ratio = 1 n_pixels_inserted = int(len(idx_list) * insertion_ratio) insertion_indices = idx_list[:n_pixels_inserted] if not isinstance(image, np.ndarray): image = np.array(image) modified_image = copy.deepcopy(image) if base_value is None: base_value = modified_image.mean() base_image = np.full(image.shape, base_value, dtype=int) # Reshape the image for efficient modification base_image_flat = base_image.reshape(-1, base_image.shape[-1]) modified_image_flat = modified_image.reshape(-1, modified_image.shape[-1]) base_image_flat[insertion_indices] = modified_image_flat[insertion_indices] base_image_reshaped = base_image_flat.reshape(base_image.shape) return base_image_reshaped
def _calculate_auc_insertion(self, logit_values, fractions_added): """ Calculates the Area Under the Deletion Curve (AUC) using the simpson rule. Args: fractions_removed (list): List of fractions of the image removed. logit_values (list): List of corresponding logit values. Returns: float: The AUC score. """ return simpson(logit_values, x=fractions_added)
[docs] def insertion(self, image, interval=0.1, base_value=None): logits, _ = self._logits_and_predictions(image) pred_idx = logits.argmax(-1).item() modified_input_list = [] logit_list = [] fractions_added = np.arange(0, 1.00001, interval) idx_list = self.get_idx_list(image) for frac in fractions_added: modified_input = self.modify_input(image, idx_list, insertion_ratio=frac, base_value=base_value) modified_input_list.append(modified_input) mod_logits, _ = self._logits_and_predictions(modified_input) pred = softmax(mod_logits, dim=-1) mod_logits = pred[0, pred_idx].item() logit_list.append(mod_logits) return modified_input_list, logit_list, fractions_added
[docs] def plot_insertion_curve(self, image, interval=0.1, base_value=None, title="Insertion Curve", save_path=None): """ Plots the deletion curve. Args: fractions_removed (list): List of fractions of the image removed. logit_values (list): List of corresponding logit values. save_path (str, optional): Path to save the plot. If None, the plot is displayed. """ images_list, logit_list, fractions_added = self.insertion(image, interval=interval, base_value=base_value) auc = self._calculate_auc_insertion(logit_list, fractions_added) plot_logit_change_with_images(images=images_list, logits=logit_list, percentages=fractions_added, auc=auc, title=title, save_path=save_path)
def __call__(self, images_list, interval=0.1, base_value=None): auc_scores = [] for image in images_list: _, logit_list, fractions_added = self.insertion(image, interval=interval, base_value=base_value) auc = self._calculate_auc_insertion(logit_list, fractions_added) auc_scores.append(auc) return auc_scores