From 32228496d21935e26e12ec17d37455d3d36ab997 Mon Sep 17 00:00:00 2001 From: aman Date: Thu, 23 Oct 2025 08:46:48 +0200 Subject: [PATCH] initial framework; to be extended --- .gitignore | 30 ++++++ README.md | 1 + env.yml | 21 ++++ scripts/build_graph.py | 84 ++++++++++++++++ scripts/compare_embeddings.py | 86 ++++++++++++++++ scripts/encode_graph.py | 63 ++++++++++++ scripts/train_vgae.py | 184 ++++++++++++++++++++++++++++++++++ 7 files changed, 469 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 env.yml create mode 100644 scripts/build_graph.py create mode 100644 scripts/compare_embeddings.py create mode 100644 scripts/encode_graph.py create mode 100644 scripts/train_vgae.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fcf296e --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +# Python +__pycache__/ +*.pyc + +# Conda / mamba envs +.env/ +*.yml.lock + +# Data +*.hic +*.mcool +*.cool +*.bam +*.bw +*.bigwig +*.bed +*.pairs* +*.pt +*.npy +*.csv +*.png + +# Jupyter and logs +*.ipynb_checkpoints/ +*.log +.DS_Store + +# Results / temp +results/ +data/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..62b1af1 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# Chromatin-GNN: Graph representation learning for 3D genome architecture diff --git a/env.yml b/env.yml new file mode 100644 index 0000000..42aec6d --- /dev/null +++ b/env.yml @@ -0,0 +1,21 @@ +name: chromatin_gnn_aman +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python=3.10 + - pytorch + - torchvision + - torchaudio + - cooler + - pybigwig + - pandas + - numpy + - scikit-learn + - matplotlib + - umap-learn + - pip + - pip: + - torch-geometric + diff --git a/scripts/build_graph.py b/scripts/build_graph.py new file mode 100644 index 0000000..40e169c --- /dev/null +++ b/scripts/build_graph.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +import argparse +import numpy as np +import pandas as pd +import torch +import cooler +import pyBigWig +from torch_geometric.data import Data + + +def bin_bigwig(bw_path, chrom, bins): + """Average bigWig signal across each genomic bin""" + bw = pyBigWig.open(bw_path) + if chrom not in bw.chroms(): + raise ValueError(f"{chrom} not found in {bw_path}. Available: {list(bw.chroms().keys())[:5]}...") + chrom_len = bw.chroms(chrom) + vals = [] + for s, e in bins: + s = max(0, s) + e = min(chrom_len, e) + if s >= e: + vals.append(0.0) + continue + v = bw.stats(chrom, s, e, type="mean")[0] + vals.append(0.0 if v is None or np.isnan(v) else v) + bw.close() + return np.array(vals) + + +def build_graph(mcool_path, chrom, res, bigwigs, out_path, max_dist=5_000_000): + """Convert .mcool + bigWigs to PyTorch Geometric Data object.""" + print(f"Processing {chrom} at {res} bp resolution...") + + # Load pixels + c = cooler.Cooler(f"{mcool_path}::resolutions/{res}") + pixels = c.matrix(balance=True, as_pixels=True, join=True).fetch(chrom) + pixels = pixels.query(f"chrom1 == chrom2 and abs(start2 - start1) <= {max_dist}") + + # Map genomic coordinates to bin IDs + bins_df = c.bins().fetch(chrom) + bins_df["bin_id"] = np.arange(len(bins_df)) + start_to_bin = dict(zip(bins_df["start"].values, bins_df["bin_id"].values)) + + valid = pixels["start1"].isin(start_to_bin) & pixels["start2"].isin(start_to_bin) + pixels = pixels.loc[valid] + + bin1 = pixels["start1"].map(start_to_bin).values + bin2 = pixels["start2"].map(start_to_bin).values + edge_index = torch.tensor([bin1, bin2], dtype=torch.long) + + # Edge weights + if "balanced" in pixels.columns and pixels["balanced"].notna().any(): + w = pixels["balanced"].fillna(0).values + else: + w = pixels["count"].values + edge_weight = torch.tensor(np.log1p(w), dtype=torch.float) + + # Node features + starts = bins_df["start"].values + bins = [(int(s), int(s + res)) for s in starts] + node_feats = [] + for bw in bigwigs: + print(f" Adding feature from {bw}") + node_feats.append(bin_bigwig(bw, chrom, bins)) + x = torch.tensor(np.stack(node_feats, axis=1), dtype=torch.float) + + # Save graph + data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight) + torch.save(data, out_path) + print(f"Saved {chrom}: {x.shape[0]} nodes, {edge_index.shape[1]} edges → {out_path}") + + +if __name__ == "__main__": + p = argparse.ArgumentParser(description="Build graph from Micro-C and bigWigs") + p.add_argument("--mcool", required=True, help="Path to .mcool file") + p.add_argument("--chrom", required=True, help="Chromosome name (e.g., chr21)") + p.add_argument("--res", type=int, default=10000, help="Resolution (bp)") + p.add_argument("--bigwigs", nargs="+", required=True, help="List of bigWig feature files") + p.add_argument("--out", required=True, help="Output .pt file path") + p.add_argument("--max_dist", type=int, default=5_000_000, help="Max genomic distance for edges") + args = p.parse_args() + + build_graph(args.mcool, args.chrom, args.res, args.bigwigs, args.out, args.max_dist) diff --git a/scripts/compare_embeddings.py b/scripts/compare_embeddings.py new file mode 100644 index 0000000..4cbd8ab --- /dev/null +++ b/scripts/compare_embeddings.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Compares two latent embedding matrices (e.g., CTRL vs EED-i), +computes similarity metrics (cosine, Euclidean, L1), +and saves both a CSV and an optional line plot. + +Usage: + python scripts/compare_embeddings_general.py \ + --emb1 results/emb.npy \ + --emb2 results/emb_eedi.npy \ + --label1 CTRL --label2 EEDi \ + --prefix results/chr21 +""" + +import argparse +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from scipy.spatial.distance import cosine, euclidean, cityblock +import os + + +def compute_metrics(emb1, emb2): + """Compute cosine similarity, cosine distance, L2, and L1 per row.""" + cos_sims, cos_dists, l2_dists, l1_dists = [], [], [], [] + for a, b in zip(emb1, emb2): + cos_sim = 1 - cosine(a, b) + cos_dist = 1 - cos_sim + l2 = euclidean(a, b) + l1 = cityblock(a, b) + cos_sims.append(cos_sim) + cos_dists.append(cos_dist) + l2_dists.append(l2) + l1_dists.append(l1) + return np.array(cos_sims), np.array(cos_dists), np.array(l2_dists), np.array(l1_dists) + + +def main(): + p = argparse.ArgumentParser(description="Compare two embedding matrices") + p.add_argument("--emb1", required=True, help="Path to first embedding .npy") + p.add_argument("--emb2", required=True, help="Path to second embedding .npy") + p.add_argument("--label1", default="A", help="Label for first embedding") + p.add_argument("--label2", default="B", help="Label for second embedding") + p.add_argument("--prefix", default="results/compare", help="Prefix for output files") + p.add_argument("--no-plot", action="store_true", help="Skip generating the plot") + args = p.parse_args() + + # ---- Load ---- + emb1 = np.load(args.emb1) + emb2 = np.load(args.emb2) + if emb1.shape != emb2.shape: + raise ValueError(f"Shape mismatch: {emb1.shape} vs {emb2.shape}") + + os.makedirs(os.path.dirname(args.prefix), exist_ok=True) + n_bins, n_dim = emb1.shape + print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims") + + # ---- Compute metrics ---- + cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2) + + df = pd.DataFrame({ + "bin_id": np.arange(n_bins), + "cosine_similarity": cos_sims, + "cosine_distance": cos_dists, + "euclidean": l2_dists, + "manhattan": l1_dists + }) + csv_path = f"{args.prefix}_delta.csv" + df.to_csv(csv_path, index=False) + print(f"Saved metrics → {csv_path}") + + # ---- Plot ---- + if not args.no_plot: + plt.figure(figsize=(12, 4)) + plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue") + plt.title(f"Δ-Embedding ({args.label1} vs {args.label2})") + plt.xlabel("Bin index") + plt.ylabel("Cosine distance (1 – similarity)") + plt.tight_layout() + fig_path = f"{args.prefix}_delta.png" + plt.savefig(fig_path, dpi=300) + print(f"Saved plot → {fig_path}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/encode_graph.py b/scripts/encode_graph.py new file mode 100644 index 0000000..b588e10 --- /dev/null +++ b/scripts/encode_graph.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +""" +Encode a new graph using a trained VGAE model. +Automatically infers hidden/latent dimensions from saved weights. +""" + +import argparse, torch, numpy as np +from torch_geometric.nn import GCNConv, VGAE + +# Reuse your Encoder definition directly here for clarity +class Encoder(torch.nn.Module): + def __init__(self, in_dim, hidden, latent, dropout=0.2): + super().__init__() + self.gc1 = GCNConv(in_dim, hidden) + self.gc_mu = GCNConv(hidden, latent) + self.gc_log = GCNConv(hidden, latent) + self.dropout = dropout + + def forward(self, x, edge_index): + import torch.nn.functional as F + h = self.gc1(x, edge_index) + h = F.relu(h) + h = F.dropout(h, p=self.dropout, training=self.training) + return self.gc_mu(h, edge_index), self.gc_log(h, edge_index) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--model", required=True) + p.add_argument("--graph", required=True) + p.add_argument("--out", required=True) + args = p.parse_args() + + # ---- Load data and model state ---- + data = torch.load(args.graph) + model_state = torch.load(args.model, map_location="cpu") + + # ---- Infer dimensions dynamically ---- + in_dim = data.x.size(1) + # detect hidden and latent dimensions safely + keys = list(model_state.keys()) + gc1_weight = [k for k in keys if "gc1" in k and "weight" in k][0] + gc_mu_weight = [k for k in keys if "gc_mu" in k and "weight" in k][0] + + hidden = model_state[gc1_weight].shape[0] + latent = model_state[gc_mu_weight].shape[0] + + print(f"Inferred dims: in={in_dim}, hidden={hidden}, latent={latent}") + + enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent) + model = VGAE(enc) + model.load_state_dict(model_state) + model.eval() + + # ---- Encode ---- + with torch.no_grad(): + z = model.encode(data.x.float(), data.edge_index) + np.save(args.out, z.cpu().numpy()) + print(f"Saved embeddings → {args.out} shape={z.shape}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/train_vgae.py b/scripts/train_vgae.py new file mode 100644 index 0000000..4bb2ba3 --- /dev/null +++ b/scripts/train_vgae.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" +Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph. + +Inputs: + - A PyTorch Geometric Data object saved with torch.save(...) containing: + x : [num_nodes, num_features] node features + edge_index : [2, num_edges] undirected edges (will be coalesced) + edge_weight : [num_edges] (optional, unused by VGAE) + + - from build_graph.py +--- +Outputs (under results/): + - model.pt : trained VGAE state_dict + - emb.npy : node embeddings (mean; shape [num_nodes, latent_dim]) + - metrics.json : train/val/test AUC/AP summary +""" + +import os, json, argparse, numpy as np, torch +import torch.nn.functional as F +from torch_geometric.nn import GCNConv +from torch_geometric.nn.models import VGAE +from torch_geometric.transforms import RandomLinkSplit +from torch_geometric.utils import to_undirected, remove_self_loops +from torch_geometric.utils import negative_sampling +from sklearn.metrics import roc_auc_score, average_precision_score + + +class Encoder(torch.nn.Module): + def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2): + super().__init__() + self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) + self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True) + self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) + self.dropout = dropout + + def forward(self, x, edge_index): + h = self.gc1(x, edge_index) + h = F.relu(h) + h = F.dropout(h, p=self.dropout, training=self.training) + return self.gc_mu(h, edge_index), self.gc_log(h, edge_index) + + +@torch.no_grad() +def eval_linkpred(model, data_like, z): + """Compute AUROC/AP using provided positive/negative edges.""" + pos = data_like.pos_edge_index + neg = data_like.neg_edge_index + # model.test returns (auc, ap) but relies on torchmetrics in some versions; + # compute explicitly for stability: + def sigmoid(x): return 1 / (1 + torch.exp(-x)) + + # Inner product decoder scores + def scores(edges): + src, dst = edges + s = (z[src] * z[dst]).sum(dim=1) + return sigmoid(s).cpu().numpy() + + y_true = np.concatenate([np.ones(pos.size(1)), np.zeros(neg.size(1))]) + y_pred = np.concatenate([scores(pos), scores(neg)]) + + auc = roc_auc_score(y_true, y_pred) + ap = average_precision_score(y_true, y_pred) + return auc, ap + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--graph", required=True, help="Path to Data .pt file") + ap.add_argument("--epochs", type=int, default=100) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument("--hidden", type=int, default=128) + ap.add_argument("--latent", type=int, default=64) + ap.add_argument("--dropout", type=float, default=0.2) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--outdir", default="results") + args = ap.parse_args() + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + os.makedirs(args.outdir, exist_ok=True) + + # ---- Load graph ---- + data = torch.load(args.graph) + # Coalesce/clean edges + ei, _ = remove_self_loops(data.edge_index) + data.edge_index = to_undirected(ei, num_nodes=data.num_nodes) + x = data.x.float() + + # ---- Split edges for link prediction ---- + splitter = RandomLinkSplit( + num_val=0.1, + num_test=0.1, + is_undirected=True, + add_negative_train_samples=False, + split_labels=False, + ) + train_data, val_data, test_data = splitter(data) + + # Positive edges are just the edges in each split + train_data.pos_edge_index = train_data.edge_index + val_data.pos_edge_index = val_data.edge_index + test_data.pos_edge_index = test_data.edge_index + + # Generate negative edges for validation and test manually + for subset in [val_data, test_data]: + subset.neg_edge_index = negative_sampling( + edge_index=subset.edge_index, + num_nodes=data.num_nodes, + num_neg_samples=subset.edge_index.size(1), + method='sparse' + ) + + + # ---- Model ---- + enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout) + model = VGAE(enc) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + # ---- Training loop ---- + best_val_auc = -1.0 + best_state = None + for epoch in range(1, args.epochs + 1): + model.train() + optimizer.zero_grad() + # Encode using remaining training edges + z = model.encode(x, train_data.edge_index) + # Reconstruction loss on positive training edges (negatives sampled inside) + loss_recon = model.recon_loss(z, train_data.pos_edge_index) + # KL divergence regularizer + loss_kl = (1.0 / data.num_nodes) * model.kl_loss() + loss = loss_recon + loss_kl + loss.backward() + optimizer.step() + + # ---- Validation ---- + model.eval() + with torch.no_grad(): + z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings + val_auc, val_ap = eval_linkpred(model, val_data, z_full) + + if val_auc > best_val_auc: + best_val_auc = val_auc + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + + if epoch % 10 == 0 or epoch == 1: + print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}") + + # ---- Save best model ---- + model.load_state_dict(best_state) + model_path = os.path.join(args.outdir, "model.pt") + torch.save(model.state_dict(), model_path) + + # ---- Final test metrics ---- + model.eval() + with torch.no_grad(): + z_final = model.encode(x, data.edge_index) + test_auc, test_ap = eval_linkpred(model, test_data, z_final) + + # ---- Save embeddings & metrics ---- + emb_path = os.path.join(args.outdir, "emb.npy") + np.save(emb_path, z_final.cpu().numpy()) + + metrics = { + "val_auc": float(best_val_auc), + "test_auc": float(test_auc), + "test_ap": float(test_ap), + "epochs": args.epochs, + "hidden": args.hidden, + "latent": args.latent, + "dropout": args.dropout, + "lr": args.lr, + "seed": args.seed + } + with open(os.path.join(args.outdir, "metrics.json"), "w") as f: + json.dump(metrics, f, indent=2) + + print(f"Saved model -> {model_path}") + print(f"Saved embeddings -> {emb_path} (shape={z_final.shape})") + print(f"Metrics: AUC(test)={test_auc:.4f}, AP(test)={test_ap:.4f}") + + +if __name__ == "__main__": + main()