initial framework; to be extended

This commit is contained in:
2025-10-23 08:46:48 +02:00
commit 32228496d2
7 changed files with 469 additions and 0 deletions

30
.gitignore vendored Normal file
View File

@@ -0,0 +1,30 @@
# Python
__pycache__/
*.pyc
# Conda / mamba envs
.env/
*.yml.lock
# Data
*.hic
*.mcool
*.cool
*.bam
*.bw
*.bigwig
*.bed
*.pairs*
*.pt
*.npy
*.csv
*.png
# Jupyter and logs
*.ipynb_checkpoints/
*.log
.DS_Store
# Results / temp
results/
data/

1
README.md Normal file
View File

@@ -0,0 +1 @@
# Chromatin-GNN: Graph representation learning for 3D genome architecture

21
env.yml Normal file
View File

@@ -0,0 +1,21 @@
name: chromatin_gnn_aman
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- python=3.10
- pytorch
- torchvision
- torchaudio
- cooler
- pybigwig
- pandas
- numpy
- scikit-learn
- matplotlib
- umap-learn
- pip
- pip:
- torch-geometric

84
scripts/build_graph.py Normal file
View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python3
import argparse
import numpy as np
import pandas as pd
import torch
import cooler
import pyBigWig
from torch_geometric.data import Data
def bin_bigwig(bw_path, chrom, bins):
"""Average bigWig signal across each genomic bin"""
bw = pyBigWig.open(bw_path)
if chrom not in bw.chroms():
raise ValueError(f"{chrom} not found in {bw_path}. Available: {list(bw.chroms().keys())[:5]}...")
chrom_len = bw.chroms(chrom)
vals = []
for s, e in bins:
s = max(0, s)
e = min(chrom_len, e)
if s >= e:
vals.append(0.0)
continue
v = bw.stats(chrom, s, e, type="mean")[0]
vals.append(0.0 if v is None or np.isnan(v) else v)
bw.close()
return np.array(vals)
def build_graph(mcool_path, chrom, res, bigwigs, out_path, max_dist=5_000_000):
"""Convert .mcool + bigWigs to PyTorch Geometric Data object."""
print(f"Processing {chrom} at {res} bp resolution...")
# Load pixels
c = cooler.Cooler(f"{mcool_path}::resolutions/{res}")
pixels = c.matrix(balance=True, as_pixels=True, join=True).fetch(chrom)
pixels = pixels.query(f"chrom1 == chrom2 and abs(start2 - start1) <= {max_dist}")
# Map genomic coordinates to bin IDs
bins_df = c.bins().fetch(chrom)
bins_df["bin_id"] = np.arange(len(bins_df))
start_to_bin = dict(zip(bins_df["start"].values, bins_df["bin_id"].values))
valid = pixels["start1"].isin(start_to_bin) & pixels["start2"].isin(start_to_bin)
pixels = pixels.loc[valid]
bin1 = pixels["start1"].map(start_to_bin).values
bin2 = pixels["start2"].map(start_to_bin).values
edge_index = torch.tensor([bin1, bin2], dtype=torch.long)
# Edge weights
if "balanced" in pixels.columns and pixels["balanced"].notna().any():
w = pixels["balanced"].fillna(0).values
else:
w = pixels["count"].values
edge_weight = torch.tensor(np.log1p(w), dtype=torch.float)
# Node features
starts = bins_df["start"].values
bins = [(int(s), int(s + res)) for s in starts]
node_feats = []
for bw in bigwigs:
print(f" Adding feature from {bw}")
node_feats.append(bin_bigwig(bw, chrom, bins))
x = torch.tensor(np.stack(node_feats, axis=1), dtype=torch.float)
# Save graph
data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight)
torch.save(data, out_path)
print(f"Saved {chrom}: {x.shape[0]} nodes, {edge_index.shape[1]} edges → {out_path}")
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Build graph from Micro-C and bigWigs")
p.add_argument("--mcool", required=True, help="Path to .mcool file")
p.add_argument("--chrom", required=True, help="Chromosome name (e.g., chr21)")
p.add_argument("--res", type=int, default=10000, help="Resolution (bp)")
p.add_argument("--bigwigs", nargs="+", required=True, help="List of bigWig feature files")
p.add_argument("--out", required=True, help="Output .pt file path")
p.add_argument("--max_dist", type=int, default=5_000_000, help="Max genomic distance for edges")
args = p.parse_args()
build_graph(args.mcool, args.chrom, args.res, args.bigwigs, args.out, args.max_dist)

View File

