initial framework; to be extended
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compares two latent embedding matrices (e.g., CTRL vs EED-i),
|
||||
computes similarity metrics (cosine, Euclidean, L1),
|
||||
and saves both a CSV and an optional line plot.
|
||||
Compares two latent embedding matrices,
|
||||
computes similarity metrics (cosine, Euclidean, L1)
|
||||
|
||||
Usage:
|
||||
python scripts/compare_embeddings_general.py \
|
||||
@@ -45,7 +44,7 @@ def main():
|
||||
p.add_argument("--no-plot", action="store_true", help="Skip generating the plot")
|
||||
args = p.parse_args()
|
||||
|
||||
# ---- Load ----
|
||||
# Load
|
||||
emb1 = np.load(args.emb1)
|
||||
emb2 = np.load(args.emb2)
|
||||
if emb1.shape != emb2.shape:
|
||||
@@ -55,7 +54,7 @@ def main():
|
||||
n_bins, n_dim = emb1.shape
|
||||
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
|
||||
|
||||
# ---- Compute metrics ----
|
||||
# Compute metrics
|
||||
cos_sims, cos_dists, l2_dists, l1_dists = compute_metrics(emb1, emb2)
|
||||
|
||||
df = pd.DataFrame({
|
||||
@@ -69,7 +68,7 @@ def main():
|
||||
df.to_csv(csv_path, index=False)
|
||||
print(f"Saved metrics → {csv_path}")
|
||||
|
||||
# ---- Plot ----
|
||||
# Plot
|
||||
if not args.no_plot:
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.plot(df["bin_id"], df["cosine_distance"], lw=0.8, color="steelblue")
|
||||
|
||||
Reference in New Issue
Block a user