Source code for evaluator.fidelity.faithfulness_corr

from evaluator.evaluate import Evaluate
import numpy as np
import copy
from scipy.stats import pearsonr
from torch.nn.functional import softmax


[docs] class FaithfulnessCorrelation(Evaluate): """ A class to evaluate the faithfulness of an explainability method by calculating the Pearson correlation coefficient between the change in model logits and the heatmap scores of perturbed regions. This assumes the heatmap and input image have the same spatial dimensions (height, width). """ def __init__(self, model, processor, target_layers, targets, cam_method='gradcamelementwise', explainability_method="cam"): super(FaithfulnessCorrelation, self).__init__(model, processor, target_layers, targets, explainability_method=explainability_method, cam_method=cam_method) def _get_image_dimensions(self, image): """ Helper method to handle different image types and return dimensions. """ if isinstance(image, np.ndarray): h, w, _ = image.shape return h, w else: # Assuming PIL.Image.Image w, h = image.size return h, w
[docs] def idx_sample(self, image, size=10): """ Creates a list of size 'size' containing a random sample of indices from the flattened image indices. Args: image (np.ndarray): The input image. size (int): The number of indices to sample. Returns: np.ndarray: An array of randomly sampled flattened indices. """ h, w = self._get_image_dimensions(image) num_pixels = h * w # We sample without replacement to ensure unique indices indices = np.random.choice(num_pixels, size=size, replace=False) return indices
[docs] def average_subset_explanation_score(self, heatmap, idx_list): """ Averages the heatmap values corresponding to a list of flattened indices. Args: heatmap (np.ndarray): The heatmap generated by the explainer. idx_list (np.ndarray): A list of random flattened indices. Returns: float: The average heatmap score for the selected subset of pixels. """ if not isinstance(heatmap, np.ndarray): heatmap = np.array(heatmap) # Flatten the heatmap to match the indices heatmap_flat = heatmap.flatten() # Get heatmap values at the specified indices and calculate the mean heatmap_values = heatmap_flat[idx_list] return np.mean(heatmap_values)
[docs] def modify_input(self, image, idx_list, base_value=None): """ Sets a subset of pixels in the input image to a base value. Args: image (np.ndarray): The input image as a NumPy array. idx_list (np.ndarray): A list of random flattened indices corresponding to the image. base_value (int or tuple, optional): The value to set the pixels to. If None, the mean of the image is used. Defaults to None. Returns: np.ndarray: The modified image. """ # Make a deep copy to avoid modifying the original image modified_image = copy.deepcopy(np.array(image)) h, w, c = modified_image.shape # If base_value is not specified, use the mean of the image if base_value is None: base_value = modified_image.mean() # Reshape the image to a flattened 2D array (pixels x channels) modified_image_flat = modified_image.reshape(-1, c) # Set the pixels at the random indices to the base_value modified_image_flat[idx_list, :] = base_value # Reshape the image back to its original dimensions modified_image = modified_image_flat.reshape(h, w, c) return modified_image
[docs] def calculate_pearsons(self, logit_list, exp_scores): """ Takes two lists and calculates the Pearson's correlation coefficient between them. Args: logit_list (list): A list of logit differences. exp_scores (list): A list of average explanation scores. Returns: float: The Pearson correlation coefficient. """ corr, _ = pearsonr(logit_list, exp_scores) return corr
def __call__(self, image_list, n_iter=100, sample_ratio=0.2): """ Calculates the average and standard deviation of faithfulness correlations across a list of images. Args: image_list (list): A list of PIL Images or NumPy arrays. n_iter (int): The number of iterations per image to perform perturbations. sample_size (int): The number of pixels to randomly sample for each perturbation. Returns: tuple: A tuple containing the average correlation and the standard deviation of correlations. """ corr_list = [] sample_size = int(sample_ratio * (image_list[0].size[0] * image_list[0].size[1])) for image in image_list: # Get heatmap, original logits, and predicted index _, _, heatmap, _, _ = self._generate_heatmaps(image) logits, _ = self._logits_and_predictions(image) logits = softmax(logits, dim=-1) pred_idx = logits.argmax(-1).item() logits_list = [] exp_scores = [] for i in range(n_iter): # Sample random indices from the image dimensions idx_list = self.idx_sample(image, size=sample_size) # Modify the image and get new logits mod_image = self.modify_input(image, idx_list) mod_logits, _ = self._logits_and_predictions(mod_image) mod_logits = softmax(mod_logits, dim=-1) # Calculate average explanation score for the perturbed region avg_exp_score = self.average_subset_explanation_score(heatmap, idx_list) # Calculate logit difference and append to lists logit_diff = (logits[0, pred_idx] - mod_logits[0, pred_idx]).item() logits_list.append(logit_diff) exp_scores.append(avg_exp_score) corr = self.calculate_pearsons(logits_list, exp_scores) corr_list.append(corr) return corr_list