@@ -0,0 +1,86 @@
#!/usr/bin/env python3
"""
Compares two latent embedding matrices (e.g., CTRL vs EED-i),
computes similarity metrics (cosine, Euclidean, L1),
and saves both a CSV and an optional line plot.
Usage:
python scripts/compare_embeddings_general.py \
--emb1 results/emb.npy \
--emb2 results/emb_eedi.npy \
--label1 CTRL --label2 EEDi \
--prefix results/chr21
"""
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial.distance import cosine, euclidean, cityblock
import os
def compute_metrics(emb1, emb2):
"""Compute cosine similarity, cosine distance, L2, and L1 per row."""
cos_sims, cos_dists, l2_dists, l1_dists = [], [], [], []
for a, b in zip(emb1, emb2):
cos_sim = 1 - cosine(a, b)
cos_dist = 1 - cos_sim
l2 = euclidean(a, b)
l1 = cityblock(a, b)
cos_sims.append(cos_sim)
cos_dists.append(cos_dist)
l2_dists.append(l2)
l1_dists.append(l1)
return np.array(cos_sims), np.array(cos_dists), np.array(l2_dists), np.array(l1_dists)
def main():
p = argparse.ArgumentParser(description="Compare two embedding matrices")
p.add_argument("--emb1", required=True, help="Path to first embedding .npy")
p.add_argument("--emb2", required=True, help="Path to second embedding .npy")
p.add_argument("--label1", default="A", help="Label for first embedding")
p.add_argument("--label2", default="B", help="Label for second embedding")
p.add_argument("--prefix", default="results/compare", help="Prefix for output files")
p.add_argument("--no-plot", action="store_true", help="Skip generating the plot")
args = p.parse_args()
# ---- Load ----
emb1 = np.load(args.emb1)
emb2 = np.load(args.emb2)
if emb1.shape != emb2.shape:
raise ValueError(f"Shape mismatch: {emb1.shape} vs {emb2.shape}")
os.makedirs(os.path.dirname(args.prefix), exist_ok=True)
n_bins, n_dim = emb1.shape
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
# ---- Compute metrics ----
cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2)
df = pd.DataFrame({
"bin_id": np.arange(n_bins),
"cosine_similarity": cos_sims,
"cosine_distance": cos_dists,
"euclidean": l2_dists,
"manhattan": l1_dists
})
csv_path = f"{args.prefix}_delta.csv"
df.to_csv(csv_path, index=False)
print(f"Saved metrics → {csv_path}")
# ---- Plot ----
if not args.no_plot:
plt.figure(figsize=(12, 4))
plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue")
plt.title(f"Δ-Embedding ({args.label1} vs {args.label2})")
plt.xlabel("Bin index")
plt.ylabel("Cosine distance (1 similarity)")
plt.tight_layout()
fig_path = f"{args.prefix}_delta.png"
plt.savefig(fig_path, dpi=300)
print(f"Saved plot → {fig_path}")
if __name__ == "__main__":
main()

63
scripts/encode_graph.py Normal file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
"""
Encode a new graph using a trained VGAE model.
Automatically infers hidden/latent dimensions from saved weights.
"""
import argparse, torch, numpy as np
from torch_geometric.nn import GCNConv, VGAE
# Reuse your Encoder definition directly here for clarity
class Encoder(torch.nn.Module):
def __init__(self, in_dim, hidden, latent, dropout=0.2):
super().__init__()
self.gc1 = GCNConv(in_dim, hidden)
self.gc_mu = GCNConv(hidden, latent)
self.gc_log = GCNConv(hidden, latent)
self.dropout = dropout
def forward(self, x, edge_index):
import torch.nn.functional as F
h = self.gc1(x, edge_index)
h = F.relu(h)
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
def main():
p = argparse.ArgumentParser()
p.add_argument("--model", required=True)
p.add_argument("--graph", required=True)
p.add_argument("--out", required=True)
args = p.parse_args()
# ---- Load data and model state ----
data = torch.load(args.graph)
model_state = torch.load(args.model, map_location="cpu")
# ---- Infer dimensions dynamically ----
in_dim = data.x.size(1)
# detect hidden and latent dimensions safely
keys = list(model_state.keys())
gc1_weight = [k for k in keys if "gc1" in k and "weight" in k][0]
gc_mu_weight = [k for k in keys if "gc_mu" in k and "weight" in k][0]
hidden = model_state[gc1_weight].shape[0]
latent = model_state[gc_mu_weight].shape[0]
print(f"Inferred dims: in={in_dim}, hidden={hidden}, latent={latent}")
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
model = VGAE(enc)
model.load_state_dict(model_state)
model.eval()
# ---- Encode ----
with torch.no_grad():
z = model.encode(data.x.float(), data.edge_index)
np.save(args.out, z.cpu().numpy())
print(f"Saved embeddings → {args.out} shape={z.shape}")
if __name__ == "__main__":
main()

