domid.trainers package

Submodules

domid.trainers.pretraining_GMM module

class domid.trainers.pretraining_GMM.Pretraining(model, device, loader_tr, loader_val, i_h, i_w, args)[source]

Bases: object

__init__(model, device, loader_tr, loader_val, i_h, i_w, args)[source]
Parameters:
  • model – the model to train

  • device – the device to use

  • loader_tr – the training data loader

  • i_h – image height

  • i_w – image width

pretrain_loss(tensor_x, inject_tensor)[source]
Parameters:

tensor_x – the input image

Returns:

the loss

GMM_fit()[source]

During pre-training we estimate pi, mu_c, and log_sigma2_c with a GMM at the end of each epoch. After pre-training these initial parameter values are used in the calculation of the ELBO loss, and are further updated with backpropagation like all other neural network weights.

domid.trainers.pretraining_KMeans module

class domid.trainers.pretraining_KMeans.Pretraining(model, device, loader_tr, loader_val, i_h, i_w, args)[source]

Bases: object

__init__(model, device, loader_tr, loader_val, i_h, i_w, args)[source]
Parameters:
  • model – the model to train

  • device – the device to use

  • loader_tr – the training data loader

  • i_h – image height

  • i_w – image width

pretrain_loss(tensor_x, inject_tensor)[source]
Parameters:

tensor_x – the input image

Returns:

the loss

model_fit()[source]

domid.trainers.pretraining_sdcn module

class domid.trainers.pretraining_sdcn.PretrainingSDCN(model, device, loader_tr, loader_val, i_h, i_w, args)[source]

Bases: object

__init__(model, device, loader_tr, loader_val, i_h, i_w, args)[source]
Parameters:
  • model – the model to train

  • device – the device to use

  • loader_tr – the training data loader

  • i_h – image height

  • i_w – image width

pretrain_loss(tensor_x)[source]
kmeans_cluster_assignement()[source]
model_fit()[source]

domid.trainers.trainer_ae module

class domid.trainers.trainer_ae.TrainerAE(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

init_business(model, task, observer, device, aconf, flag_accept=True)[source]
Parameters:
  • model – model to train

  • task – task to train on

  • observer – observer to notify

  • device – device to use

  • writer – tensorboard writer

  • pretrain – whether to pretrain the model with MSE loss

  • aconf – configuration parameters, including learning rate and pretrain threshold

tr_epoch(epoch)[source]
Parameters:

epoch – epoch number

Returns:

before_tr()[source]

check the performance of randomly initialized weight

post_tr()[source]

after training

domid.trainers.trainer_cluster module

class domid.trainers.trainer_cluster.TrainerCluster(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

init_business(model, task, observer, device, aconf, flag_accept=True)[source]
Parameters:
  • model – model to train

  • task – task to train on

  • observer – observer to notify

  • device – device to use

  • writer – tensorboard writer

  • pretrain – whether to pretrain the model with MSE loss

  • aconf – configuration parameters, including learning rate and pretrain threshold

tr_epoch(epoch)[source]
Parameters:

epoch – epoch number

Returns:

before_tr()[source]

check the performance of randomly initialized weight

post_tr()[source]

after training

domid.trainers.trainer_sdcn module

class domid.trainers.trainer_sdcn.TrainerSDCN(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

init_business(model, task, observer, device, aconf, flag_accept=True)[source]

model, task, observer, device, aconf

tr_epoch(epoch)[source]
Parameters:

epoch – epoch number

Returns:

before_tr()[source]

check the performance of randomly initialized weight

post_tr()[source]

after training

domid.trainers.zoo_trainer module

select trainer

class domid.trainers.zoo_trainer.TrainerChainNodeGetter(str_trainer)[source]

Bases: object

Chain of Responsibility: node is named in pattern Trainer[XXX] where the string after ‘Trainer’ is the name to be passed to args.trainer.

__init__(str_trainer)[source]

__init__. :param args: command line arguments

Module contents