#!/usr/bin/env python3 """Shared VGAE encoder. Imported by train_vgae.py and encode_graph.py.""" import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GCNConv class Encoder(nn.Module): """Two-layer GCN encoder for VGAE with input BatchNorm. Architecture: BatchNorm → GCN(hidden) → ReLU → Dropout → GCN_mu / GCN_logstd The BatchNorm layer normalises raw ChIP-seq signals and its running statistics are saved in model.pt, so encode_graph.py applies identical normalisation to held-out cell lines without a separate scaler file. """ def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2): super().__init__() self.norm = nn.BatchNorm1d(in_dim) self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.dropout = dropout def forward(self, x, edge_index): x = self.norm(x) h = F.relu(self.gc1(x, edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)