184
scripts/train_vgae.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
Inputs:
- A PyTorch Geometric Data object saved with torch.save(...) containing:
x : [num_nodes, num_features] node features
edge_index : [2, num_edges] undirected edges (will be coalesced)
edge_weight : [num_edges] (optional, unused by VGAE)
- from build_graph.py
---
Outputs (under results/):
- model.pt : trained VGAE state_dict
- emb.npy : node embeddings (mean; shape [num_nodes, latent_dim])
- metrics.json : train/val/test AUC/AP summary
"""
import os, json, argparse, numpy as np, torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn.models import VGAE
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import to_undirected, remove_self_loops
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, average_precision_score
class Encoder(torch.nn.Module):
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
super().__init__()
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout
def forward(self, x, edge_index):
h = self.gc1(x, edge_index)
h = F.relu(h)
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
@torch.no_grad()
def eval_linkpred(model, data_like, z):
"""Compute AUROC/AP using provided positive/negative edges."""
pos = data_like.pos_edge_index
neg = data_like.neg_edge_index
# model.test returns (auc, ap) but relies on torchmetrics in some versions;
# compute explicitly for stability:
def sigmoid(x): return 1 / (1 + torch.exp(-x))
# Inner product decoder scores
def scores(edges):
src, dst = edges
s = (z[src] * z[dst]).sum(dim=1)
return sigmoid(s).cpu().numpy()
y_true = np.concatenate([np.ones(pos.size(1)), np.zeros(neg.size(1))])
y_pred = np.concatenate([scores(pos), scores(neg)])
auc = roc_auc_score(y_true, y_pred)
ap = average_precision_score(y_true, y_pred)
return auc, ap
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--graph", required=True, help="Path to Data .pt file")
ap.add_argument("--epochs", type=int, default=100)
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--hidden", type=int, default=128)
ap.add_argument("--latent", type=int, default=64)
ap.add_argument("--dropout", type=float, default=0.2)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--outdir", default="results")
args = ap.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
os.makedirs(args.outdir, exist_ok=True)
# ---- Load graph ----
data = torch.load(args.graph)
# Coalesce/clean edges
ei, _ = remove_self_loops(data.edge_index)
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
x = data.x.float()
# ---- Split edges for link prediction ----
splitter = RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
add_negative_train_samples=False,
split_labels=False,
)
train_data, val_data, test_data = splitter(data)
# Positive edges are just the edges in each split
train_data.pos_edge_index = train_data.edge_index
val_data.pos_edge_index = val_data.edge_index
test_data.pos_edge_index = test_data.edge_index
# Generate negative edges for validation and test manually
for subset in [val_data, test_data]:
subset.neg_edge_index = negative_sampling(
edge_index=subset.edge_index,
num_nodes=data.num_nodes,
num_neg_samples=subset.edge_index.size(1),
method='sparse'
)
# ---- Model ----
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
model = VGAE(enc)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# ---- Training loop ----
best_val_auc = -1.0
best_state = None
for epoch in range(1, args.epochs + 1):
model.train()
optimizer.zero_grad()
# Encode using remaining training edges
z = model.encode(x, train_data.edge_index)
# Reconstruction loss on positive training edges (negatives sampled inside)
loss_recon = model.recon_loss(z, train_data.pos_edge_index)
# KL divergence regularizer
loss_kl = (1.0 / data.num_nodes) * model.kl_loss()
loss = loss_recon + loss_kl
loss.backward()
optimizer.step()
# ---- Validation ----
model.eval()
with torch.no_grad():
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
val_auc, val_ap = eval_linkpred(model, val_data, z_full)
if val_auc > best_val_auc:
best_val_auc = val_auc
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
if epoch % 10 == 0 or epoch == 1:
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
# ---- Save best model ----
model.load_state_dict(best_state)
model_path = os.path.join(args.outdir, "model.pt")
torch.save(model.state_dict(), model_path)
# ---- Final test metrics ----
model.eval()
with torch.no_grad():
z_final = model.encode(x, data.edge_index)
test_auc, test_ap = eval_linkpred(model, test_data, z_final)
# ---- Save embeddings & metrics ----
emb_path = os.path.join(args.outdir, "emb.npy")
np.save(emb_path, z_final.cpu().numpy())
metrics = {
"val_auc": float(best_val_auc),
"test_auc": float(test_auc),
"test_ap": float(test_ap),
"epochs": args.epochs,
"hidden": args.hidden,
"latent": args.latent,
"dropout": args.dropout,
"lr": args.lr,
"seed": args.seed
}
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
json.dump(metrics, f, indent=2)
print(f"Saved model -> {model_path}")
print(f"Saved embeddings -> {emb_path} (shape={z_final.shape})")
print(f"Metrics: AUC(test)={test_auc:.4f}, AP(test)={test_ap:.4f}")
if __name__ == "__main__":
main()