initial framework; to be extended
This commit is contained in:
@@ -1,8 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Compares two latent embedding matrices (e.g., CTRL vs EED-i),
|
Compares two latent embedding matrices,
|
||||||
computes similarity metrics (cosine, Euclidean, L1),
|
computes similarity metrics (cosine, Euclidean, L1)
|
||||||
and saves both a CSV and an optional line plot.
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python scripts/compare_embeddings_general.py \
|
python scripts/compare_embeddings_general.py \
|
||||||
@@ -45,7 +44,7 @@ def main():
|
|||||||
p.add_argument("--no-plot", action="store_true", help="Skip generating the plot")
|
p.add_argument("--no-plot", action="store_true", help="Skip generating the plot")
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
# ---- Load ----
|
# Load
|
||||||
emb1 = np.load(args.emb1)
|
emb1 = np.load(args.emb1)
|
||||||
emb2 = np.load(args.emb2)
|
emb2 = np.load(args.emb2)
|
||||||
if emb1.shape != emb2.shape:
|
if emb1.shape != emb2.shape:
|
||||||
@@ -55,7 +54,7 @@ def main():
|
|||||||
n_bins, n_dim = emb1.shape
|
n_bins, n_dim = emb1.shape
|
||||||
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
|
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
|
||||||
|
|
||||||
# ---- Compute metrics ----
|
# Compute metrics
|
||||||
cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2)
|
cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2)
|
||||||
|
|
||||||
df = pd.DataFrame({
|
df = pd.DataFrame({
|
||||||
@@ -69,7 +68,7 @@ def main():
|
|||||||
df.to_csv(csv_path, index=False)
|
df.to_csv(csv_path, index=False)
|
||||||
print(f"Saved metrics → {csv_path}")
|
print(f"Saved metrics → {csv_path}")
|
||||||
|
|
||||||
# ---- Plot ----
|
# Plot
|
||||||
if not args.no_plot:
|
if not args.no_plot:
|
||||||
plt.figure(figsize=(12, 4))
|
plt.figure(figsize=(12, 4))
|
||||||
plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue")
|
plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue")
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
||||||
|
---
|
||||||
Inputs:
|
Inputs:
|
||||||
- A PyTorch Geometric Data object saved with torch.save(...) containing:
|
- A PyTorch Geometric Data object saved with torch.save()
|
||||||
x : [num_nodes, num_features] node features
|
|
||||||
edge_index : [2, num_edges] undirected edges (will be coalesced)
|
|
||||||
edge_weight : [num_edges] (optional, unused by VGAE)
|
|
||||||
|
|
||||||
- from build_graph.py
|
- from build_graph.py
|
||||||
---
|
---
|
||||||
Outputs (under results/):
|
Outputs (under results/):
|
||||||
@@ -80,14 +76,14 @@ def main():
|
|||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
os.makedirs(args.outdir, exist_ok=True)
|
os.makedirs(args.outdir, exist_ok=True)
|
||||||
|
|
||||||
# ---- Load graph ----
|
# Load graph
|
||||||
data = torch.load(args.graph)
|
data = torch.load(args.graph)
|
||||||
# Coalesce/clean edges
|
# Coalesce/clean edges
|
||||||
ei, _ = remove_self_loops(data.edge_index)
|
ei, _ = remove_self_loops(data.edge_index)
|
||||||
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
|
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
|
||||||
x = data.x.float()
|
x = data.x.float()
|
||||||
|
|
||||||
# ---- Split edges for link prediction ----
|
# Split edges for link prediction
|
||||||
splitter = RandomLinkSplit(
|
splitter = RandomLinkSplit(
|
||||||
num_val=0.1,
|
num_val=0.1,
|
||||||
num_test=0.1,
|
num_test=0.1,
|
||||||
@@ -112,12 +108,12 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---- Model ----
|
# Model
|
||||||
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
|
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
|
||||||
model = VGAE(enc)
|
model = VGAE(enc)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
# ---- Training loop ----
|
# Training loop
|
||||||
best_val_auc = -1.0
|
best_val_auc = -1.0
|
||||||
best_state = None
|
best_state = None
|
||||||
for epoch in range(1, args.epochs + 1):
|
for epoch in range(1, args.epochs + 1):
|
||||||
@@ -133,7 +129,7 @@ def main():
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# ---- Validation ----
|
# Validation
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
|
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
|
||||||
@@ -146,18 +142,18 @@ def main():
|
|||||||
if epoch % 10 == 0 or epoch == 1:
|
if epoch % 10 == 0 or epoch == 1:
|
||||||
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
||||||
|
|
||||||
# ---- Save best model ----
|
# Save best model
|
||||||
model.load_state_dict(best_state)
|
model.load_state_dict(best_state)
|
||||||
model_path = os.path.join(args.outdir, "model.pt")
|
model_path = os.path.join(args.outdir, "model.pt")
|
||||||
torch.save(model.state_dict(), model_path)
|
torch.save(model.state_dict(), model_path)
|
||||||
|
|
||||||
# ---- Final test metrics ----
|
# Final test metrics
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z_final = model.encode(x, data.edge_index)
|
z_final = model.encode(x, data.edge_index)
|
||||||
test_auc, test_ap = eval_linkpred(model, test_data, z_final)
|
test_auc, test_ap = eval_linkpred(model, test_data, z_final)
|
||||||
|
|
||||||
# ---- Save embeddings & metrics ----
|
# Save embeddings & metrics
|
||||||
emb_path = os.path.join(args.outdir, "emb.npy")
|
emb_path = os.path.join(args.outdir, "emb.npy")
|
||||||
np.save(emb_path, z_final.cpu().numpy())
|
np.save(emb_path, z_final.cpu().numpy())
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user