import os
import pickle
import numpy as np
import scipy.sparse as sp
import torch
from sklearn.metrics import pairwise_distances as pair
from sklearn.preprocessing import normalize
[docs]class GraphConstructor:
"""
Class to construct graph from features. This is only used in training for SDCN model.
"""
[docs] def __init__(self, graph_method, topk=7):
"""
Initializer of GraphConstructor.
:param graph_method: the method to calculate distance between features; one of 'heat', 'cos', 'ncos'.
:param topk: number of connections per image
"""
self.graph_method = graph_method
self.topk = topk
[docs] def sparse_mx_to_torch_sparse_tensor(self, sparse_mx): # FIXME move to utils
"""Convert a scipy sparse matrix to a torch sparse
tensor."""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)
[docs] def get_features_labels(self, dataset):
"""
This funciton is used to get features and labels from dataset.
:param dataset: Image dataset that can be batched or unbatched
:return: X: features from the image (flattened images), labels: domain labels, region_labels: region labels if the dataset is WSI images
"""
num_batches = len(dataset)
num_img, i_c, i_w, i_h = next(iter(dataset))[0].shape
X = torch.zeros((num_batches, num_img, i_c * i_w * i_h))
labels = torch.zeros((num_batches, num_img, 1))
counter = 0
for tensor_x, vec_y, vec_d, inj_tensor, img_ids in dataset:
X[counter, :, :] = torch.reshape(tensor_x, (tensor_x.shape[0], i_c * i_w * i_h))
labels[counter, :, 0] = torch.argmax(vec_d, dim=1)
counter += 1
return X.type(torch.float32), labels.type(torch.int32)
[docs] def normalize(self, mx): # FIXME move to utils
"""
Row-normalize sparse matrix which is used to calculate the distance for normalized cosine method.
:param mx: sparse matrix
:return: row-normalized sparse matrix
"""
rowsum = np.array(mx.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.0 # i.e., when row sum is 0, we will keep that row at 0 in themultiplication below
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx
[docs] def distance_calc(self, features):
"""
This function is used to calculate distance between features.
:param features: the batch of features from the dataset
:return: distance matrix between features of the batch of images with the shape of (num_img, num_img)
"""
if self.graph_method == "heat":
dist = -0.5 * pair(features) ** 2
dist = np.exp(dist)
elif self.graph_method == "cos":
features[features > 0] = 1
dist = np.dot(features, features.T)
elif self.graph_method == "ncos":
features[features > 0] = 1
features = normalize(features, axis=1, norm="l1")
dist = np.dot(features, features.T)
return dist
[docs] def connection_calc(self, features):
"""
This function is used to calculate the connection pairs between images for all the batches of dataset.
:param features: flattened image from the batch of dataset
:return: indecies of top k connections per each image in the batch (shape: (num_img*self.topk, 2))
"""
dist = self.distance_calc(features)
connection_pairs = []
inds = []
for i in range(dist.shape[0]):
ind = np.argpartition(dist[i, :], -(self.topk + 1))[-(self.topk + 1) :]
inds.append(ind)
for i, v in enumerate(inds):
for vv in v:
if vv == i:
pass
else:
connection_pairs.append([i, vv])
return dist, inds, connection_pairs
[docs] def mk_adj_mat(self, n, connection_pairs):
"""
This function is used to make the adjacency matrix for the graph for each batch of dataset.
:param n: batchsize
:param connection_pairs: top k connections per each image in the batch (shape: (num_img*self.topk, 2))
:return:
"""
idx = np.array([i for i in range(n)], dtype=np.int32)
idx_map = {j: i for i, j in enumerate(idx)}
edges_unordered = np.array(connection_pairs, dtype=np.int32) # features #np.genfromtxt(path, dtype=np.int32)
edges_mapped = [idx_map.get(val, -1) for val in edges_unordered.flatten()]
if -1 in edges_mapped:
print("Error: Some keys in edges_unordered do not exist in idx_map.")
else:
edges = np.array(edges_mapped, dtype=np.int32).reshape(edges_unordered.shape)
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape=(n, n), dtype=np.float32)
# build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
adj = adj + sp.eye(adj.shape[0])
adj = self.normalize(adj)
return adj
[docs] def construct_graph(self, dataset, experiment_folder):
"""
This function is used to construct the graph for all the batches of dataset. This is called in the trainer function of SDCN model.
:param dataset: dataset contraining all the batches of data (or no batched data)
:param graph_method: graph construction method
:return: the adjacency matrix for all the batches of data
"""
sparse_matrices = []
adjacency_matrices = []
features, domain_labels = self.get_features_labels(dataset)
batch_num = features.shape[0]
num_features = features.shape[1]
for i in range(0, batch_num):
dist, inds, connection_pairs = self.connection_calc(features[i, :, :])
adj_mat = self.mk_adj_mat(num_features, connection_pairs)
adjacency_matrices.append(adj_mat)
sparse_mx = self.sparse_mx_to_torch_sparse_tensor(adj_mat)
sparse_matrices.append(sparse_mx)
if experiment_folder is not None:
connect_path = os.path.join(
"notebooks", experiment_folder, "connection_pairs_" + str(i) + ".pkl"
) # FIXME move to zout
feat_path = os.path.join("notebooks", experiment_folder, "features_" + str(i) + ".pkl")
label_path = os.path.join("notebooks", experiment_folder, "labels_" + str(i) + ".pkl")
with open(connect_path, "wb") as file:
pickle.dump(connection_pairs, file)
with open(feat_path, "wb") as file:
pickle.dump(features[i, :, :], file)
with open(label_path, "wb") as file:
pickle.dump(domain_labels[i, :], file)
return adjacency_matrices, sparse_matrices