Source code for domid.compos.GNN_layer

import torch
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


[docs]class GNNLayer(Module):
[docs] def __init__(self, in_features, out_features, device): super(GNNLayer, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.zeros([in_features, out_features], dtype=torch.float, device=device)) torch.nn.init.xavier_uniform_(self.weight)
[docs] def forward(self, features, adj, activation=torch.nn.ReLU()): """ :param features: features from specific layer of the encoder :param adj: adjecency matrix from the constructed graph :param activation: :return: hidden layer of GNN """ support = torch.mm(features, self.weight) output = torch.spmm(adj, support) if activation: output = activation(output) return output