Source code for domid.algos.builder_vade

from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.utils.utils_cuda import get_device

from domid.algos.observers.b_obvisitor_clustering_only import ObVisitorClusteringOnly
from domid.models.model_vade import mk_vade
from domid.trainers.zoo_trainer import TrainerChainNodeGetter


[docs]class NodeAlgoBuilderVaDE(NodeAlgoBuilder):
[docs] def init_business(self, exp): """ Initialize model, observer, trainer. Return trainer. """ task = exp.task args = exp.args device = get_device(args) zd_dim = args.zd_dim d_dim = args.d_dim L = args.L model = mk_vade()( zd_dim=zd_dim, d_dim=d_dim, device=device, i_c=task.isize.c, i_h=task.isize.h, i_w=task.isize.w, bs=args.bs, L=L, dim_inject_y=args.dim_inject_y, prior=args.prior, random_batching=args.random_batching, model_method=args.model_method, pre_tr_weight_path=args.pre_tr_weight_path, feat_extract=args.feat_extract, ) observer = ObVisitorCleanUp( ObVisitorClusteringOnly(exp, MSelOracleVisitor(MSelValPerf(max_es=args.es)), device) ) trainer = TrainerChainNodeGetter(args.trainer)() trainer.init_business(model, task, observer, device, args) return trainer, model, observer, device
[docs]def get_node_na(): return NodeAlgoBuilderVaDE