Files
chromatin-vgae-hic/experiments/h3_longrange/build_ablation_graphs.py

86 lines
3.5 KiB
Python

#!/usr/bin/env python3
"""
H3 graph ablation: build distance-banded variants of a full contact graph.
From one input graph (all edges up to max_dist), produces two filtered copies:
*_local.pt — only edges with |Δbp| < short_cutoff (within-TAD scale)
*_longrange.pt — only edges with |Δbp| > long_cutoff (sub-compartment scale)
The full graph is reused as-is from H1; this script generates only the two
ablations. Node features and node count are preserved exactly.
Hypothesis 3 tests whether long-range edges (>1 Mb) encode non-trivial
topological structure beyond what is explained by local contact density.
Usage
-----
python experiments/h3_longrange/build_ablation_graphs.py \\
--graph data/processed/gm12878/chr1.pt \\
--res 25000 \\
--short_cutoff 250000 \\
--long_cutoff 1000000 \\
--out_local data/processed/gm12878/chr1_local.pt \\
--out_longrange data/processed/gm12878/chr1_longrange.pt
"""
import argparse
import os
import torch
def filter_edges_by_distance(data, min_bins=0, max_bins=int(1e9)):
"""Return a copy of `data` keeping only edges with min_bins <= |Δ| < max_bins."""
ei = data.edge_index
src, dst = ei[0], ei[1]
dist = (src - dst).abs()
mask = (dist >= min_bins) & (dist < max_bins)
new = data.clone()
new.edge_index = ei[:, mask]
if hasattr(data, "edge_weight") and data.edge_weight is not None:
new.edge_weight = data.edge_weight[mask]
return new
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--graph", required=True, help="Input full graph .pt")
ap.add_argument("--res", type=int, default=25_000, help="Bin size bp")
ap.add_argument("--short_cutoff", type=int, default=250_000,
help="Local variant: |Δbp| < this (default 250 kb, within-TAD)")
ap.add_argument("--long_cutoff", type=int, default=1_000_000,
help="Long-range variant: |Δbp| > this (default 1 Mb)")
ap.add_argument("--out_local", required=True)
ap.add_argument("--out_longrange", required=True)
args = ap.parse_args()
data = torch.load(args.graph, weights_only=False)
short_bins = args.short_cutoff // args.res
long_bins = args.long_cutoff // args.res
print(f"Input graph: {data.num_nodes} nodes, {data.edge_index.shape[1]} edges, "
f"{data.x.shape[1]} features")
print(f"Local variant: |Δbin| < {short_bins} (= {args.short_cutoff//1000} kb)")
print(f"Long-range variant: |Δbin| > {long_bins} (= {args.long_cutoff//1000} kb)")
local_data = filter_edges_by_distance(data, min_bins=0, max_bins=short_bins)
longrange_data = filter_edges_by_distance(data, min_bins=long_bins, max_bins=int(1e9))
print(f"Local graph: {local_data.edge_index.shape[1]:>9} edges "
f"({100*local_data.edge_index.shape[1]/data.edge_index.shape[1]:.1f}% of full)")
print(f"Long-range graph: {longrange_data.edge_index.shape[1]:>9} edges "
f"({100*longrange_data.edge_index.shape[1]/data.edge_index.shape[1]:.1f}% of full)")
for path in (args.out_local, args.out_longrange):
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
torch.save(local_data, args.out_local)
torch.save(longrange_data, args.out_longrange)
print(f"Saved → {args.out_local}")
print(f"Saved → {args.out_longrange}")
if __name__ == "__main__":
main()