initial framework; to be extended
This commit is contained in:
30
.gitignore
vendored
Normal file
30
.gitignore
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Conda / mamba envs
|
||||
.env/
|
||||
*.yml.lock
|
||||
|
||||
# Data
|
||||
*.hic
|
||||
*.mcool
|
||||
*.cool
|
||||
*.bam
|
||||
*.bw
|
||||
*.bigwig
|
||||
*.bed
|
||||
*.pairs*
|
||||
*.pt
|
||||
*.npy
|
||||
*.csv
|
||||
*.png
|
||||
|
||||
# Jupyter and logs
|
||||
*.ipynb_checkpoints/
|
||||
*.log
|
||||
.DS_Store
|
||||
|
||||
# Results / temp
|
||||
results/
|
||||
data/
|
||||
1
README.md
Normal file
1
README.md
Normal file
@@ -0,0 +1 @@
|
||||
# Chromatin-GNN: Graph representation learning for 3D genome architecture
|
||||
21
env.yml
Normal file
21
env.yml
Normal file
@@ -0,0 +1,21 @@
|
||||
name: chromatin_gnn_aman
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.10
|
||||
- pytorch
|
||||
- torchvision
|
||||
- torchaudio
|
||||
- cooler
|
||||
- pybigwig
|
||||
- pandas
|
||||
- numpy
|
||||
- scikit-learn
|
||||
- matplotlib
|
||||
- umap-learn
|
||||
- pip
|
||||
- pip:
|
||||
- torch-geometric
|
||||
|
||||
84
scripts/build_graph.py
Normal file
84
scripts/build_graph.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import cooler
|
||||
import pyBigWig
|
||||
from torch_geometric.data import Data
|
||||
|
||||
|
||||
def bin_bigwig(bw_path, chrom, bins):
|
||||
"""Average bigWig signal across each genomic bin"""
|
||||
bw = pyBigWig.open(bw_path)
|
||||
if chrom not in bw.chroms():
|
||||
raise ValueError(f"{chrom} not found in {bw_path}. Available: {list(bw.chroms().keys())[:5]}...")
|
||||
chrom_len = bw.chroms(chrom)
|
||||
vals = []
|
||||
for s, e in bins:
|
||||
s = max(0, s)
|
||||
e = min(chrom_len, e)
|
||||
if s >= e:
|
||||
vals.append(0.0)
|
||||
continue
|
||||
v = bw.stats(chrom, s, e, type="mean")[0]
|
||||
vals.append(0.0 if v is None or np.isnan(v) else v)
|
||||
bw.close()
|
||||
return np.array(vals)
|
||||
|
||||
|
||||
def build_graph(mcool_path, chrom, res, bigwigs, out_path, max_dist=5_000_000):
|
||||
"""Convert .mcool + bigWigs to PyTorch Geometric Data object."""
|
||||
print(f"Processing {chrom} at {res} bp resolution...")
|
||||
|
||||
# Load pixels
|
||||
c = cooler.Cooler(f"{mcool_path}::resolutions/{res}")
|
||||
pixels = c.matrix(balance=True, as_pixels=True, join=True).fetch(chrom)
|
||||
pixels = pixels.query(f"chrom1 == chrom2 and abs(start2 - start1) <= {max_dist}")
|
||||
|
||||
# Map genomic coordinates to bin IDs
|
||||
bins_df = c.bins().fetch(chrom)
|
||||
bins_df["bin_id"] = np.arange(len(bins_df))
|
||||
start_to_bin = dict(zip(bins_df["start"].values, bins_df["bin_id"].values))
|
||||
|
||||
valid = pixels["start1"].isin(start_to_bin) & pixels["start2"].isin(start_to_bin)
|
||||
pixels = pixels.loc[valid]
|
||||
|
||||
bin1 = pixels["start1"].map(start_to_bin).values
|
||||
bin2 = pixels["start2"].map(start_to_bin).values
|
||||
edge_index = torch.tensor([bin1, bin2], dtype=torch.long)
|
||||
|
||||
# Edge weights
|
||||
if "balanced" in pixels.columns and pixels["balanced"].notna().any():
|
||||
w = pixels["balanced"].fillna(0).values
|
||||
else:
|
||||
w = pixels["count"].values
|
||||
edge_weight = torch.tensor(np.log1p(w), dtype=torch.float)
|
||||
|
||||
# Node features
|
||||
starts = bins_df["start"].values
|
||||
bins = [(int(s), int(s + res)) for s in starts]
|
||||
node_feats = []
|
||||
for bw in bigwigs:
|
||||
print(f" Adding feature from {bw}")
|
||||
node_feats.append(bin_bigwig(bw, chrom, bins))
|
||||
x = torch.tensor(np.stack(node_feats, axis=1), dtype=torch.float)
|
||||
|
||||
# Save graph
|
||||
data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight)
|
||||
torch.save(data, out_path)
|
||||
print(f"Saved {chrom}: {x.shape[0]} nodes, {edge_index.shape[1]} edges → {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser(description="Build graph from Micro-C and bigWigs")
|
||||
p.add_argument("--mcool", required=True, help="Path to .mcool file")
|
||||
p.add_argument("--chrom", required=True, help="Chromosome name (e.g., chr21)")
|
||||
p.add_argument("--res", type=int, default=10000, help="Resolution (bp)")
|
||||
p.add_argument("--bigwigs", nargs="+", required=True, help="List of bigWig feature files")
|
||||
p.add_argument("--out", required=True, help="Output .pt file path")
|
||||
p.add_argument("--max_dist", type=int, default=5_000_000, help="Max genomic distance for edges")
|
||||
args = p.parse_args()
|
||||
|
||||
build_graph(args.mcool, args.chrom, args.res, args.bigwigs, args.out, args.max_dist)
|
||||
86
scripts/compare_embeddings.py
Normal file
86
scripts/compare_embeddings.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compares two latent embedding matrices (e.g., CTRL vs EED-i),
|
||||
computes similarity metrics (cosine, Euclidean, L1),
|
||||
and saves both a CSV and an optional line plot.
|
||||
|
||||
Usage:
|
||||
python scripts/compare_embeddings_general.py \
|
||||
--emb1 results/emb.npy \
|
||||
--emb2 results/emb_eedi.npy \
|
||||
--label1 CTRL --label2 EEDi \
|
||||
--prefix results/chr21
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.spatial.distance import cosine, euclidean, cityblock
|
||||
import os
|
||||
|
||||
|
||||
def compute_metrics(emb1, emb2):
|
||||
"""Compute cosine similarity, cosine distance, L2, and L1 per row."""
|
||||
cos_sims, cos_dists, l2_dists, l1_dists = [], [], [], []
|
||||
for a, b in zip(emb1, emb2):
|
||||
cos_sim = 1 - cosine(a, b)
|
||||
cos_dist = 1 - cos_sim
|
||||
l2 = euclidean(a, b)
|
||||
l1 = cityblock(a, b)
|
||||
cos_sims.append(cos_sim)
|
||||
cos_dists.append(cos_dist)
|
||||
l2_dists.append(l2)
|
||||
l1_dists.append(l1)
|
||||
return np.array(cos_sims), np.array(cos_dists), np.array(l2_dists), np.array(l1_dists)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description="Compare two embedding matrices")
|
||||
p.add_argument("--emb1", required=True, help="Path to first embedding .npy")
|
||||
p.add_argument("--emb2", required=True, help="Path to second embedding .npy")
|
||||
p.add_argument("--label1", default="A", help="Label for first embedding")
|
||||
p.add_argument("--label2", default="B", help="Label for second embedding")
|
||||
p.add_argument("--prefix", default="results/compare", help="Prefix for output files")
|
||||
p.add_argument("--no-plot", action="store_true", help="Skip generating the plot")
|
||||
args = p.parse_args()
|
||||
|
||||
# ---- Load ----
|
||||
emb1 = np.load(args.emb1)
|
||||
emb2 = np.load(args.emb2)
|
||||
if emb1.shape != emb2.shape:
|
||||
raise ValueError(f"Shape mismatch: {emb1.shape} vs {emb2.shape}")
|
||||
|
||||
os.makedirs(os.path.dirname(args.prefix), exist_ok=True)
|
||||
n_bins, n_dim = emb1.shape
|
||||
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
|
||||
|
||||
# ---- Compute metrics ----
|
||||
cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2)
|
||||
|
||||
df = pd.DataFrame({
|
||||
"bin_id": np.arange(n_bins),
|
||||
"cosine_similarity": cos_sims,
|
||||
"cosine_distance": cos_dists,
|
||||
"euclidean": l2_dists,
|
||||
"manhattan": l1_dists
|
||||
})
|
||||
csv_path = f"{args.prefix}_delta.csv"
|
||||
df.to_csv(csv_path, index=False)
|
||||
print(f"Saved metrics → {csv_path}")
|
||||
|
||||
# ---- Plot ----
|
||||
if not args.no_plot:
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue")
|
||||
plt.title(f"Δ-Embedding ({args.label1} vs {args.label2})")
|
||||
plt.xlabel("Bin index")
|
||||
plt.ylabel("Cosine distance (1 – similarity)")
|
||||
plt.tight_layout()
|
||||
fig_path = f"{args.prefix}_delta.png"
|
||||
plt.savefig(fig_path, dpi=300)
|
||||
print(f"Saved plot → {fig_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
63
scripts/encode_graph.py
Normal file
63
scripts/encode_graph.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Encode a new graph using a trained VGAE model.
|
||||
Automatically infers hidden/latent dimensions from saved weights.
|
||||
"""
|
||||
|
||||
import argparse, torch, numpy as np
|
||||
from torch_geometric.nn import GCNConv, VGAE
|
||||
|
||||
# Reuse your Encoder definition directly here for clarity
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, in_dim, hidden, latent, dropout=0.2):
|
||||
super().__init__()
|
||||
self.gc1 = GCNConv(in_dim, hidden)
|
||||
self.gc_mu = GCNConv(hidden, latent)
|
||||
self.gc_log = GCNConv(hidden, latent)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
import torch.nn.functional as F
|
||||
h = self.gc1(x, edge_index)
|
||||
h = F.relu(h)
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", required=True)
|
||||
p.add_argument("--graph", required=True)
|
||||
p.add_argument("--out", required=True)
|
||||
args = p.parse_args()
|
||||
|
||||
# ---- Load data and model state ----
|
||||
data = torch.load(args.graph)
|
||||
model_state = torch.load(args.model, map_location="cpu")
|
||||
|
||||
# ---- Infer dimensions dynamically ----
|
||||
in_dim = data.x.size(1)
|
||||
# detect hidden and latent dimensions safely
|
||||
keys = list(model_state.keys())
|
||||
gc1_weight = [k for k in keys if "gc1" in k and "weight" in k][0]
|
||||
gc_mu_weight = [k for k in keys if "gc_mu" in k and "weight" in k][0]
|
||||
|
||||
hidden = model_state[gc1_weight].shape[0]
|
||||
latent = model_state[gc_mu_weight].shape[0]
|
||||
|
||||
print(f"Inferred dims: in={in_dim}, hidden={hidden}, latent={latent}")
|
||||
|
||||
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
||||
model = VGAE(enc)
|
||||
model.load_state_dict(model_state)
|
||||
model.eval()
|
||||
|
||||
# ---- Encode ----
|
||||
with torch.no_grad():
|
||||
z = model.encode(data.x.float(), data.edge_index)
|
||||
np.save(args.out, z.cpu().numpy())
|
||||
print(f"Saved embeddings → {args.out} shape={z.shape}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
184
scripts/train_vgae.py
Normal file
184
scripts/train_vgae.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
||||
|
||||
Inputs:
|
||||
- A PyTorch Geometric Data object saved with torch.save(...) containing:
|
||||
x : [num_nodes, num_features] node features
|
||||
edge_index : [2, num_edges] undirected edges (will be coalesced)
|
||||
edge_weight : [num_edges] (optional, unused by VGAE)
|
||||
|
||||
- from build_graph.py
|
||||
---
|
||||
Outputs (under results/):
|
||||
- model.pt : trained VGAE state_dict
|
||||
- emb.npy : node embeddings (mean; shape [num_nodes, latent_dim])
|
||||
- metrics.json : train/val/test AUC/AP summary
|
||||
"""
|
||||
|
||||
import os, json, argparse, numpy as np, torch
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv
|
||||
from torch_geometric.nn.models import VGAE
|
||||
from torch_geometric.transforms import RandomLinkSplit
|
||||
from torch_geometric.utils import to_undirected, remove_self_loops
|
||||
from torch_geometric.utils import negative_sampling
|
||||
from sklearn.metrics import roc_auc_score, average_precision_score
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
|
||||
super().__init__()
|
||||
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
h = self.gc1(x, edge_index)
|
||||
h = F.relu(h)
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_linkpred(model, data_like, z):
|
||||
"""Compute AUROC/AP using provided positive/negative edges."""
|
||||
pos = data_like.pos_edge_index
|
||||
neg = data_like.neg_edge_index
|
||||
# model.test returns (auc, ap) but relies on torchmetrics in some versions;
|
||||
# compute explicitly for stability:
|
||||
def sigmoid(x): return 1 / (1 + torch.exp(-x))
|
||||
|
||||
# Inner product decoder scores
|
||||
def scores(edges):
|
||||
src, dst = edges
|
||||
s = (z[src] * z[dst]).sum(dim=1)
|
||||
return sigmoid(s).cpu().numpy()
|
||||
|
||||
y_true = np.concatenate([np.ones(pos.size(1)), np.zeros(neg.size(1))])
|
||||
y_pred = np.concatenate([scores(pos), scores(neg)])
|
||||
|
||||
auc = roc_auc_score(y_true, y_pred)
|
||||
ap = average_precision_score(y_true, y_pred)
|
||||
return auc, ap
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--graph", required=True, help="Path to Data .pt file")
|
||||
ap.add_argument("--epochs", type=int, default=100)
|
||||
ap.add_argument("--lr", type=float, default=1e-3)
|
||||
ap.add_argument("--hidden", type=int, default=128)
|
||||
ap.add_argument("--latent", type=int, default=64)
|
||||
ap.add_argument("--dropout", type=float, default=0.2)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
ap.add_argument("--outdir", default="results")
|
||||
args = ap.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
|
||||
# ---- Load graph ----
|
||||
data = torch.load(args.graph)
|
||||
# Coalesce/clean edges
|
||||
ei, _ = remove_self_loops(data.edge_index)
|
||||
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
|
||||
x = data.x.float()
|
||||
|
||||
# ---- Split edges for link prediction ----
|
||||
splitter = RandomLinkSplit(
|
||||
num_val=0.1,
|
||||
num_test=0.1,
|
||||
is_undirected=True,
|
||||
add_negative_train_samples=False,
|
||||
split_labels=False,
|
||||
)
|
||||
train_data, val_data, test_data = splitter(data)
|
||||
|
||||
# Positive edges are just the edges in each split
|
||||
train_data.pos_edge_index = train_data.edge_index
|
||||
val_data.pos_edge_index = val_data.edge_index
|
||||
test_data.pos_edge_index = test_data.edge_index
|
||||
|
||||
# Generate negative edges for validation and test manually
|
||||
for subset in [val_data, test_data]:
|
||||
subset.neg_edge_index = negative_sampling(
|
||||
edge_index=subset.edge_index,
|
||||
num_nodes=data.num_nodes,
|
||||
num_neg_samples=subset.edge_index.size(1),
|
||||
method='sparse'
|
||||
)
|
||||
|
||||
|
||||
# ---- Model ----
|
||||
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
|
||||
model = VGAE(enc)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
# ---- Training loop ----
|
||||
best_val_auc = -1.0
|
||||
best_state = None
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
# Encode using remaining training edges
|
||||
z = model.encode(x, train_data.edge_index)
|
||||
# Reconstruction loss on positive training edges (negatives sampled inside)
|
||||
loss_recon = model.recon_loss(z, train_data.pos_edge_index)
|
||||
# KL divergence regularizer
|
||||
loss_kl = (1.0 / data.num_nodes) * model.kl_loss()
|
||||
loss = loss_recon + loss_kl
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# ---- Validation ----
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
|
||||
val_auc, val_ap = eval_linkpred(model, val_data, z_full)
|
||||
|
||||
if val_auc > best_val_auc:
|
||||
best_val_auc = val_auc
|
||||
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||
|
||||
if epoch % 10 == 0 or epoch == 1:
|
||||
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
||||
|
||||
# ---- Save best model ----
|
||||
model.load_state_dict(best_state)
|
||||
model_path = os.path.join(args.outdir, "model.pt")
|
||||
torch.save(model.state_dict(), model_path)
|
||||
|
||||
# ---- Final test metrics ----
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
z_final = model.encode(x, data.edge_index)
|
||||
test_auc, test_ap = eval_linkpred(model, test_data, z_final)
|
||||
|
||||
# ---- Save embeddings & metrics ----
|
||||
emb_path = os.path.join(args.outdir, "emb.npy")
|
||||
np.save(emb_path, z_final.cpu().numpy())
|
||||
|
||||
metrics = {
|
||||
"val_auc": float(best_val_auc),
|
||||
"test_auc": float(test_auc),
|
||||
"test_ap": float(test_ap),
|
||||
"epochs": args.epochs,
|
||||
"hidden": args.hidden,
|
||||
"latent": args.latent,
|
||||
"dropout": args.dropout,
|
||||
"lr": args.lr,
|
||||
"seed": args.seed
|
||||
}
|
||||
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
print(f"Saved model -> {model_path}")
|
||||
print(f"Saved embeddings -> {emb_path} (shape={z_final.shape})")
|
||||
print(f"Metrics: AUC(test)={test_auc:.4f}, AP(test)={test_ap:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user