Synthetic 4-Quadrant Benchmark — Figure Reproduction¶

Paper Figures: Fig 2 (a–e), Fig S1 (a–d)

This notebook reproduces the synthetic benchmark figures from the SAME paper. It operates in two modes:

  1. Visualization only (default): loads pre-computed data and SAME results from data/ and results/
  2. Full reproduction: regenerates the dataset and re-runs SAME optimization (requires Gurobi)

Dataset: 4 quadrants (10×10 grid each), 3 cell types (c1, c2, c3), 411 template / 372 query cells. Each quadrant tests a different alignment challenge:

  • Top-Left: Missing cell type (c3 removed in query)
  • Top-Right: Noisy class probabilities (near-uniform)
  • Bottom-Right: Space-tearing (point swaps + shear)
  • Bottom-Left: Topological split (1 ellipse → 2 rings)

Setup¶

In [2]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from matplotlib import rcParams, gridspec, cm
from matplotlib.colors import Normalize
from scipy.spatial import Delaunay
from pathlib import Path

# -- Configure matplotlib for publication-quality SVGs --
rcParams['svg.fonttype'] = 'none'
rcParams['font.size'] = 12

# -- Path setup --
# Resolve project root (SAME/) regardless of where the notebook kernel runs
NOTEBOOK_DIR = Path('.').resolve()
PROJECT_ROOT = NOTEBOOK_DIR.parents[1]  # SAME/examples/synthetic -> SAME/
import sys
sys.path.insert(0, str(PROJECT_ROOT))

DATA_DIR = NOTEBOOK_DIR / 'data'
RESULTS_DIR = NOTEBOOK_DIR / 'results'
FIG_DIR = NOTEBOOK_DIR / 'figures'
FIG_DIR.mkdir(exist_ok=True)

# -- SAME imports --
from src.synthetic_datagen import (
    create_full_benchmark,
    visualize_benchmark,
    visualize_space_tearing,
    visualize_topological_merger,
    print_statistics,
    check_triangle_violations_within_quadrants,
    CLASS_NAMES, CLASS_COLORS
)
from src.eval_utils import check_triangle_violations

print(f"Project root: {PROJECT_ROOT}")
print(f"Data dir:     {DATA_DIR}")
print(f"Results dir:  {RESULTS_DIR}")
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Project root: /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME
Data dir:     /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME/examples/synthetic/data
Results dir:  /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME/examples/synthetic/results

1. Load Dataset¶

In [3]:
# Load pre-generated dataset (seed=8899)
ref_df = pd.read_csv(DATA_DIR / 'ref.csv', index_col=0)
query_df = pd.read_csv(DATA_DIR / 'query.csv', index_col=0)
ground_truth_df = pd.read_csv(DATA_DIR / 'ground_truth.csv', index_col=0)
quadrants = pickle.load(open(DATA_DIR / 'quadrants.pkl', 'rb'))

print(f"Template: {len(ref_df)} cells,  Query: {len(query_df)} cells")
print(f"Ground truth pairs: {len(ground_truth_df)}")
print(f"Cell types: {sorted(ref_df['cell_type'].unique())}")
print(f"\nTemplate cell type counts:\n{ref_df['cell_type'].value_counts().sort_index()}")
print(f"\nQuery cell type counts:\n{query_df['cell_type'].value_counts().sort_index()}")
Template: 411 cells,  Query: 372 cells
Ground truth pairs: 372
Cell types: ['c1', 'c2', 'c3']

Template cell type counts:
cell_type
c1    142
c2    119
c3    150
Name: count, dtype: int64

Query cell type counts:
cell_type
c1    142
c2    119
c3    111
Name: count, dtype: int64

2. Fig 2a,c — Template & Query Overview¶

Four quadrants with distinct alignment challenges. Template (a) and Query (c) slices colored by cell type.

In [4]:
fig = visualize_benchmark(ref_df, query_df, quadrants)
plt.savefig(FIG_DIR / 'Fig2ac_benchmark_overview.svg', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

3. Fig S1 — Per-Quadrant Transformation Details¶

S1a: Top-Left — Missing cell type (c3 absent in query)¶

In [5]:
q = quadrants['top_left']
fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))

# Template
ax = axes[0]
for c in range(3):
    mask = q['ref_classes'] == c
    if mask.sum() > 0:
        ax.scatter(q['ref_points'][mask, 0], q['ref_points'][mask, 1],
                   c=list(CLASS_COLORS.values())[c], s=50,
                   label=CLASS_NAMES[c], edgecolors='black', linewidth=0.5)
ax.set_title(f'a. Template: {len(q["ref_points"])} cells', fontsize=12, loc='left')
ax.set_aspect('equal')

