Source code for domid.trainers.trainer_ae

import torch.optim as optim
from domainlab.algos.trainers.a_trainer import AbstractTrainer
from tensorboardX import SummaryWriter

from domid.compos.predict_basic import Prediction
from domid.compos.tensorboard_fun import tensorboard_write
from domid.trainers.pretraining_KMeans import Pretraining
from domid.utils.perf_cluster import PerfCluster
from domid.utils.storing import Storing


[docs]class TrainerAE(AbstractTrainer):
[docs] def init_business(self, model, task, observer, device, aconf, flag_accept=True): """ :param model: model to train :param task: task to train on :param observer: observer to notify :param device: device to use :param writer: tensorboard writer :param pretrain: whether to pretrain the model with MSE loss :param aconf: configuration parameters, including learning rate and pretrain threshold """ super().__init__() super().init_business(model, task, observer, device, aconf) print(model) if aconf.pre_tr > 0: self.pretrain = True else: self.pretrain = False self.pretraining_finished = not self.pretrain self.lr = aconf.lr self.warmup_beta = 0.1 if not self.pretraining_finished: self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) print("".join(["#"] * 60) + "\nPretraining initialized.\n" + "".join(["#"] * 60)) else: self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.epo_loss_tr = None self.thres = aconf.pre_tr # number of epochs for pretraining self.i_h, self.i_w = task.isize.h, task.isize.w self.args = aconf self.storage = Storing(self.args) self.writer = SummaryWriter(logdir="debug/" + self.storage.experiment_name) self.loader_val = task.loader_tr self.aname = aconf.model
[docs] def tr_epoch(self, epoch): """ :param epoch: epoch number :return: """ print("Epoch {}.".format(epoch)) if self.pretraining_finished else print("Epoch {}. Pretraining.".format(epoch)) self.model.train() self.epo_loss_tr = 0 pretrain = Pretraining(self.model, self.device, self.loader_tr, self.loader_val, self.i_h, self.i_w, self.args) prediction = Prediction( self.model, self.device, self.loader_tr, self.loader_val, self.i_h, self.i_w, self.args.bs ) acc_tr_y, _, acc_tr_d, _ = prediction.epoch_tr_acc() acc_val_y, _, acc_val_d, _ = prediction.epoch_val_acc() r_score_tr = "None" r_score_te = "None" if self.args.task == "her2": r_score_tr = prediction.epoch_tr_correlation() r_score_te = prediction.epoch_val_correlation() # validation set is used as a test set # ___________Define warm-up for ELBO loss_________ if self.warmup_beta < 1 and self.pretraining_finished: self.warmup_beta = self.warmup_beta + 0.01 # _____________one training epoch: start_______________________ for i, (tensor_x, vec_y, vec_d, *other_vars) in enumerate(self.loader_tr): if i == 0: self.model.batch_zero = True if len(other_vars) > 0: inject_tensor, image_id = other_vars if len(inject_tensor) > 0: inject_tensor = inject_tensor.to(self.device) tensor_x, vec_y, vec_d = ( tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device), ) self.optimizer.zero_grad() # __________________Pretrain/ELBO loss____________ if epoch < self.thres and not self.pretraining_finished: loss = pretrain.pretrain_loss(tensor_x, inject_tensor) else: if not self.pretraining_finished: # i.e., this is the first epoch after pre-training # So we need to set the pretraining_finishend flag to True, and to reset the optimizer: self.pretraining_finished = True self.model.counter = 1 self.optimizer = optim.Adam( self.model.parameters(), lr=self.lr, ) print("".join(["#"] * 60)) print("Epoch {}: Finished pretraining and starting to use the full model loss.".format(epoch)) print("".join(["#"] * 60)) loss = self.model.cal_loss(tensor_x, inject_tensor) loss = loss.sum() loss.backward() self.optimizer.step() self.epo_loss_tr += loss.cpu().detach().item() # after one epoch (all batches), GMM is calculated again and pi, mu_c # will get updated via this line. # name convention: mu_c is the mean for the Gaussian mixture cluster, # but mu alone means mean for decoded pixel if not self.pretraining_finished: pretrain.model_fit() # __________________Validation_____________________ for i, (tensor_x_val, vec_y_val, vec_d_val, *other_vars) in enumerate(self.loader_val): if len(other_vars) > 0: inject_tensor_val, img_id_val = other_vars if len(inject_tensor_val) > 0: inject_tensor_val = inject_tensor_val.to(self.device) tensor_x_val, vec_y_val, vec_d_val = ( tensor_x_val.to(self.device), vec_y_val.to(self.device), vec_d_val.to(self.device), ) if epoch < self.thres and not self.pretraining_finished: loss_val = pretrain.pretrain_loss(tensor_x_val, inject_tensor_val) else: loss_val = self.model.cal_loss(tensor_x_val, inject_tensor_val, self.warmup_beta) if self.writer != None: tensorboard_write( self.writer, self.model, epoch, self.lr, self.warmup_beta, acc_tr_y, loss, self.pretraining_finished, tensor_x, inject_tensor, ) # _____storing results and Z space__________ self.storage.storing( epoch, acc_tr_y, acc_tr_d, self.epo_loss_tr, acc_val_y, acc_val_d, loss_val, r_score_tr, r_score_te ) if epoch % 2 == 0: _, z_proj, predictions, vec_y_labels, vec_d_labels, image_id_labels = prediction.mk_prediction() # _, Z, domain_labels, machine_labels, image_locs = prediction.mk_prediction() self.storage.storing_z_space(z_proj, predictions, vec_y_labels, vec_d_labels, image_id_labels) if epoch % 2 == 0: self.storage.saving_model(self.model) flag_stop = self.observer.update(epoch) # notify observer # self.storage.csv_dump(epoch) return flag_stop
[docs] def before_tr(self): """ check the performance of randomly initialized weight """ acc = PerfCluster.cal_acc(self.model, self.loader_tr, self.device) # FIXME change tr to te print("before training, model accuracy:", acc)
[docs] def post_tr(self): print("training is done") self.observer.after_all()