from evaluator.evaluate import Evaluate
import numpy as np
import copy
from scipy.stats import pearsonr
from torch.nn.functional import softmax
[docs]
class FaithfulnessCorrelation(Evaluate):
"""
A class to evaluate the faithfulness of an explainability method by calculating the
Pearson correlation coefficient between the change in model logits and the
heatmap scores of perturbed regions. This assumes the heatmap and input image
have the same spatial dimensions (height, width).
"""
def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise',
explainability_method="cam"):
super(FaithfulnessCorrelation, self).__init__(model, processor, target_layers, targets,
explainability_method=explainability_method,
cam_method=cam_method)
def _get_image_dimensions(self, image):
"""
Helper method to handle different image types and return dimensions.
"""
if isinstance(image, np.ndarray):
h, w, _ = image.shape
return h, w
else: # Assuming PIL.Image.Image
w, h = image.size
return h, w
[docs]
def idx_sample(self, image, size=10):
"""
Creates a list of size 'size' containing a random sample of indices from
the flattened image indices.
Args:
image (np.ndarray): The input image.
size (int): The number of indices to sample.
Returns:
np.ndarray: An array of randomly sampled flattened indices.
"""
h, w = self._get_image_dimensions(image)
num_pixels = h * w
# We sample without replacement to ensure unique indices
indices = np.random.choice(num_pixels, size=size, replace=False)
return indices
[docs]
def average_subset_explanation_score(self, heatmap, idx_list):
"""
Averages the heatmap values corresponding to a list of flattened indices.
Args:
heatmap (np.ndarray): The heatmap generated by the explainer.
idx_list (np.ndarray): A list of random flattened indices.
Returns:
float: The average heatmap score for the selected subset of pixels.
"""
if not isinstance(heatmap, np.ndarray):
heatmap = np.array(heatmap)
# Flatten the heatmap to match the indices
heatmap_flat = heatmap.flatten()
# Get heatmap values at the specified indices and calculate the mean
heatmap_values = heatmap_flat[idx_list]
return np.mean(heatmap_values)
[docs]
def calculate_pearsons(self, logit_list, exp_scores):
"""
Takes two lists and calculates the Pearson's correlation coefficient
between them.
Args:
logit_list (list): A list of logit differences.
exp_scores (list): A list of average explanation scores.
Returns:
float: The Pearson correlation coefficient.
"""
corr, _ = pearsonr(logit_list, exp_scores)
return corr
def __call__(self, image_list, n_iter=100, sample_ratio=0.2):
"""
Calculates the average and standard deviation of faithfulness correlations
across a list of images.
Args:
image_list (list): A list of PIL Images or NumPy arrays.
n_iter (int): The number of iterations per image to perform perturbations.
sample_size (int): The number of pixels to randomly sample for each perturbation.
Returns:
tuple: A tuple containing the average correlation and the standard
deviation of correlations.
"""
corr_list = []
sample_size = int(sample_ratio * (image_list[0].size[0] * image_list[0].size[1]))
for image in image_list:
# Get heatmap, original logits, and predicted index
_, _, heatmap, _, _ = self._generate_heatmaps(image)
logits, _ = self._logits_and_predictions(image)
logits = softmax(logits, dim=-1)
pred_idx = logits.argmax(-1).item()
logits_list = []
exp_scores = []
for i in range(n_iter):
# Sample random indices from the image dimensions
idx_list = self.idx_sample(image, size=sample_size)
# Modify the image and get new logits
mod_image = self.modify_input(image, idx_list)
mod_logits, _ = self._logits_and_predictions(mod_image)
mod_logits = softmax(mod_logits, dim=-1)
# Calculate average explanation score for the perturbed region
avg_exp_score = self.average_subset_explanation_score(heatmap, idx_list)
# Calculate logit difference and append to lists
logit_diff = (logits[0, pred_idx] - mod_logits[0, pred_idx]).item()
logits_list.append(logit_diff)
exp_scores.append(avg_exp_score)
corr = self.calculate_pearsons(logits_list, exp_scores)
corr_list.append(corr)
return corr_list