63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Encode a new graph using a trained VGAE model.
|
|
Automatically infers hidden/latent dimensions from saved weights.
|
|
"""
|
|
|
|
import argparse, torch, numpy as np
|
|
from torch_geometric.nn import GCNConv, VGAE
|
|
|
|
# Reuse your Encoder definition directly here for clarity
|
|
class Encoder(torch.nn.Module):
|
|
def __init__(self, in_dim, hidden, latent, dropout=0.2):
|
|
super().__init__()
|
|
self.gc1 = GCNConv(in_dim, hidden)
|
|
self.gc_mu = GCNConv(hidden, latent)
|
|
self.gc_log = GCNConv(hidden, latent)
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x, edge_index):
|
|
import torch.nn.functional as F
|
|
h = self.gc1(x, edge_index)
|
|
h = F.relu(h)
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--model", required=True)
|
|
p.add_argument("--graph", required=True)
|
|
p.add_argument("--out", required=True)
|
|
args = p.parse_args()
|
|
|
|
# ---- Load data and model state ----
|
|
data = torch.load(args.graph)
|
|
model_state = torch.load(args.model, map_location="cpu")
|
|
|
|
# ---- Infer dimensions dynamically ----
|
|
in_dim = data.x.size(1)
|
|
# detect hidden and latent dimensions safely
|
|
keys = list(model_state.keys())
|
|
gc1_weight = [k for k in keys if "gc1" in k and "weight" in k][0]
|
|
gc_mu_weight = [k for k in keys if "gc_mu" in k and "weight" in k][0]
|
|
|
|
hidden = model_state[gc1_weight].shape[0]
|
|
latent = model_state[gc_mu_weight].shape[0]
|
|
|
|
print(f"Inferred dims: in={in_dim}, hidden={hidden}, latent={latent}")
|
|
|
|
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
|
model = VGAE(enc)
|
|
model.load_state_dict(model_state)
|
|
model.eval()
|
|
|
|
# ---- Encode ----
|
|
with torch.no_grad():
|
|
z = model.encode(data.x.float(), data.edge_index)
|
|
np.save(args.out, z.cpu().numpy())
|
|
print(f"Saved embeddings → {args.out} shape={z.shape}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |