Source code for domid.dsets.make_graph_wsi
import numpy as np
from domid.dsets.make_graph import GraphConstructor
[docs]class GraphConstructorWSI(GraphConstructor):
"""
Class to construct graph from features from WSI images.
This is only used in training for SDCN model and for WSI dataset.
"""
[docs] def __init__(self, graph_method, topk=7):
"""
Initializer of GraphConstructor.
:param graph_method: the method to calculate distance between features; one of 'heat', 'cos', 'ncos', 'patch_distance'.
:param topk: number of connections per image
"""
super().__init__(graph_method, topk)
[docs] def distance_calc_wsi(self, features=None, coordinates=None):
"""
This function is used to calculate distance between features.
:param features: the batch of features from the dataset
:param coordinates: if the image(patch in the batch) has the coordinates specified, then the distance between can be calculated based on the coordinates
:return: distance matrix between features of the batch of images with the shape of (num_img, num_img)
"""
if self.graph_method == "patch_distance":
num_coords = len(coordinates)
dist = np.zeros((num_coords, num_coords))
for i in range(num_coords):
for j in range(i, num_coords):
distance = np.sqrt(
(coordinates[i][0] - coordinates[j][0]) ** 2 + (coordinates[i][1] - coordinates[j][1]) ** 2
)
dist[i, j] = distance
dist[j, i] = distance
else:
dist = super().distance_calc(features)
return dist
[docs] def connection_calc(self, features, region_labels):
"""
This function is used to calculate the connection pairs between images for all the batches of dataset.
:param features: flattened image from the batch of dataset
:param region_labels: spacial information between patches used to calculate the distance between them (e.g. of the string '1Carcinoma_coord_39100_39573_patchnumber_98_xy_0_0.png')
:return: indecies of top k connections per each image in the batch (shape: (num_img*self.topk, 2))
"""
dist = []
if len(region_labels) > 0:
# sample region_label that is passed to this function is '1Carcinoma_coord_39100_39573_patchnumber_98_xy_0_0.png'
# the coordinated of the region would then be (39100, 39573) and the coordinates of the patch would be (0, 0)
coordinates = [
[
int(reg_lab.split("_")[2]) + int(reg_lab.split("_")[-1][:-4]),
int(reg_lab.split("_")[3]) + int(reg_lab.split("_")[-2][2:]),
]
for reg_lab in region_labels
]
d = self.distance_calc_wsi(features, coordinates) # within each region calculate distance between patches
dist.append(d)
else:
dist.append(self.distance_calc_wsi(features))
connection_pairs = []
inds = []
counter = 0
for region in dist:
for i in range(region.shape[0]):
ind = np.argpartition(region[i, :], -(self.topk + 1))[-(self.topk + 1) :]
ind = ind + np.ones(len(ind)) * counter * region.shape[0]
ind = ind.astype(np.int32)
inds.append(ind) # each patch's 10 connections
counter += 1
for i, v in enumerate(inds):
for vv in v:
if vv == i:
pass
else:
connection_pairs.append([i, vv])
return dist, inds, connection_pairs
[docs] def construct_graph(self, features, img_ids, experiment_folder):
"""
This function is used to construct the graph for all the batches of dataset. This is called in the trainer function of SDCN model.
:param features: flattened image from the batch of dataset
:img_ids:
:experiment_folder:
:return: the adjacency matrix for one batch of data
"""
coordinates = ["_".join(img_id.split("/")[-1].split("_")[-8:]) for img_id in img_ids]
num_features = features.shape[0]
dist, inds, connection_pairs = self.connection_calc(features, coordinates)
adj_mx = self.mk_adj_mat(num_features, connection_pairs)
sparse_mx = self.sparse_mx_to_torch_sparse_tensor(adj_mx)
return adj_mx, sparse_mx