# Query (c3 removed)
ax = axes[1]
for c in range(2):
    mask = q['query_classes'] == c
    if mask.sum() > 0:
        ax.scatter(q['query_points'][mask, 0], q['query_points'][mask, 1],
                   c=list(CLASS_COLORS.values())[c], s=50, marker='P',
                   label=CLASS_NAMES[c], edgecolors='black', linewidth=0.5)
ax.set_title(f'b. Query: {len(q["query_points"])} cells', fontsize=12, loc='left')
ax.set_aspect('equal')

for ax in axes:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS1a_missing_class.svg', dpi=300, bbox_inches='tight')
plt.show()
print(f"Ref: c1={sum(q['ref_classes']==0)}, c2={sum(q['ref_classes']==1)}, c3={sum(q['ref_classes']==2)}")
print(f"Query: c1={sum(q['query_classes']==0)}, c2={sum(q['query_classes']==1)} (c3 missing!)")
No description has been provided for this image
Ref: c1=30, c2=32, c3=38
Query: c1=30, c2=32 (c3 missing!)

S1b: Bottom-Left — Topological split (1 ellipse → 2 rings)¶

In [6]:
fig = visualize_topological_merger(quadrants)
plt.savefig(FIG_DIR / 'FigS1b_topological_split.svg', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

S1c: Bottom-Right — Space-tearing (point swaps + shear)¶

In [7]:
fig = visualize_space_tearing(quadrants, 'bottom_right', recompute_triangulation=True, min_angle_deg=5)
plt.savefig(FIG_DIR / 'FigS1c_space_tearing.svg', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

S1d: Top-Right — Noisy class probabilities¶

In [8]:
probs = ['c1', 'c2', 'c3']
cmap = 'viridis'
fixed_ticks = [0, 5, 10]

fig = plt.figure(figsize=(12, 7))
gs = gridspec.GridSpec(2, 4, width_ratios=[1, 1, 1, 0.05], wspace=0.35, hspace=0.35)
ax = np.array([[fig.add_subplot(gs[r, c]) for c in range(3)] for r in range(2)])
norm = Normalize(vmin=0, vmax=100)

labels = [['a', 'b', 'c'], ['d', 'e', 'f']]
row_labels = ['Template', 'Query']
dfs = [ref_df, query_df]

for row in range(2):
    for i, p in enumerate(probs):
        sc = ax[row, i].scatter(
            dfs[row]['X'], dfs[row]['Y'], c=dfs[row][p], cmap=cmap, norm=norm,
            s=20, alpha=0.85, edgecolors='black', linewidth=0.3)
        ax[row, i].set_title(f"{labels[row][i]}. {row_labels[row]}: Class probability: {p}", fontsize=12)
        ax[row, i].set_aspect('equal')
        ax[row, i].set_xticks(fixed_ticks)
        ax[row, i].set_yticks(fixed_ticks)
        ax[row, i].spines['top'].set_visible(False)
        ax[row, i].spines['right'].set_visible(False)

cax = fig.add_subplot(gs[:, 3])
cbar = fig.colorbar(sc, cax=cax, orientation='vertical')
cbar.set_label('Class Probability (%)', fontsize=12)
cbar.set_ticks(np.linspace(0, 100, 6))

plt.tight_layout(rect=[0, 0, 0.97, 1])
plt.savefig(FIG_DIR / 'FigS1d_noisy_probabilities.svg', dpi=300, bbox_inches='tight')
plt.show()
/hpc/group/singhlab/user/ap756/tmp/ipykernel_4034952/3965856167.py:31: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  plt.tight_layout(rect=[0, 0, 0.97, 1])
No description has been provided for this image

4. Load SAME Results¶

Load pre-computed matching results (dp=10) and metacell objects.

In [9]:
mc_align = pickle.load(open(RESULTS_DIR / 'mc_align.pkl', 'rb'))
mc_ref = pickle.load(open(RESULTS_DIR / 'mc_ref.pkl', 'rb'))
matches_df = pd.read_csv(RESULTS_DIR / 'matchedDF.csv')

# Add cell type annotations from metacell objects
matches_df['cell_type'] = mc_align.metacell_df.loc[
    matches_df['Aligned_metacell_id'].values, 'cell_type'].values
matches_df['SAME_X'] = matches_df['ref_X']
matches_df['SAME_Y'] = matches_df['ref_Y']

# Evaluate cell type matching accuracy
matches_df['align_cell_type'] = mc_align.metacell_df.loc[
    matches_df['Aligned_metacell_id'].values, 'cell_type'].values
matches_df['ref_cell_type'] = mc_ref.metacell_df.loc[
    matches_df['Ref_metacell_id'].values, 'cell_type'].values
ct_accuracy = (matches_df.align_cell_type == matches_df.ref_cell_type).mean()

print(f"Matches: {len(matches_df)}")
print(f"Cell type matching accuracy: {ct_accuracy:.1%}")
Matches: 372
Cell type matching accuracy: 100.0%

5. Fig 2e — SAME Alignment Result¶

Displacement vectors showing query → template matching, colored by cell type.

In [10]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

# Background: query cells (faint)
for ct, color in CLASS_COLORS.items():
    mask = query_df['cell_type'] == ct
    if mask.sum() > 0:
        ax.scatter(query_df.loc[mask, 'X'], query_df.loc[mask, 'Y'],
                   c=[color], s=30, marker='P', alpha=0.2)

# Displacement lines
for _, row in matches_df.iterrows():
    ax.plot([row['X'], row['ref_X']], [row['Y'], row['ref_Y']],
            'k-', alpha=0.3, linewidth=1, zorder=3)

# Matched positions colored by cell type
query_colors = matches_df['cell_type'].map(CLASS_COLORS)
ax.scatter(matches_df['ref_X'], matches_df['ref_Y'], c=query_colors,
           s=35, edgecolors='black', linewidth=0.4, alpha=0.85)

ax.set_title(f'SAME Alignment (100% cell type match)', fontsize=12)
ax.set_aspect('equal')
ax.set_axis_off()

plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig2e_SAME_alignment.svg', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

Fig 2e (zoom) — Bottom-left quadrant detail¶

In [11]:
fig, axes = plt.subplots(1, 2, figsize=(6, 4))

for ax_idx, (ax, title, show_arrows) in enumerate(zip(
        axes, ['Query', 'SAME'], [False, True])):
    for ct, color in CLASS_COLORS.items():
        mask = query_df['cell_type'] == ct
        if mask.sum() > 0:
            ax.scatter(query_df.loc[mask, 'X'], query_df.loc[mask, 'Y'],
                       c=[color], s=40, marker='P', alpha=(0.2 if show_arrows else 1.0))

    if show_arrows:
        ax.quiver(matches_df['X'], matches_df['Y'],
                  matches_df['ref_X'] - matches_df['X'],
                  matches_df['ref_Y'] - matches_df['Y'],
                  angles='xy', scale_units='xy', scale=1, alpha=0.2,
                  width=0.01, headwidth=3, headlength=4, color='gray')
        query_colors = matches_df['cell_type'].map(CLASS_COLORS)
        ax.scatter(matches_df['ref_X'], matches_df['ref_Y'], c=query_colors,
                   s=40, edgecolors='black', linewidth=0.4, alpha=0.85)

    ax.set_title(title, fontsize=12)
    ax.set_aspect('equal')
    ax.set_axis_off()
    ax.set_xlim(1.3, 5)
    ax.set_ylim(2, 5)

plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig2e_zoom_bottom_left.svg', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

6. Fig 2 — Triangle Violation Analysis¶

Spatial visualization of triangle orientation violations (expected in space-tearing and topological split quadrants).

In [12]:
# Check triangle violations
matches_df = check_triangle_violations_within_quadrants(matches_df, mc_align)

tri_violations, report = check_triangle_violations(
    matches_df, mc_align,
    aligned_id_col='Aligned_metacell_id',
    ref_id_col='Ref_metacell_id',
    mapped_x_col='ref_X',
    mapped_y_col='ref_Y',
    cell_type_col='cell_type',
    ignore_same_type_triangles=False,
    node_local=False)

print("Triangle violation summary:")
print(tri_violations[['triangle_violation', 'in_violating_triangle']].value_counts())
Triangle violation summary:
triangle_violation  in_violating_triangle
False               False                    324
True                True                      46
False               True                       2
Name: count, dtype: int64
In [13]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

ax.scatter(ref_df['X'], ref_df['Y'], c='blue', marker='P', s=20, alpha=0.3)

good_matches = matches_df[~matches_df['triangle_violation']]
bad_matches = matches_df[matches_df['triangle_violation']]

# Good matches: black arrows
for _, row in good_matches.iterrows():
    ax.plot([row['X'], row['ref_X']], [row['Y'], row['ref_Y']],
            'k-', alpha=0.3, linewidth=1)

# Violations: magenta arrows
for _, row in bad_matches.iterrows():
    ax.plot([row['X'], row['ref_X']], [row['Y'], row['ref_Y']],
            'm-', alpha=0.8, linewidth=1.5)

ax.scatter(good_matches['X'], good_matches['Y'], c='blue', s=30,
           label=f'Good ({len(good_matches)})')
ax.scatter(bad_matches['X'], bad_matches['Y'], c='magenta', s=50, marker='x',
           linewidths=2, label=f'Violation ({len(bad_matches)})')

ax.set_title('Triangle Violations', fontsize=12)
ax.legend()
ax.set_aspect('equal')
ax.set_axis_off()

plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig2_triangle_violations.svg', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

7. Supplementary — Delaunay Triangulation on Metacells¶

In [14]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for ax, mc, title in zip(axes, [mc_align, mc_ref], ['Query (aligned)', 'Template (reference)']):
    for ct, color in CLASS_COLORS.items():
        mask = mc.metacell_df['cell_type'] == ct
        ax.scatter(mc.metacell_df.loc[mask, 'X'], mc.metacell_df.loc[mask, 'Y'],
                   c=color, s=25, alpha=0.85, edgecolors='black', linewidth=0.3, label=ct)

    for triangle in mc.metacell_delaunay:
        pts = mc.metacell_df.iloc[list(triangle)][['X', 'Y']].values
        poly = plt.Polygon(pts, closed=True, fill=False,
                           edgecolor='gray', linewidth=0.5, alpha=0.5)
        ax.add_patch(poly)

    ax.set_title(f'{title} ({len(mc.metacell_df)} cells, {len(mc.metacell_delaunay)} triangles)')
    ax.set_aspect('equal')
    ax.legend()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS_delaunay_triangulation.svg', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

(Optional) Full Reproduction¶

Run this section to regenerate data and re-run SAME from scratch. Requires Gurobi license. See run_same.sh for the equivalent command-line parameters.

In [ ]:
# ============================================================
# UNCOMMENT TO REGENERATE DATA
# ============================================================
# np.random.seed(8899)
# ref_df, query_df, quadrants, ground_truth_df, ex_df = create_full_benchmark()
# print_statistics(ref_df, query_df, quadrants)
# 
# ref_df.to_csv(DATA_DIR / 'ref.csv', index=True)
# query_df.to_csv(DATA_DIR / 'query.csv', index=True)
# ground_truth_df.to_csv(DATA_DIR / 'ground_truth.csv', index=True)
# pickle.dump(quadrants, open(DATA_DIR / 'quadrants.pkl', 'wb'))
# 
# import anndata as ad
# refAD = ad.AnnData(ex_df['ref'], obs=ref_df)
# queryAD = ad.AnnData(ex_df['query'], obs=query_df)
# refAD.write_h5ad(DATA_DIR / 'ref.h5ad')
# queryAD.write_h5ad(DATA_DIR / 'query.h5ad')
In [ ]:
# ============================================================
# UNCOMMENT TO RE-RUN SAME OPTIMIZATION
# ============================================================
# from src.metacell_utils import greedy_triangle_collapse
# from src.same import init_gurobi_params, init_optim_params, sliding_window_matching
# import time
# 
# # -- Create metacells (MS=1 means no aggregation, just builds Delaunay) --
# mc_align = greedy_triangle_collapse(
#     query_df, cell_type_col='cell_type', original_idx_col='cell_idx',
#     x_col='X', y_col='Y', max_metacell_size=1, r_max=5, min_angle_deg=5,
#     use_alpha_shape=False, alpha=None, return_object=True)
# mc_ref = greedy_triangle_collapse(
#     ref_df, cell_type_col='cell_type', original_idx_col='cell_idx',
#     x_col='X', y_col='Y', max_metacell_size=1, r_max=5, min_angle_deg=5,
#     use_alpha_shape=False, alpha=None, return_object=True)
# 
# # -- SAME parameters --
# gurobi_params = init_gurobi_params()
# gurobi_params['mip_gap'] = 0.025
# gurobi_params['lazy_allowed_flip_fraction'] = 0.0
# 
# optim_params = init_optim_params()
# optim_params.update({
#     'window_size': 100, 'overlap': 0, 'min_cells_per_window': 30,
#     'max_matches': 2, 'radius': 5, 'knn': 8,
#     'no_match_penalty': 10000, 'dist_ct_coeff': 1, 'min_angle_deg': 5,
#     'penalty_coeff': 100, 'delaunay_penalty': 10,
#     'cell_id_col': 'metacell_id', 'ref_metacell_match_multiplier': 1,
#     'ignore_same_type_triangles': False, 'lazy_constraints': True,
# })
# 
# outprefix = str(RESULTS_DIR / 'dp10')
# start = time.time()
# matches_df = sliding_window_matching(
#     mc_ref, mc_align, outprefix=outprefix,
#     optim_params=optim_params, gurobi_params=gurobi_params,
#     ignore_precomputed_triangulation=False)
# print(f"SAME completed in {time.time()-start:.1f}s")
# 
# pickle.dump(mc_align, open(RESULTS_DIR / 'mc_align.pkl', 'wb'))
# pickle.dump(mc_ref, open(RESULTS_DIR / 'mc_ref.pkl', 'wb'))