Files
chromatin-vgae-hic/experiments/h1_representation/compare.py

86 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Compares two latent embedding matrices,
computes similarity metrics (cosine, Euclidean, L1)
Usage:
python scripts/compare_embeddings_general.py \
--emb1 results/emb.npy \
--emb2 results/emb_eedi.npy \
--label1 CTRL --label2 EEDi \
--prefix results/chr21
"""
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial.distance import cosine, euclidean, cityblock
import os
def compute_metrics(emb1, emb2):
"""Compute cosine similarity, cosine distance, L2, and L1 per row."""
cos_sims, cos_dists, l2_dists, l1_dists = [], [], [], []
for a, b in zip(emb1, emb2):
cos_sim = 1 - cosine(a, b)
cos_dist = 1 - cos_sim
l2 = euclidean(a, b)
l1 = cityblock(a, b)
cos_sims.append(cos_sim)
cos_dists.append(cos_dist)
l2_dists.append(l2)
l1_dists.append(l1)
return np.array(cos_sims), np.array(cos_dists), np.array(l2_dists), np.array(l1_dists)
def main():
p = argparse.ArgumentParser(description="Compare two embedding matrices")
p.add_argument("--emb1", required=True, help="Path to first embedding .npy")
p.add_argument("--emb2", required=True, help="Path to second embedding .npy")
p.add_argument("--label1", default="A", help="Label for first embedding")
p.add_argument("--label2", default="B", help="Label for second embedding")
p.add_argument("--prefix", default="results/compare", help="Prefix for output files")
p.add_argument("--no-plot", action="store_true", help="Skip generating the plot")
args = p.parse_args()
# Load
emb1 = np.load(args.emb1)
emb2 = np.load(args.emb2)
if emb1.shape != emb2.shape:
raise ValueError(f"Shape mismatch: {emb1.shape} vs {emb2.shape}")
prefix_dir = os.path.dirname(os.path.abspath(args.prefix))
os.makedirs(prefix_dir, exist_ok=True)
n_bins, n_dim = emb1.shape
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
# Compute metrics
cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2)
df = pd.DataFrame({
"bin_id": np.arange(n_bins),
"cosine_similarity": cos_sims,
"cosine_distance": cos_dists,
"euclidean": l2_dists,
"manhattan": l1_dists
})
csv_path = f"{args.prefix}_delta.csv"
df.to_csv(csv_path, index=False)
print(f"Saved metrics → {csv_path}")
# Plot
if not args.no_plot:
plt.figure(figsize=(12, 4))
plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue")
plt.title(f"Δ-Embedding ({args.label1} vs {args.label2})")
plt.xlabel("Bin index")
plt.ylabel("Cosine distance (1 similarity)")
plt.tight_layout()
fig_path = f"{args.prefix}_delta.png"
plt.savefig(fig_path, dpi=300)
print(f"Saved plot → {fig_path}")
if __name__ == "__main__":
main()