from evaluator.evaluate import Evaluate
import numpy as np
import copy
from scipy.integrate import simpson
from utils.plots import plot_logit_change_with_images
from torch.nn.functional import softmax
[docs]
class Deletion(Evaluate):
"""
A class to evaluate the faithfulness of an explainability method by performing
a deletion-based experiment. It removes the most important pixels (based on
the heatmap scores) from the input image and measures the corresponding drop
in model's confidence.
The final metric is the Area Under the Deletion Curve (AUC). A lower AUC indicates
that the model's confidence drops more quickly, suggesting a more faithful
explanation.
"""
def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise',
explainability_method="cam"):
super(Deletion, self).__init__(model, processor, target_layers, targets,
explainability_method=explainability_method,
cam_method=cam_method)
[docs]
def get_idx_list(self, image):
_, _, heatmap, _, _ = self._generate_heatmaps(image)
heatmap_flat = heatmap.flatten()
sorted_indices = np.argsort(heatmap_flat)[::-1] # Sort from high to low score
return sorted_indices
def _calculate_auc_deletion(self, logit_values, fractions_removed):
"""
Calculates the Area Under the Deletion Curve (AUC) using the trapezoidal rule.
Args:
fractions_removed (list): List of fractions of the image removed.
logit_values (list): List of corresponding logit values.
Returns:
float: The AUC score.
"""
return simpson(logit_values, x=fractions_removed)
[docs]
def deletion(self, image, interval=0.1, base_value=None):
logits, _ = self._logits_and_predictions(image)
pred_idx = logits.argmax(-1).item()
modified_input_list = []
logit_list = []
fractions_removed = np.arange(0, 1.00001, interval)
idx_list = self.get_idx_list(image)
for frac in fractions_removed:
modified_input = self.modify_input(image, idx_list, mask_ratio=frac, base_value=base_value)
modified_input_list.append(modified_input)
mod_logits, _ = self._logits_and_predictions(modified_input)
pred = softmax(mod_logits, dim=-1)
mod_logits = pred[0, pred_idx].item()
logit_list.append(mod_logits)
return modified_input_list, logit_list, fractions_removed
[docs]
def plot_deletion_curve(self, image, interval=0.2, base_value=None, title=None, save_path=None):
"""
Plots the deletion curve.
Args:
fractions_removed (list): List of fractions of the image removed.
logit_values (list): List of corresponding logit values.
save_path (str, optional): Path to save the plot. If None, the plot is displayed.
"""
images_list, logit_list, fractions_removed = self.deletion(image, interval=interval, base_value=base_value)
auc = self._calculate_auc_deletion(logit_list, fractions_removed)
plot_logit_change_with_images(images=images_list, logits=logit_list, percentages=fractions_removed, auc=auc, title=title, save_path=save_path)
def __call__(self, images_list, interval=0.1, base_value=None):
auc_scores = []
for image in images_list:
_, logit_list, fractions_removed = self.deletion(image, interval=interval, base_value=base_value)
auc = self._calculate_auc_deletion(logit_list, fractions_removed)
auc_scores.append(auc)
return auc_scores