initial framework; to be extended

This commit is contained in:
2025-10-23 08:56:46 +02:00
parent 32228496d2
commit 430e0a10ba
2 changed files with 15 additions and 20 deletions

View File

@@ -1,8 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Compares two latent embedding matrices (e.g., CTRL vs EED-i), Compares two latent embedding matrices,
computes similarity metrics (cosine, Euclidean, L1), computes similarity metrics (cosine, Euclidean, L1)
and saves both a CSV and an optional line plot.
Usage: Usage:
python scripts/compare_embeddings_general.py \ python scripts/compare_embeddings_general.py \
@@ -45,7 +44,7 @@ def main():
p.add_argument("--no-plot", action="store_true", help="Skip generating the plot") p.add_argument("--no-plot", action="store_true", help="Skip generating the plot")
args = p.parse_args() args = p.parse_args()
# ---- Load ---- # Load
emb1 = np.load(args.emb1) emb1 = np.load(args.emb1)
emb2 = np.load(args.emb2) emb2 = np.load(args.emb2)
if emb1.shape != emb2.shape: if emb1.shape != emb2.shape:
@@ -55,7 +54,7 @@ def main():
n_bins, n_dim = emb1.shape n_bins, n_dim = emb1.shape
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims") print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
# ---- Compute metrics ---- # Compute metrics
cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2) cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2)
df = pd.DataFrame({ df = pd.DataFrame({
@@ -69,7 +68,7 @@ def main():
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
print(f"Saved metrics → {csv_path}") print(f"Saved metrics → {csv_path}")
# ---- Plot ---- # Plot
if not args.no_plot: if not args.no_plot:
plt.figure(figsize=(12, 4)) plt.figure(figsize=(12, 4))
plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue") plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue")

View File

@@ -1,13 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph. Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
---
Inputs: Inputs:
- A PyTorch Geometric Data object saved with torch.save(...) containing: - A PyTorch Geometric Data object saved with torch.save()
x : [num_nodes, num_features] node features
edge_index : [2, num_edges] undirected edges (will be coalesced)
edge_weight : [num_edges] (optional, unused by VGAE)
- from build_graph.py - from build_graph.py
--- ---
Outputs (under results/): Outputs (under results/):
@@ -80,14 +76,14 @@ def main():
np.random.seed(args.seed) np.random.seed(args.seed)
os.makedirs(args.outdir, exist_ok=True) os.makedirs(args.outdir, exist_ok=True)
# ---- Load graph ---- # Load graph
data = torch.load(args.graph) data = torch.load(args.graph)
# Coalesce/clean edges # Coalesce/clean edges
ei, _ = remove_self_loops(data.edge_index) ei, _ = remove_self_loops(data.edge_index)
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes) data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
x = data.x.float() x = data.x.float()
# ---- Split edges for link prediction ---- # Split edges for link prediction
splitter = RandomLinkSplit( splitter = RandomLinkSplit(
num_val=0.1, num_val=0.1,
num_test=0.1, num_test=0.1,
@@ -112,12 +108,12 @@ def main():
) )
# ---- Model ---- # Model
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout) enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
model = VGAE(enc) model = VGAE(enc)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# ---- Training loop ---- # Training loop
best_val_auc = -1.0 best_val_auc = -1.0
best_state = None best_state = None
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
@@ -133,7 +129,7 @@ def main():
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# ---- Validation ---- # Validation
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
@@ -146,18 +142,18 @@ def main():
if epoch % 10 == 0 or epoch == 1: if epoch % 10 == 0 or epoch == 1:
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}") print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
# ---- Save best model ---- # Save best model
model.load_state_dict(best_state) model.load_state_dict(best_state)
model_path = os.path.join(args.outdir, "model.pt") model_path = os.path.join(args.outdir, "model.pt")
torch.save(model.state_dict(), model_path) torch.save(model.state_dict(), model_path)
# ---- Final test metrics ---- # Final test metrics
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
z_final = model.encode(x, data.edge_index) z_final = model.encode(x, data.edge_index)
test_auc, test_ap = eval_linkpred(model, test_data, z_final) test_auc, test_ap = eval_linkpred(model, test_data, z_final)
# ---- Save embeddings & metrics ---- # Save embeddings & metrics
emb_path = os.path.join(args.outdir, "emb.npy") emb_path = os.path.join(args.outdir, "emb.npy")
np.save(emb_path, z_final.cpu().numpy()) np.save(emb_path, z_final.cpu().numpy())