from evaluator.evaluate import Evaluate
import numpy as np
import copy
from torch.nn.functional import softmax
from scipy.integrate import simpson
from utils.plots import plot_logit_change_with_images
[docs]
class Insertion(Evaluate):
"""
A class to evaluate the faithfulness of an explainability method by performing
a insertion-based experiment. It adds the most important pixels (based on
the heatmap scores) from the input image to a base image and measures the corresponding gain
in model's confidence.
The final metric is the Area Under the Insertion Curve (AUC). A lower AUC indicates
that the model's confidence drops more quickly, suggesting a more faithful
explanation.
"""
def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise',
explainability_method="cam"):
super(Insertion, self).__init__(model, processor, target_layers, targets,
explainability_method=explainability_method,
cam_method=cam_method)
[docs]
def get_idx_list(self, image):
_, _, heatmap, _, _ = self._generate_heatmaps(image)
heatmap_flat = heatmap.flatten()
sorted_indices = np.argsort(heatmap_flat)[::-1] # Sort from high to low score
return sorted_indices
def _calculate_auc_insertion(self, logit_values, fractions_added):
"""
Calculates the Area Under the Deletion Curve (AUC) using the simpson rule.
Args:
fractions_removed (list): List of fractions of the image removed.
logit_values (list): List of corresponding logit values.
Returns:
float: The AUC score.
"""
return simpson(logit_values, x=fractions_added)
[docs]
def insertion(self, image, interval=0.1, base_value=None):
logits, _ = self._logits_and_predictions(image)
pred_idx = logits.argmax(-1).item()
modified_input_list = []
logit_list = []
fractions_added = np.arange(0, 1.00001, interval)
idx_list = self.get_idx_list(image)
for frac in fractions_added:
modified_input = self.modify_input(image, idx_list, insertion_ratio=frac, base_value=base_value)
modified_input_list.append(modified_input)
mod_logits, _ = self._logits_and_predictions(modified_input)
pred = softmax(mod_logits, dim=-1)
mod_logits = pred[0, pred_idx].item()
logit_list.append(mod_logits)
return modified_input_list, logit_list, fractions_added
[docs]
def plot_insertion_curve(self, image, interval=0.1, base_value=None, title="Insertion Curve", save_path=None):
"""
Plots the deletion curve.
Args:
fractions_removed (list): List of fractions of the image removed.
logit_values (list): List of corresponding logit values.
save_path (str, optional): Path to save the plot. If None, the plot is displayed.
"""
images_list, logit_list, fractions_added = self.insertion(image, interval=interval, base_value=base_value)
auc = self._calculate_auc_insertion(logit_list, fractions_added)
plot_logit_change_with_images(images=images_list, logits=logit_list, percentages=fractions_added, auc=auc, title=title, save_path=save_path)
def __call__(self, images_list, interval=0.1, base_value=None):
auc_scores = []
for image in images_list:
_, logit_list, fractions_added = self.insertion(image, interval=interval, base_value=base_value)
auc = self._calculate_auc_insertion(logit_list, fractions_added)
auc_scores.append(auc)
return auc_scores