109 lines
4.1 KiB
Python
109 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
||
|
||
import argparse
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
import cooler
|
||
import pyBigWig
|
||
from torch_geometric.data import Data
|
||
|
||
|
||
def bin_bigwig(bw_path, chrom, bins):
|
||
"""Average bigWig signal across each genomic bin"""
|
||
bw = pyBigWig.open(bw_path)
|
||
if chrom not in bw.chroms():
|
||
raise ValueError(f"{chrom} not found in {bw_path}. Available: {list(bw.chroms().keys())[:5]}...")
|
||
chrom_len = bw.chroms(chrom)
|
||
vals = []
|
||
for s, e in bins:
|
||
s = max(0, s)
|
||
e = min(chrom_len, e)
|
||
if s >= e:
|
||
vals.append(0.0)
|
||
continue
|
||
v = bw.stats(chrom, s, e, type="mean")[0]
|
||
vals.append(0.0 if v is None or np.isnan(v) else v)
|
||
bw.close()
|
||
return np.array(vals)
|
||
|
||
|
||
def build_graph(mcool_path, chrom, res, bigwigs, out_path,
|
||
max_dist=5_000_000, edge_top_pct=100):
|
||
"""Convert .mcool + bigWigs to PyTorch Geometric Data object.
|
||
|
||
Parameters
|
||
----------
|
||
edge_top_pct : int, 1–100
|
||
Keep only the top N% of edges ranked by balanced contact weight.
|
||
Default 100 = keep all. Use e.g. 50 to keep the strongest half,
|
||
which removes low-weight noise from the ICE-balanced matrix.
|
||
"""
|
||
print(f"Processing {chrom} at {res} bp resolution...")
|
||
|
||
# Load pixels
|
||
c = cooler.Cooler(f"{mcool_path}::resolutions/{res}")
|
||
pixels = c.matrix(balance=True, as_pixels=True, join=True).fetch(chrom)
|
||
pixels = pixels.query(f"chrom1 == chrom2 and abs(start2 - start1) <= {max_dist}")
|
||
|
||
# Map genomic coordinates to bin IDs
|
||
bins_df = c.bins().fetch(chrom)
|
||
bins_df["bin_id"] = np.arange(len(bins_df))
|
||
start_to_bin = dict(zip(bins_df["start"].values, bins_df["bin_id"].values))
|
||
|
||
valid = pixels["start1"].isin(start_to_bin) & pixels["start2"].isin(start_to_bin)
|
||
pixels = pixels.loc[valid]
|
||
|
||
bin1 = pixels["start1"].map(start_to_bin).values
|
||
bin2 = pixels["start2"].map(start_to_bin).values
|
||
|
||
# Edge weights
|
||
if "balanced" in pixels.columns and pixels["balanced"].notna().any():
|
||
w = pixels["balanced"].fillna(0).values
|
||
else:
|
||
w = pixels["count"].values
|
||
|
||
# Optional edge thresholding: keep top edge_top_pct% by weight
|
||
if edge_top_pct < 100:
|
||
nonzero = w[w > 0]
|
||
if len(nonzero):
|
||
cutoff = np.percentile(nonzero, 100 - edge_top_pct)
|
||
mask = w >= cutoff
|
||
bin1, bin2, w = bin1[mask], bin2[mask], w[mask]
|
||
print(f" Edge filter (top {edge_top_pct}%): "
|
||
f"{mask.sum()} / {len(mask)} edges kept "
|
||
f"(weight ≥ {cutoff:.4f})")
|
||
|
||
edge_index = torch.tensor(np.stack([bin1, bin2]), dtype=torch.long)
|
||
edge_weight = torch.tensor(np.log1p(w), dtype=torch.float)
|
||
|
||
# Node features
|
||
starts = bins_df["start"].values
|
||
bins = [(int(s), int(s + res)) for s in starts]
|
||
node_feats = []
|
||
for bw in bigwigs:
|
||
print(f" Adding feature from {bw}")
|
||
node_feats.append(bin_bigwig(bw, chrom, bins))
|
||
x = torch.tensor(np.stack(node_feats, axis=1), dtype=torch.float)
|
||
|
||
data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight)
|
||
torch.save(data, out_path)
|
||
print(f"Saved {chrom}: {x.shape[0]} nodes, {edge_index.shape[1]} edges → {out_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
p = argparse.ArgumentParser(description="Build graph from Micro-C and bigWigs")
|
||
p.add_argument("--mcool", required=True, help="Path to .mcool file")
|
||
p.add_argument("--chrom", required=True, help="Chromosome name (e.g., chr21)")
|
||
p.add_argument("--res", type=int, default=10000, help="Resolution (bp)")
|
||
p.add_argument("--bigwigs", nargs="+", required=True, help="List of bigWig feature files")
|
||
p.add_argument("--out", required=True, help="Output .pt file path")
|
||
p.add_argument("--max_dist", type=int, default=5_000_000,
|
||
help="Max genomic distance for edges (bp)")
|
||
p.add_argument("--edge_top_pct", type=int, default=100,
|
||
help="Keep only top N%% of edges by balanced weight (1-100, default 100=all)")
|
||
args = p.parse_args()
|
||
|
||
build_graph(args.mcool, args.chrom, args.res, args.bigwigs, args.out,
|
||
args.max_dist, args.edge_top_pct)
|