#!/usr/bin/env python3 """ H3 graph ablation: build distance-banded variants of a full contact graph. From one input graph (all edges up to max_dist), produces two filtered copies: *_local.pt — only edges with |Δbp| < short_cutoff (within-TAD scale) *_longrange.pt — only edges with |Δbp| > long_cutoff (sub-compartment scale) The full graph is reused as-is from H1; this script generates only the two ablations. Node features and node count are preserved exactly. Hypothesis 3 tests whether long-range edges (>1 Mb) encode non-trivial topological structure beyond what is explained by local contact density. Usage ----- python experiments/h3_longrange/build_ablation_graphs.py \\ --graph data/processed/gm12878/chr1.pt \\ --res 25000 \\ --short_cutoff 250000 \\ --long_cutoff 1000000 \\ --out_local data/processed/gm12878/chr1_local.pt \\ --out_longrange data/processed/gm12878/chr1_longrange.pt """ import argparse import os import torch def filter_edges_by_distance(data, min_bins=0, max_bins=int(1e9)): """Return a copy of `data` keeping only edges with min_bins <= |Δ| < max_bins.""" ei = data.edge_index src, dst = ei[0], ei[1] dist = (src - dst).abs() mask = (dist >= min_bins) & (dist < max_bins) new = data.clone() new.edge_index = ei[:, mask] if hasattr(data, "edge_weight") and data.edge_weight is not None: new.edge_weight = data.edge_weight[mask] return new def main(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--graph", required=True, help="Input full graph .pt") ap.add_argument("--res", type=int, default=25_000, help="Bin size bp") ap.add_argument("--short_cutoff", type=int, default=250_000, help="Local variant: |Δbp| < this (default 250 kb, within-TAD)") ap.add_argument("--long_cutoff", type=int, default=1_000_000, help="Long-range variant: |Δbp| > this (default 1 Mb)") ap.add_argument("--out_local", required=True) ap.add_argument("--out_longrange", required=True) args = ap.parse_args() data = torch.load(args.graph, weights_only=False) short_bins = args.short_cutoff // args.res long_bins = args.long_cutoff // args.res print(f"Input graph: {data.num_nodes} nodes, {data.edge_index.shape[1]} edges, " f"{data.x.shape[1]} features") print(f"Local variant: |Δbin| < {short_bins} (= {args.short_cutoff//1000} kb)") print(f"Long-range variant: |Δbin| > {long_bins} (= {args.long_cutoff//1000} kb)") local_data = filter_edges_by_distance(data, min_bins=0, max_bins=short_bins) longrange_data = filter_edges_by_distance(data, min_bins=long_bins, max_bins=int(1e9)) print(f"Local graph: {local_data.edge_index.shape[1]:>9} edges " f"({100*local_data.edge_index.shape[1]/data.edge_index.shape[1]:.1f}% of full)") print(f"Long-range graph: {longrange_data.edge_index.shape[1]:>9} edges " f"({100*longrange_data.edge_index.shape[1]/data.edge_index.shape[1]:.1f}% of full)") for path in (args.out_local, args.out_longrange): os.makedirs(os.path.dirname(path) or ".", exist_ok=True) torch.save(local_data, args.out_local) torch.save(longrange_data, args.out_longrange) print(f"Saved → {args.out_local}") print(f"Saved → {args.out_longrange}") if __name__ == "__main__": main()