Synthetic 4-Quadrant Benchmark — Figure Reproduction¶
Paper Figures: Fig 2 (a–e), Fig S1 (a–d)
This notebook reproduces the synthetic benchmark figures from the SAME paper. It operates in two modes:
- Visualization only (default): loads pre-computed data and SAME results from
data/andresults/ - Full reproduction: regenerates the dataset and re-runs SAME optimization (requires Gurobi)
Dataset: 4 quadrants (10×10 grid each), 3 cell types (c1, c2, c3), 411 template / 372 query cells. Each quadrant tests a different alignment challenge:
- Top-Left: Missing cell type (c3 removed in query)
- Top-Right: Noisy class probabilities (near-uniform)
- Bottom-Right: Space-tearing (point swaps + shear)
- Bottom-Left: Topological split (1 ellipse → 2 rings)
Setup¶
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from matplotlib import rcParams, gridspec, cm
from matplotlib.colors import Normalize
from scipy.spatial import Delaunay
from pathlib import Path
# -- Configure matplotlib for publication-quality SVGs --
rcParams['svg.fonttype'] = 'none'
rcParams['font.size'] = 12
# -- Path setup --
# Resolve project root (SAME/) regardless of where the notebook kernel runs
NOTEBOOK_DIR = Path('.').resolve()
PROJECT_ROOT = NOTEBOOK_DIR.parents[1] # SAME/examples/synthetic -> SAME/
import sys
sys.path.insert(0, str(PROJECT_ROOT))
DATA_DIR = NOTEBOOK_DIR / 'data'
RESULTS_DIR = NOTEBOOK_DIR / 'results'
FIG_DIR = NOTEBOOK_DIR / 'figures'
FIG_DIR.mkdir(exist_ok=True)
# -- SAME imports --
from src.synthetic_datagen import (
create_full_benchmark,
visualize_benchmark,
visualize_space_tearing,
visualize_topological_merger,
print_statistics,
check_triangle_violations_within_quadrants,
CLASS_NAMES, CLASS_COLORS
)
from src.eval_utils import check_triangle_violations
print(f"Project root: {PROJECT_ROOT}")
print(f"Data dir: {DATA_DIR}")
print(f"Results dir: {RESULTS_DIR}")
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload Project root: /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME Data dir: /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME/examples/synthetic/data Results dir: /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME/examples/synthetic/results
1. Load Dataset¶
# Load pre-generated dataset (seed=8899)
ref_df = pd.read_csv(DATA_DIR / 'ref.csv', index_col=0)
query_df = pd.read_csv(DATA_DIR / 'query.csv', index_col=0)
ground_truth_df = pd.read_csv(DATA_DIR / 'ground_truth.csv', index_col=0)
quadrants = pickle.load(open(DATA_DIR / 'quadrants.pkl', 'rb'))
print(f"Template: {len(ref_df)} cells, Query: {len(query_df)} cells")
print(f"Ground truth pairs: {len(ground_truth_df)}")
print(f"Cell types: {sorted(ref_df['cell_type'].unique())}")
print(f"\nTemplate cell type counts:\n{ref_df['cell_type'].value_counts().sort_index()}")
print(f"\nQuery cell type counts:\n{query_df['cell_type'].value_counts().sort_index()}")
Template: 411 cells, Query: 372 cells Ground truth pairs: 372 Cell types: ['c1', 'c2', 'c3'] Template cell type counts: cell_type c1 142 c2 119 c3 150 Name: count, dtype: int64 Query cell type counts: cell_type c1 142 c2 119 c3 111 Name: count, dtype: int64
2. Fig 2a,c — Template & Query Overview¶
Four quadrants with distinct alignment challenges. Template (a) and Query (c) slices colored by cell type.
fig = visualize_benchmark(ref_df, query_df, quadrants)
plt.savefig(FIG_DIR / 'Fig2ac_benchmark_overview.svg', dpi=300, bbox_inches='tight')
plt.show()
q = quadrants['top_left']
fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
# Template
ax = axes[0]
for c in range(3):
mask = q['ref_classes'] == c
if mask.sum() > 0:
ax.scatter(q['ref_points'][mask, 0], q['ref_points'][mask, 1],
c=list(CLASS_COLORS.values())[c], s=50,
label=CLASS_NAMES[c], edgecolors='black', linewidth=0.5)
ax.set_title(f'a. Template: {len(q["ref_points"])} cells', fontsize=12, loc='left')
ax.set_aspect('equal')
# Query (c3 removed)
ax = axes[1]
for c in range(2):
mask = q['query_classes'] == c
if mask.sum() > 0:
ax.scatter(q['query_points'][mask, 0], q['query_points'][mask, 1],
c=list(CLASS_COLORS.values())[c], s=50, marker='P',
label=CLASS_NAMES[c], edgecolors='black', linewidth=0.5)
ax.set_title(f'b. Query: {len(q["query_points"])} cells', fontsize=12, loc='left')
ax.set_aspect('equal')
for ax in axes:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS1a_missing_class.svg', dpi=300, bbox_inches='tight')
plt.show()
print(f"Ref: c1={sum(q['ref_classes']==0)}, c2={sum(q['ref_classes']==1)}, c3={sum(q['ref_classes']==2)}")
print(f"Query: c1={sum(q['query_classes']==0)}, c2={sum(q['query_classes']==1)} (c3 missing!)")
Ref: c1=30, c2=32, c3=38 Query: c1=30, c2=32 (c3 missing!)
S1b: Bottom-Left — Topological split (1 ellipse → 2 rings)¶
fig = visualize_topological_merger(quadrants)
plt.savefig(FIG_DIR / 'FigS1b_topological_split.svg', dpi=300, bbox_inches='tight')
plt.show()
S1c: Bottom-Right — Space-tearing (point swaps + shear)¶
fig = visualize_space_tearing(quadrants, 'bottom_right', recompute_triangulation=True, min_angle_deg=5)
plt.savefig(FIG_DIR / 'FigS1c_space_tearing.svg', dpi=300, bbox_inches='tight')
plt.show()
S1d: Top-Right — Noisy class probabilities¶
probs = ['c1', 'c2', 'c3']
cmap = 'viridis'
fixed_ticks = [0, 5, 10]
fig = plt.figure(figsize=(12, 7))
gs = gridspec.GridSpec(2, 4, width_ratios=[1, 1, 1, 0.05], wspace=0.35, hspace=0.35)
ax = np.array([[fig.add_subplot(gs[r, c]) for c in range(3)] for r in range(2)])
norm = Normalize(vmin=0, vmax=100)
labels = [['a', 'b', 'c'], ['d', 'e', 'f']]
row_labels = ['Template', 'Query']
dfs = [ref_df, query_df]
for row in range(2):
for i, p in enumerate(probs):
sc = ax[row, i].scatter(
dfs[row]['X'], dfs[row]['Y'], c=dfs[row][p], cmap=cmap, norm=norm,
s=20, alpha=0.85, edgecolors='black', linewidth=0.3)
ax[row, i].set_title(f"{labels[row][i]}. {row_labels[row]}: Class probability: {p}", fontsize=12)
ax[row, i].set_aspect('equal')
ax[row, i].set_xticks(fixed_ticks)
ax[row, i].set_yticks(fixed_ticks)
ax[row, i].spines['top'].set_visible(False)
ax[row, i].spines['right'].set_visible(False)
cax = fig.add_subplot(gs[:, 3])
cbar = fig.colorbar(sc, cax=cax, orientation='vertical')
cbar.set_label('Class Probability (%)', fontsize=12)
cbar.set_ticks(np.linspace(0, 100, 6))
plt.tight_layout(rect=[0, 0, 0.97, 1])
plt.savefig(FIG_DIR / 'FigS1d_noisy_probabilities.svg', dpi=300, bbox_inches='tight')
plt.show()
/hpc/group/singhlab/user/ap756/tmp/ipykernel_4034952/3965856167.py:31: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect. plt.tight_layout(rect=[0, 0, 0.97, 1])
4. Load SAME Results¶
Load pre-computed matching results (dp=10) and metacell objects.
mc_align = pickle.load(open(RESULTS_DIR / 'mc_align.pkl', 'rb'))
mc_ref = pickle.load(open(RESULTS_DIR / 'mc_ref.pkl', 'rb'))
matches_df = pd.read_csv(RESULTS_DIR / 'matchedDF.csv')
# Add cell type annotations from metacell objects
matches_df['cell_type'] = mc_align.metacell_df.loc[
matches_df['Aligned_metacell_id'].values, 'cell_type'].values
matches_df['SAME_X'] = matches_df['ref_X']
matches_df['SAME_Y'] = matches_df['ref_Y']
# Evaluate cell type matching accuracy
matches_df['align_cell_type'] = mc_align.metacell_df.loc[
matches_df['Aligned_metacell_id'].values, 'cell_type'].values
matches_df['ref_cell_type'] = mc_ref.metacell_df.loc[
matches_df['Ref_metacell_id'].values, 'cell_type'].values
ct_accuracy = (matches_df.align_cell_type == matches_df.ref_cell_type).mean()
print(f"Matches: {len(matches_df)}")
print(f"Cell type matching accuracy: {ct_accuracy:.1%}")
Matches: 372 Cell type matching accuracy: 100.0%
5. Fig 2e — SAME Alignment Result¶
Displacement vectors showing query → template matching, colored by cell type.
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
# Background: query cells (faint)
for ct, color in CLASS_COLORS.items():
mask = query_df['cell_type'] == ct
if mask.sum() > 0:
ax.scatter(query_df.loc[mask, 'X'], query_df.loc[mask, 'Y'],
c=[color], s=30, marker='P', alpha=0.2)
# Displacement lines
for _, row in matches_df.iterrows():
ax.plot([row['X'], row['ref_X']], [row['Y'], row['ref_Y']],
'k-', alpha=0.3, linewidth=1, zorder=3)
# Matched positions colored by cell type
query_colors = matches_df['cell_type'].map(CLASS_COLORS)
ax.scatter(matches_df['ref_X'], matches_df['ref_Y'], c=query_colors,
s=35, edgecolors='black', linewidth=0.4, alpha=0.85)
ax.set_title(f'SAME Alignment (100% cell type match)', fontsize=12)
ax.set_aspect('equal')
ax.set_axis_off()
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig2e_SAME_alignment.svg', dpi=300, bbox_inches='tight')
plt.show()
Fig 2e (zoom) — Bottom-left quadrant detail¶
fig, axes = plt.subplots(1, 2, figsize=(6, 4))
for ax_idx, (ax, title, show_arrows) in enumerate(zip(
axes, ['Query', 'SAME'], [False, True])):
for ct, color in CLASS_COLORS.items():
mask = query_df['cell_type'] == ct
if mask.sum() > 0:
ax.scatter(query_df.loc[mask, 'X'], query_df.loc[mask, 'Y'],
c=[color], s=40, marker='P', alpha=(0.2 if show_arrows else 1.0))
if show_arrows:
ax.quiver(matches_df['X'], matches_df['Y'],
matches_df['ref_X'] - matches_df['X'],
matches_df['ref_Y'] - matches_df['Y'],
angles='xy', scale_units='xy', scale=1, alpha=0.2,
width=0.01, headwidth=3, headlength=4, color='gray')
query_colors = matches_df['cell_type'].map(CLASS_COLORS)
ax.scatter(matches_df['ref_X'], matches_df['ref_Y'], c=query_colors,
s=40, edgecolors='black', linewidth=0.4, alpha=0.85)
ax.set_title(title, fontsize=12)
ax.set_aspect('equal')
ax.set_axis_off()
ax.set_xlim(1.3, 5)
ax.set_ylim(2, 5)
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig2e_zoom_bottom_left.svg', dpi=300, bbox_inches='tight')
plt.show()
6. Fig 2 — Triangle Violation Analysis¶
Spatial visualization of triangle orientation violations (expected in space-tearing and topological split quadrants).
# Check triangle violations
matches_df = check_triangle_violations_within_quadrants(matches_df, mc_align)
tri_violations, report = check_triangle_violations(
matches_df, mc_align,
aligned_id_col='Aligned_metacell_id',
ref_id_col='Ref_metacell_id',
mapped_x_col='ref_X',
mapped_y_col='ref_Y',
cell_type_col='cell_type',
ignore_same_type_triangles=False,
node_local=False)
print("Triangle violation summary:")
print(tri_violations[['triangle_violation', 'in_violating_triangle']].value_counts())
Triangle violation summary: triangle_violation in_violating_triangle False False 324 True True 46 False True 2 Name: count, dtype: int64
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.scatter(ref_df['X'], ref_df['Y'], c='blue', marker='P', s=20, alpha=0.3)
good_matches = matches_df[~matches_df['triangle_violation']]
bad_matches = matches_df[matches_df['triangle_violation']]
# Good matches: black arrows
for _, row in good_matches.iterrows():
ax.plot([row['X'], row['ref_X']], [row['Y'], row['ref_Y']],
'k-', alpha=0.3, linewidth=1)
# Violations: magenta arrows
for _, row in bad_matches.iterrows():
ax.plot([row['X'], row['ref_X']], [row['Y'], row['ref_Y']],
'm-', alpha=0.8, linewidth=1.5)
ax.scatter(good_matches['X'], good_matches['Y'], c='blue', s=30,
label=f'Good ({len(good_matches)})')
ax.scatter(bad_matches['X'], bad_matches['Y'], c='magenta', s=50, marker='x',
linewidths=2, label=f'Violation ({len(bad_matches)})')
ax.set_title('Triangle Violations', fontsize=12)
ax.legend()
ax.set_aspect('equal')
ax.set_axis_off()
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig2_triangle_violations.svg', dpi=300, bbox_inches='tight')
plt.show()
7. Supplementary — Delaunay Triangulation on Metacells¶
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, mc, title in zip(axes, [mc_align, mc_ref], ['Query (aligned)', 'Template (reference)']):
for ct, color in CLASS_COLORS.items():
mask = mc.metacell_df['cell_type'] == ct
ax.scatter(mc.metacell_df.loc[mask, 'X'], mc.metacell_df.loc[mask, 'Y'],
c=color, s=25, alpha=0.85, edgecolors='black', linewidth=0.3, label=ct)
for triangle in mc.metacell_delaunay:
pts = mc.metacell_df.iloc[list(triangle)][['X', 'Y']].values
poly = plt.Polygon(pts, closed=True, fill=False,
edgecolor='gray', linewidth=0.5, alpha=0.5)
ax.add_patch(poly)
ax.set_title(f'{title} ({len(mc.metacell_df)} cells, {len(mc.metacell_delaunay)} triangles)')
ax.set_aspect('equal')
ax.legend()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS_delaunay_triangulation.svg', dpi=300, bbox_inches='tight')
plt.show()
(Optional) Full Reproduction¶
Run this section to regenerate data and re-run SAME from scratch. Requires Gurobi license.
See run_same.sh for the equivalent command-line parameters.
# ============================================================
# UNCOMMENT TO REGENERATE DATA
# ============================================================
# np.random.seed(8899)
# ref_df, query_df, quadrants, ground_truth_df, ex_df = create_full_benchmark()
# print_statistics(ref_df, query_df, quadrants)
#
# ref_df.to_csv(DATA_DIR / 'ref.csv', index=True)
# query_df.to_csv(DATA_DIR / 'query.csv', index=True)
# ground_truth_df.to_csv(DATA_DIR / 'ground_truth.csv', index=True)
# pickle.dump(quadrants, open(DATA_DIR / 'quadrants.pkl', 'wb'))
#
# import anndata as ad
# refAD = ad.AnnData(ex_df['ref'], obs=ref_df)
# queryAD = ad.AnnData(ex_df['query'], obs=query_df)
# refAD.write_h5ad(DATA_DIR / 'ref.h5ad')
# queryAD.write_h5ad(DATA_DIR / 'query.h5ad')
# ============================================================
# UNCOMMENT TO RE-RUN SAME OPTIMIZATION
# ============================================================
# from src.metacell_utils import greedy_triangle_collapse
# from src.same import init_gurobi_params, init_optim_params, sliding_window_matching
# import time
#
# # -- Create metacells (MS=1 means no aggregation, just builds Delaunay) --
# mc_align = greedy_triangle_collapse(
# query_df, cell_type_col='cell_type', original_idx_col='cell_idx',
# x_col='X', y_col='Y', max_metacell_size=1, r_max=5, min_angle_deg=5,
# use_alpha_shape=False, alpha=None, return_object=True)
# mc_ref = greedy_triangle_collapse(
# ref_df, cell_type_col='cell_type', original_idx_col='cell_idx',
# x_col='X', y_col='Y', max_metacell_size=1, r_max=5, min_angle_deg=5,
# use_alpha_shape=False, alpha=None, return_object=True)
#
# # -- SAME parameters --
# gurobi_params = init_gurobi_params()
# gurobi_params['mip_gap'] = 0.025
# gurobi_params['lazy_allowed_flip_fraction'] = 0.0
#
# optim_params = init_optim_params()
# optim_params.update({
# 'window_size': 100, 'overlap': 0, 'min_cells_per_window': 30,
# 'max_matches': 2, 'radius': 5, 'knn': 8,
# 'no_match_penalty': 10000, 'dist_ct_coeff': 1, 'min_angle_deg': 5,
# 'penalty_coeff': 100, 'delaunay_penalty': 10,
# 'cell_id_col': 'metacell_id', 'ref_metacell_match_multiplier': 1,
# 'ignore_same_type_triangles': False, 'lazy_constraints': True,
# })
#
# outprefix = str(RESULTS_DIR / 'dp10')
# start = time.time()
# matches_df = sliding_window_matching(
# mc_ref, mc_align, outprefix=outprefix,
# optim_params=optim_params, gurobi_params=gurobi_params,
# ignore_precomputed_triangulation=False)
# print(f"SAME completed in {time.time()-start:.1f}s")
#
# pickle.dump(mc_align, open(RESULTS_DIR / 'mc_align.pkl', 'wb'))
# pickle.dump(mc_ref, open(RESULTS_DIR / 'mc_ref.pkl', 'wb'))