Source code for domid.trainers.trainer_sdcn

import torch
import torch.nn.parallel
import torch.optim as optim
from domainlab.algos.trainers.a_trainer import AbstractTrainer

from domid.compos.predict_basic import Prediction
from domid.compos.tensorboard_fun import tensorboard_write
from domid.dsets.make_graph import GraphConstructor
from domid.dsets.make_graph_wsi import GraphConstructorWSI
from domid.trainers.pretraining_sdcn import PretrainingSDCN
from domid.utils.perf_cluster import PerfCluster
from domid.utils.storing import Storing


[docs]class TrainerSDCN(AbstractTrainer):
[docs] def init_business(self, model, task, observer, device, aconf, flag_accept=True): """ model, task, observer, device, aconf """ # def init_business(self, model, task, observer, device, writer, pretrain=True, aconf=None): # """ # :param model: model to train # :param task: task to train on # :param observer: observer to notify # :param device: device to use # :param writer: tensorboard writer # :param pretrain: whether to pretrain the model with MSE loss # :param aconf: configuration parameters, including learning rate and pretrain threshold # """ # super().__init__() super().init_business(model, task, observer, device, aconf) # breakpoint() # print(model) if aconf.pre_tr > 0: self.pretrain = True else: self.pretrain = False self.pretraining_finished = not self.pretrain self.lr = aconf.lr self.warmup_beta = 0.1 if not self.pretraining_finished: self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) print("".join(["#"] * 60) + "\nPretraining initialized.\n" + "".join(["#"] * 60)) else: self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.epo_loss_tr = None self.writer = None self.thres = aconf.pre_tr # number of epochs for pretraining self.i_h, self.i_w = task.isize.h, task.isize.w self.args = aconf self.storage = Storing(self.args) self.loader_val = task.loader_tr self.aname = aconf.model self.graph_method = self.model.graph_method assert self.graph_method, "Graph calculation methos should be specified" print("Graph calculation method is", self.graph_method) if "her" in self.args.task: # this calculates graph once and uses it for all the epochs self.adj_mx, self.spar_mx = GraphConstructorHER2(self.graph_method).construct_graph( self.loader_tr, self.storage.experiment_name ) # .to(self.device) self.model.adj = self.spar_mx[0] # Initializing GNN with a sample graph and calculating all the graphs is needed for all of the batches if "mnist" in self.args.task: # this calculates graph once and uses it for all the epochs self.adj_mx, self.spar_mx = GraphConstructor(self.graph_method).construct_graph( self.loader_tr, self.storage.experiment_name ) # .to(self.device) self.model.adj = self.spar_mx[0] if "wsi" in self.args.task: # this initializes to calculate graph on the fly for every epoch # for the "wsi" task the graphs are constructed in domid/trainers/pretraining_sdcn.py self.graph_constr = GraphConstructorWSI(self.graph_method) init_adj_mx, init_spar_mx = self.graph_constr.construct_graph( next(iter(self.loader_tr))[0][: int(self.args.bs / 3), :, :, :], next(iter(self.loader_tr))[-1][: int(self.args.bs / 3)], self.storage.experiment_name, ) self.model.adj = init_spar_mx
[docs] def tr_epoch(self, epoch): """ :param epoch: epoch number :return: """ print("Epoch {}.".format(epoch)) if self.pretraining_finished else print("Epoch {}. Pretraining.".format(epoch)) # import pdb; pdb.set_trace() # for name, param in self.model.named_parameters(): # if 'weight' in name: # weights = param.data.cpu().numpy() # print(weights.shape) # pdb.set_trace() # self.model.train() self.epo_loss_tr = 0 pretrain = PretrainingSDCN( self.model, self.device, self.loader_tr, self.loader_val, self.i_h, self.i_w, self.args ) prediction = Prediction( self.model, self.device, self.loader_tr, self.loader_val, self.i_h, self.i_w, self.args.bs ) acc_tr_y, _, acc_tr_d, _ = prediction.epoch_tr_acc() acc_val_y, _, acc_val_d, _ = prediction.epoch_val_acc() r_score_tr = "None" r_score_te = "None" kl_total = 0 ce_total = 0 re_total = 0 # if self.args.task == "her2": # r_score_tr = prediction.epoch_tr_correlation() # r_score_te = prediction.epoch_val_correlation() # validation set is used as a test set # ___________Define warm-up for ELBO loss_________ if self.warmup_beta < 1 and self.pretraining_finished: self.warmup_beta = self.warmup_beta + 0.01 # _____________one training epoch: start_______________________ for i, (tensor_x, vec_y, vec_d, *other_vars) in enumerate(self.loader_tr): if len(other_vars) > 0: inject_tensor, image_id = other_vars if len(inject_tensor) > 0: inject_tensor = inject_tensor.to(self.device) if i == 0: self.model.batch_zero = True if self.args.random_batching: patches_idx = self.model.random_ind[i] # torch.randint(0, len(vec_y), (int(self.args.bs/3),)) tensor_x = tensor_x[patches_idx, :, :, :] vec_y = vec_y[patches_idx, :] vec_d = vec_d[patches_idx, :] image_id = [image_id[patch_idx_num] for patch_idx_num in patches_idx] init_adj_mx, init_spar_mx = self.graph_constr.construct_graph( tensor_x, image_id, self.storage.experiment_name ) self.model.adj = init_spar_mx else: self.model.adj = self.spar_mx[i] # .to(self.device) if i < 3: print("i_" + str(i), vec_y.argmax(dim=1).unique(), vec_d.argmax(dim=1).unique()) tensor_x, vec_y, vec_d = ( tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device), ) self.optimizer.zero_grad() # __________________Pretrain/ELBO loss____________ if epoch < self.thres and not self.pretraining_finished: loss = pretrain.pretrain_loss(tensor_x) else: if not self.pretraining_finished: # i.e., this is the first epoch after pre-training # So we need to set the pretraining_finishend flag to True, and to reset the optimizer: self.pretraining_finished = True self.model.counter = 1 self.optimizer = optim.Adam( self.model.parameters(), lr=self.lr, # betas=(0.5, 0.9), # weight_decay=0.0001, ) print("".join(["#"] * 60)) print("Epoch {}: Finished pretraining and starting to use the full model loss.".format(epoch)) print("".join(["#"] * 60)) loss = self.model.cal_loss(tensor_x) # print('loss', loss) loss = loss.sum() loss.backward() self.optimizer.step() self.epo_loss_tr += loss.cpu().detach().item() kl_batch, ce_batch, re_batch = self.model.cal_loss_for_tensorboard() kl_total += kl_batch ce_total += ce_batch re_total += re_batch # after one epoch (all batches), GMM is calculated again and pi, mu_c # will get updated via this line. # name convention: mu_c is the mean for the Gaussian mixture cluster, # but mu alone means mean for decoded pixel if not self.pretraining_finished: pretrain.model_fit() # __________________Validation_____________________ for i, (tensor_x_val, vec_y_val, vec_d_val, *other_vars) in enumerate(self.loader_val): if self.args.random_batching: if len(other_vars) > 0: inject_tensor, img_id_val = other_vars patches_idx = self.model.random_ind[i] # torch.randint(0, len(vec_y), (int(self.args.bs/3),)) tensor_x_val = tensor_x_val[patches_idx, :, :, :] vec_y_val = vec_y_val[patches_idx, :] vec_d_val = vec_d_val[patches_idx, :] img_id_val = [img_id_val[patch_idx_num] for patch_idx_num in patches_idx] init_adj_mx, init_spar_mx = self.graph_constr.construct_graph( tensor_x_val, img_id_val, self.storage.experiment_name ) self.model.adj = init_spar_mx tensor_x_val, vec_y_val, vec_d_val = ( tensor_x_val.to(self.device), vec_y_val.to(self.device), vec_d_val.to(self.device), ) if epoch < self.thres and not self.pretraining_finished: loss_val = pretrain.pretrain_loss(tensor_x_val) else: loss_val = self.model.cal_loss(tensor_x_val) if self.writer != None: tensorboard_write( self.writer, self.model, epoch, self.lr, self.warmup_beta, acc_tr_y, loss, self.pretraining_finished, tensor_x, other_info=(kl_total, ce_total, re_total), ) if self.args.task == "wsi": self.model.random_ind = [torch.randint(0, self.args.bs, (int(self.args.bs / 3),)) for i in range(0, 65)] if epoch == self.args.epos - 1 or epoch == self.args.epos: self.model.random_ind = [ torch.range(0, int(self.args.bs / 3) - 1, step=1, dtype=torch.long) for i in range(0, 65) ] # FIXME # arg.bs/3 =900, as a 1/3 of all of the patchs per subject # TODO:assert statement that all images from one region # _____storing results and Z space__________ self.storage.storing( epoch, acc_tr_y, acc_tr_d, self.epo_loss_tr, acc_val_y, acc_val_d, loss_val, r_score_tr, r_score_te ) if epoch % 1 == 0: _, z_proj, predictions, vec_y_labels, vec_d_labels, image_id_labels = prediction.mk_prediction() self.storage.storing_z_space(z_proj, predictions, vec_y_labels, vec_d_labels, image_id_labels) if epoch % 1 == 0: self.storage.saving_model(self.model) flag_stop = self.observer.update(epoch) # notify observer # self.storage.csv_dump(epoch) return flag_stop
[docs] def before_tr(self): """ check the performance of randomly initialized weight """ acc = PerfCluster.cal_acc(self.model, self.loader_tr, self.device) # FIXME change tr to te print("before training, model accuracy:", acc)
[docs] def post_tr(self): print("training is done") self.observer.after_all()