86 lines
3.5 KiB
Python
86 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
H3 graph ablation: build distance-banded variants of a full contact graph.
|
|
|
|
From one input graph (all edges up to max_dist), produces two filtered copies:
|
|
*_local.pt — only edges with |Δbp| < short_cutoff (within-TAD scale)
|
|
*_longrange.pt — only edges with |Δbp| > long_cutoff (sub-compartment scale)
|
|
|
|
The full graph is reused as-is from H1; this script generates only the two
|
|
ablations. Node features and node count are preserved exactly.
|
|
|
|
Hypothesis 3 tests whether long-range edges (>1 Mb) encode non-trivial
|
|
topological structure beyond what is explained by local contact density.
|
|
|
|
Usage
|
|
-----
|
|
python experiments/h3_longrange/build_ablation_graphs.py \\
|
|
--graph data/processed/gm12878/chr1.pt \\
|
|
--res 25000 \\
|
|
--short_cutoff 250000 \\
|
|
--long_cutoff 1000000 \\
|
|
--out_local data/processed/gm12878/chr1_local.pt \\
|
|
--out_longrange data/processed/gm12878/chr1_longrange.pt
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
|
|
import torch
|
|
|
|
|
|
def filter_edges_by_distance(data, min_bins=0, max_bins=int(1e9)):
|
|
"""Return a copy of `data` keeping only edges with min_bins <= |Δ| < max_bins."""
|
|
ei = data.edge_index
|
|
src, dst = ei[0], ei[1]
|
|
dist = (src - dst).abs()
|
|
mask = (dist >= min_bins) & (dist < max_bins)
|
|
|
|
new = data.clone()
|
|
new.edge_index = ei[:, mask]
|
|
if hasattr(data, "edge_weight") and data.edge_weight is not None:
|
|
new.edge_weight = data.edge_weight[mask]
|
|
return new
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser(description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
ap.add_argument("--graph", required=True, help="Input full graph .pt")
|
|
ap.add_argument("--res", type=int, default=25_000, help="Bin size bp")
|
|
ap.add_argument("--short_cutoff", type=int, default=250_000,
|
|
help="Local variant: |Δbp| < this (default 250 kb, within-TAD)")
|
|
ap.add_argument("--long_cutoff", type=int, default=1_000_000,
|
|
help="Long-range variant: |Δbp| > this (default 1 Mb)")
|
|
ap.add_argument("--out_local", required=True)
|
|
ap.add_argument("--out_longrange", required=True)
|
|
args = ap.parse_args()
|
|
|
|
data = torch.load(args.graph, weights_only=False)
|
|
short_bins = args.short_cutoff // args.res
|
|
long_bins = args.long_cutoff // args.res
|
|
|
|
print(f"Input graph: {data.num_nodes} nodes, {data.edge_index.shape[1]} edges, "
|
|
f"{data.x.shape[1]} features")
|
|
print(f"Local variant: |Δbin| < {short_bins} (= {args.short_cutoff//1000} kb)")
|
|
print(f"Long-range variant: |Δbin| > {long_bins} (= {args.long_cutoff//1000} kb)")
|
|
|
|
local_data = filter_edges_by_distance(data, min_bins=0, max_bins=short_bins)
|
|
longrange_data = filter_edges_by_distance(data, min_bins=long_bins, max_bins=int(1e9))
|
|
|
|
print(f"Local graph: {local_data.edge_index.shape[1]:>9} edges "
|
|
f"({100*local_data.edge_index.shape[1]/data.edge_index.shape[1]:.1f}% of full)")
|
|
print(f"Long-range graph: {longrange_data.edge_index.shape[1]:>9} edges "
|
|
f"({100*longrange_data.edge_index.shape[1]/data.edge_index.shape[1]:.1f}% of full)")
|
|
|
|
for path in (args.out_local, args.out_longrange):
|
|
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
|
torch.save(local_data, args.out_local)
|
|
torch.save(longrange_data, args.out_longrange)
|
|
print(f"Saved → {args.out_local}")
|
|
print(f"Saved → {args.out_longrange}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|