ISS Heart β Figure ReproductionΒΆ
Paper Figures: Fig 3 (aβc), Fig S4, Fig S5, Fig S6, Fig S7
This notebook reproduces the ISS heart spatial alignment figures from the SAME paper. The optimization results are pre-computed; this notebook loads them and generates publication figures.
Dataset: Human embryonic heart, smFISH-based ISS profiling
- Template: ~3160 cells, Query: ~3160 cells (VALIS-aligned)
- 8 cell types: Atrial/Ventricular cardiomyocytes, Endothelium, Epicardium, Fibroblast, Smooth muscle cells, Schwan progenitors
Main result: dp=10, MS=1, knn=8 β 71.6% cell type accuracy, <5% triangle violations
| Parameter | Value |
|---|---|
delaunay_penalty |
10 |
window_size |
4000 |
overlap |
100 |
max_matches |
1 |
radius / knn |
50 / 8 |
metacell_size |
1 (also 3, 7 in sweep) |
lazy_constraints |
True |
ignore_same_type_triangles |
True |
Note: Fig 3d-f (method comparisons) require benchmarking baselines from a separate machine β deferred.
SetupΒΆ
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import pickle
from pathlib import Path
from matplotlib.lines import Line2D
# -- Publication style --
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.size'] = 12
# -- Path setup --
NOTEBOOK_DIR = Path('.').resolve()
PROJECT_ROOT = NOTEBOOK_DIR.parents[1] # SAME/examples/heart -> SAME/
import sys
sys.path.insert(0, str(PROJECT_ROOT))
DATA_DIR = NOTEBOOK_DIR / 'data'
RESULTS_DIR = NOTEBOOK_DIR / 'results'
FIG_DIR = NOTEBOOK_DIR / 'figures'
FIG_DIR.mkdir(exist_ok=True)
from src.eval_utils import check_alignment, check_triangle_violations
print(f"Project root: {PROJECT_ROOT}")
Project root: /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME
# Cell types in the ISS heart dataset
CELL_TYPES = [
'Smooth muscle cells', 'Fibroblast', 'Atrial cardiomyocytes',
'Cardiomyocytes', 'Endothelium', 'Epicardium',
'Schwan progenitors', 'Ventricular cardiomyocytes'
]
# Column names in the raw CSV files have _percentage suffix
CT_COLS_RAW = [ct + '_percentage' for ct in CELL_TYPES]
CT_RENAME = {ct + '_percentage': ct for ct in CELL_TYPES}
# Cell type colors (sns Set1 palette, matching notebooks/ISS_Heart/ order)
CT_COLORS = {
'Smooth muscle cells': '#e41a1c',
'Fibroblast': '#377eb8',
'Epicardium': '#4daf4a',
'Atrial cardiomyocytes': '#984ea3',
'Endothelium': '#ff7f00',
'Schwan progenitors': '#ffff33',
'Cardiomyocytes': '#a65628',
'Ventricular cardiomyocytes':'#f781bf',
}
# Reference baselines (from evaluate_all_alignments.ipynb)
INITIAL_ACCURACY = 57.60 # Before alignment (image-based only)
INITIAL_VIOLATIONS = 0.0
EXPRESSION_ACCURACY = 64.69 # Expression-based matching (no phenotype labels)
EXPRESSION_VIOLATIONS = 6.15
# Color schemes for metacell sizes
MS_COLORS = {1: '#3182bd', 3: '#e31a1c', 7: '#756bb1', 15: '#31a354'}
MS_CMAPS = {1: 'Blues', 3: 'Reds', 7: 'Purples', 15: 'Greens'}
1. Load Input DataΒΆ
# Load ISS Heart spatial data
refDF = pd.read_csv(DATA_DIR / 'refAD_valis.csv')
alignDF = pd.read_csv(DATA_DIR / 'queryAD_valis.csv')
# Rename _percentage columns to bare cell type names
refDF.rename(columns=CT_RENAME, inplace=True)
alignDF.rename(columns=CT_RENAME, inplace=True)
# Preprocessing (matching the original pipeline)
refDF['X'] = refDF['spot_x'] + 75
refDF['Y'] = refDF['spot_y'] + 75
alignDF['X'] = alignDF['spot_x'] + 75
alignDF['Y'] = alignDF['spot_y'] + 75
refDF['cell_type'] = refDF[CELL_TYPES].idxmax(axis=1)
alignDF['cell_type'] = alignDF[CELL_TYPES].idxmax(axis=1)
print(f"Template: {len(refDF)} cells")
print(f"Query: {len(alignDF)} cells")
print(f"\nCell type distribution (Template):")
print(refDF['cell_type'].value_counts())
Template: 3801 cells Query: 3184 cells Cell type distribution (Template): cell_type Atrial cardiomyocytes 1218 Cardiomyocytes 790 Fibroblast 621 Ventricular cardiomyocytes 480 Epicardium 279 Endothelium 276 Smooth muscle cells 129 Schwan progenitors 8 Name: count, dtype: int64
2. Fig 3a,b β Cell Types in Template and QueryΒΆ
Spatial distribution of cell types in template (a) and query (b) slices.
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
for ax, df, title in zip(axes, [refDF, alignDF], ['a. Template', 'b. Query']):
for ct in CELL_TYPES:
mask = df['cell_type'] == ct
if mask.sum() > 0:
ax.scatter(df.loc[mask, 'X'], df.loc[mask, 'Y'],
s=8, alpha=0.7, label=ct, color=CT_COLORS[ct])
ax.set_title(title, fontsize=14, loc='left', fontweight='bold')
ax.set_aspect('equal')
ax.invert_yaxis()
ax.set_axis_off()
axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', markerscale=3, fontsize=9)
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig3ab_cell_types.svg', dpi=300, bbox_inches='tight')
plt.show()
3. Fig 3c β Accuracy vs Triangle ViolationsΒΆ
Sweep across Delaunay penalty (dp=0,1,5,10,25,50) for metacell sizes MS=1,3,7. Each point is one SAME run. Larger markers = higher penalty.
# Load pre-computed results table
results_df = pd.read_csv(RESULTS_DIR / 'results.csv')
def get_marker_size(dp):
return np.log10(dp + 1) * 60 + 50
MS_VALUES = [1, 3, 7]
DP_VALUES = [0, 1, 5, 10, 25, 50]
fig, ax = plt.subplots(figsize=(6, 4))
for ms in MS_VALUES:
ms_data = results_df[results_df['MS'] == ms].sort_values('dp')
cmap = matplotlib.colormaps[MS_CMAPS[ms]]
accs, viols = [], []
for _, row in ms_data.iterrows():
accs.append(row['accuracy'])
viols.append(row['violations'])
norm_dp = row['dp'] / 50
color = cmap(0.3 + 0.6 * norm_dp)
size = get_marker_size(row['dp'])
ax.scatter(row['accuracy'], row['violations'], s=size, color=color,
alpha=0.9, edgecolors='white', linewidths=0.8, zorder=5)
ax.plot(accs, viols, color=MS_COLORS[ms], alpha=0.5, linewidth=2, zorder=3)
# MS labels
for i, ms in enumerate(MS_VALUES):
ms_data = results_df[results_df['MS'] == ms]
best = ms_data.loc[ms_data['accuracy'].idxmax()]
offset = -2 if i % 2 == 0 else 1
ax.annotate(f'MS{ms}', (best['accuracy'] + offset, best['violations']),
fontsize=10, fontweight='bold', va='center', color=MS_COLORS[ms])
# Reference points
ax.scatter(INITIAL_ACCURACY, INITIAL_VIOLATIONS, s=250, marker='*',
color='red', zorder=10, edgecolors='black', linewidths=0.5)
ax.annotate('Initial', (INITIAL_ACCURACY + 0.8, INITIAL_VIOLATIONS + 1), fontsize=9, va='bottom')
ax.scatter(EXPRESSION_ACCURACY, EXPRESSION_VIOLATIONS, s=100, marker='D',
color='limegreen', edgecolors='black', linewidths=1, zorder=10)
ax.annotate('no phenotype', (EXPRESSION_ACCURACY + 0.8, EXPRESSION_VIOLATIONS), fontsize=9, va='center')
# Legend for dp values
legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='gray',
markersize=np.sqrt(get_marker_size(dp)), alpha=0.7, label=f'{dp}')
for dp in DP_VALUES]
ax.legend(title='Penalty', handles=legend_elements, loc='upper left', fontsize=11, ncol=2)
ax.set_xlabel('Cell type matches (%)', fontsize=11)
ax.set_ylabel('Triangle violations (%)', fontsize=11)
ax.set_xlim(55, 90)
ax.set_ylim(-3, 35)
ax.grid(True, linestyle='--', alpha=0.4)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig3c_accuracy_vs_violations.svg', bbox_inches='tight')
plt.savefig(FIG_DIR / 'Fig3c_accuracy_vs_violations.png', dpi=300, bbox_inches='tight')
plt.show()
4. Fig S4 β kNN Parameter Analysis (dp=5)ΒΆ
Effect of k-nearest neighbors on cell type matches, triangle violations, and runtime.
knn_results = pd.read_csv(RESULTS_DIR / 'knn_results.csv').sort_values('knn')
knn_order = knn_results['knn'].tolist()
fig, ax = plt.subplots(1, 3, figsize=(11.5, 4))
# a. Accuracy
sns.barplot(x='knn', y='accuracy', data=knn_results, order=knn_order, ax=ax[0], color='lightblue')
ax[0].set_xlabel('$k$-nearest neighbors')
ax[0].set_ylabel('Cell type matches (%)')
ax[0].set_title(r'$\mathbf{a.}$ Cell type matches', loc='left')
ax[0].spines['top'].set_visible(False)
ax[0].spines['right'].set_visible(False)
# Add improvement annotations
for i in range(1, len(knn_order)):
prev_acc = knn_results[knn_results['knn'] == knn_order[i-1]]['accuracy'].values[0]
curr_acc = knn_results[knn_results['knn'] == knn_order[i]]['accuracy'].values[0]
improvement = ((curr_acc - prev_acc) / prev_acc) * 100
y_val = curr_acc
ax[0].text(i, y_val + 1, f'+{improvement:.1f}%', ha='center', va='bottom', fontsize=8)
# b. Violations
sns.barplot(x='knn', y='violations', data=knn_results, order=knn_order, ax=ax[1], color='coral')
ax[1].set_xlabel('$k$-nearest neighbors')
ax[1].set_ylabel('Triangle violations (%)')
ax[1].set_title(r'$\mathbf{b.}$ Triangle violations', loc='left')
ax[1].spines['top'].set_visible(False)
ax[1].spines['right'].set_visible(False)
# c. Runtime
knn_results['runtime_min'] = knn_results['runtime_sec'] / 60
sns.barplot(x='knn', y='runtime_min', data=knn_results, order=knn_order, ax=ax[2], color='forestgreen')
ax[2].set_xlabel('$k$-nearest neighbors')
ax[2].set_ylabel('Runtime (minutes)')
ax[2].set_title(r'$\mathbf{c.}$ Total time to solve SAME ILP', loc='left')
ax[2].spines['top'].set_visible(False)
ax[2].spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS4_knn_comparison.svg', bbox_inches='tight')
plt.savefig(FIG_DIR / 'FigS4_knn_comparison.png', dpi=300, bbox_inches='tight')
plt.show()
5. Fig S5 β Robustness to Phenotype Labeling NoiseΒΆ
SAME performance with Dirichlet mixture noise injected into cell type compositions (dp=10, knn=8). Noise level 0 = original labels, 1 = fully random.
noise_levels = [0, 0.2, 0.4, 0.6, 0.8, 1.0]
noiseDF = pd.DataFrame(index=noise_levels, columns=['Accuracy', 'Noise effect', 'Time'])
for rn in noise_levels:
rn_int = int(rn * 100)
noise_dir = RESULTS_DIR / f'noise_rn{rn_int}'
try:
mc_align = pickle.load(open(noise_dir / 'mc_align.pkl', 'rb'))
mc_ref = pickle.load(open(noise_dir / 'mc_ref.pkl', 'rb'))
matchesDF = pd.read_csv(noise_dir / 'matchedDF.csv')
except Exception as e:
print(f'Skipping rn{rn_int}: {e}')
continue
# Cell type from original labels (before noise)
matchesDF['cell_type'] = mc_align.metacell_df.loc[
matchesDF['Aligned_metacell_id'], 'cell_type'].values
# Cell type after noise (for measuring label change)
if rn_int == 0:
matchesDF['cell_type_noise'] = matchesDF['cell_type']
else:
matchesDF['cell_type_noise'] = mc_align.metacell_df.loc[
matchesDF['Aligned_metacell_id'], 'cell_type_noise'].values
matchesDF['SAME_X'] = matchesDF['ref_X']
matchesDF['SAME_Y'] = matchesDF['ref_Y']
mc_ref_df = mc_ref.metacell_df.copy()
mc_ref_df['SAME_X'] = mc_ref_df['X']
mc_ref_df['SAME_Y'] = mc_ref_df['Y']
alignDF_copy, _ = check_alignment(matchesDF, mc_ref_df,
xcol='SAME_X', ycol='SAME_Y',
ctype_col='cell_type', kNN=1)
noiseDF.loc[rn, 'Accuracy'] = 100 * alignDF_copy['_1NN_match'].sum() / len(alignDF_copy)
noiseDF.loc[rn, 'Noise effect'] = 100 * (alignDF_copy['cell_type_noise'] != alignDF_copy['cell_type']).sum() / len(alignDF_copy)
noiseDF.loc[rn, 'Time'] = matchesDF['run_time'].iloc[0] / 60
noiseDF['Noise'] = noiseDF.index.values
noiseDF = noiseDF.astype(float)
print(noiseDF[['Accuracy', 'Noise effect', 'Time']])
Accuracy Noise effect Time 0.0 71.600253 0.000000 6.396518 0.2 70.936116 6.135357 20.615231 0.4 69.345144 9.933565 20.932100 0.6 66.181588 24.011389 7.257369 0.8 60.044290 65.295792 8.446541 1.0 54.856058 86.744701 20.317359
fig, ax = plt.subplots(1, 3, figsize=(9, 3))
sns.barplot(data=noiseDF, x='Noise', y='Accuracy', ax=ax[0], color='blue')
ax[0].set_xlabel(r'Noise ($\eta$)')
ax[0].set_ylabel('Cell type matches (%)')
ax[0].set_ylim(0, 80)
ax[0].axhline(INITIAL_ACCURACY, color='black', linestyle='--')
ax[0].set_title(r'$\mathbf{a.}$ SAME robustness', loc='left')
sns.barplot(data=noiseDF, x='Noise', y='Noise effect', ax=ax[1], color='red')
ax[1].set_xlabel(r'Noise ($\eta$)')
ax[1].set_ylabel('% of input cell labels \naltered by noise')
ax[1].set_ylim(0, 90)
ax[1].set_title(r'$\mathbf{b.}$ Input cell type changes', loc='left')
sns.barplot(data=noiseDF, x='Noise', y='Time', ax=ax[2], color='green')
ax[2].set_xlabel(r'Noise ($\eta$)')
ax[2].set_ylabel('Time (mins.)')
ax[2].set_title(r'$\mathbf{c.}$ Runtime (mins.)', loc='left')
for a in ax:
a.spines['top'].set_visible(False)
a.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS5_noise_robustness.svg', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'FigS5_noise_robustness.png', dpi=300, bbox_inches='tight')
plt.show()
6. Fig S6 β Effect of Metacell Size and PenaltyΒΆ
Heatmaps of accuracy and violations, plus runtime, across the MS x dp grid.
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
df_plot = results_df[results_df['MS'].isin([1, 3, 7])]
pivot_acc = df_plot.pivot(index='MS', columns='dp', values='accuracy')
pivot_viol = df_plot.pivot(index='MS', columns='dp', values='violations')
sns.heatmap(pivot_acc, annot=True, fmt='.1f', cmap='YlGnBu', ax=axes[0])
axes[0].set_title(r'$\mathbf{a.}$ Cell type matches (%) by MS and Penalty', loc='left')
axes[0].set_xlabel('Delaunay Penalty')
axes[0].set_ylabel('Metacell Size')
sns.heatmap(pivot_viol, annot=True, fmt='.1f', cmap='YlOrRd', ax=axes[1])
axes[1].set_title(r'$\mathbf{b.}$ Triangle Violations (%) by MS and Penalty', loc='left')
axes[1].set_xlabel('Delaunay Penalty')
axes[1].set_ylabel('Metacell Size')
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS6_heatmap_ms_dp.svg', bbox_inches='tight')
plt.savefig(FIG_DIR / 'FigS6_heatmap_ms_dp.png', dpi=300, bbox_inches='tight')
plt.show()
# Runtime heatmap
penaltyOrder = ['0', '1', '5', '10', '25', '50']
df_time = results_df[results_df['MS'].isin([1, 3, 7])].copy()
df_time['Penalty'] = df_time['dp'].astype(str)
df_time['MS'] = df_time['MS'].astype(str)
heatmap_data = df_time.pivot_table(
index='MS', columns='Penalty', values='runtime_sec', aggfunc='first') / 60
heatmap_data = heatmap_data.reindex(
index=[str(ms) for ms in sorted([1, 3, 7])], columns=penaltyOrder)
plt.figure(figsize=(5, 3))
sns.heatmap(heatmap_data, annot=True, fmt='.1f', cmap='viridis',
cbar_kws={'label': 'Total time taken (minutes)'},
linewidths=0.5, linecolor='white')
plt.xlabel('Penalty')
plt.ylabel('MS')
plt.title('Total time taken (minutes)', loc='left', fontweight='bold')
plt.xticks(fontweight='bold')
plt.yticks(fontweight='bold', rotation=0)
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS6_time_heatmap.svg', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'FigS6_time_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()
7. Fig S7 β Nodes in Violating Triangles vs PenaltyΒΆ
Bar plot showing the percentage of nodes participating in violated triangles.
fig, ax = plt.subplots(1, 2, figsize=(9, 4))
df_plot = results_df[results_df['MS'].isin([1, 3, 7])].copy()
sns.barplot(x='dp', y='nodes_violating', hue='MS', data=df_plot, palette='bright', ax=ax[0])
ax[0].set_xlabel('Penalty')
ax[0].set_ylabel('% Nodes violating triangle constraints')
ax[0].set_title('Nodes Violating vs Penalty')
ax[0].legend(title='Metacell size')
ax[0].spines['top'].set_visible(False)
ax[0].spines['right'].set_visible(False)
sns.barplot(x='dp', y='violations', hue='MS', data=df_plot, palette='bright', ax=ax[1])
ax[1].set_xlabel('Penalty')
ax[1].set_ylabel('% Triangle violations')
ax[1].set_title('Triangle violations')
ax[1].legend(title='Metacell size')
ax[1].spines['top'].set_visible(False)
ax[1].spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS7_nodes_violating.svg', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'FigS7_nodes_violating.png', dpi=300, bbox_inches='tight')
plt.show()
8. Fig 3 β Spatial Alignment Visualization (dp=10)ΒΆ
Load the main dp=10 result and show matched cells spatially.
# Load dp=10 knn=8 MS=1 result (paper's reported configuration)
main_dir = RESULTS_DIR / 'dp10_knn8_MS1'
mc_align = pickle.load(open(main_dir / 'mc_align.pkl', 'rb'))
mc_ref = pickle.load(open(main_dir / 'mc_ref.pkl', 'rb'))
matches_df = pd.read_csv(main_dir / 'matchedDF.csv')
# Add cell type and spatial info
matches_df['cell_type'] = matches_df[CELL_TYPES].idxmax(axis=1)
matches_df['SAME_X'] = matches_df['ref_X']
matches_df['SAME_Y'] = matches_df['ref_Y']
mc_ref_df = mc_ref.metacell_df.copy()
mc_ref_df['cell_type'] = mc_ref_df[CELL_TYPES].idxmax(axis=1)
mc_ref_df['SAME_X'] = mc_ref_df['X']
mc_ref_df['SAME_Y'] = mc_ref_df['Y']
# Compute accuracy
alignDF_eval, _ = check_alignment(matches_df, mc_ref_df,
xcol='SAME_X', ycol='SAME_Y',
ctype_col='cell_type', kNN=1)
accuracy = 100 * alignDF_eval['_1NN_match'].sum() / len(alignDF_eval)
# Triangle violations
matches_df.index = matches_df['Aligned_metacell_id'].values
tri_df, stats = check_triangle_violations(
matches_df, mc_align,
aligned_id_col='Aligned_metacell_id', ref_id_col='Ref_metacell_id',
mapped_x_col='ref_X', mapped_y_col='ref_Y',
cell_type_col='cell_type', ignore_same_type_triangles=True, verbose=False)
violations = 100 * stats['triangles_flipped'] / stats['total_triangles']
print(f"dp=10, knn=8, MS=1:")
print(f" Cell type accuracy: {accuracy:.1f}%")
print(f" Triangle violations: {violations:.1f}%")
print(f" Matched cells: {len(matches_df)}")
dp=10, knn=8, MS=1: Cell type accuracy: 71.6% Triangle violations: 5.0% Matched cells: 3162
# Spatial plot: Template vs SAME-aligned Query
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# Template
ax = axes[0]
for ct in CELL_TYPES:
mask = mc_ref_df['cell_type'] == ct
if mask.sum() > 0:
ax.scatter(mc_ref_df.loc[mask, 'X'], mc_ref_df.loc[mask, 'Y'],
s=8, alpha=0.7, label=ct, color=CT_COLORS[ct])
ax.set_title('a. Template', fontsize=14, loc='left', fontweight='bold')
ax.set_aspect('equal')
ax.invert_yaxis()
ax.set_axis_off()
# SAME aligned
ax = axes[1]
for ct in CELL_TYPES:
mask = matches_df['cell_type'] == ct
if mask.sum() > 0:
ax.scatter(matches_df.loc[mask, 'SAME_X'], matches_df.loc[mask, 'SAME_Y'],
s=8, alpha=0.7, label=ct, color=CT_COLORS[ct])
ax.set_title(f'b. SAME (dp=10, {accuracy:.1f}% CT match)', fontsize=14, loc='left', fontweight='bold')
ax.set_aspect('equal')
ax.invert_yaxis()
ax.set_axis_off()
axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', markerscale=3, fontsize=9)
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig3_spatial_alignment.svg', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'Fig3_spatial_alignment.png', dpi=300, bbox_inches='tight')
plt.show()
(Optional) Full ReproductionΒΆ
To re-run the SAME optimization from scratch, use the provided shell scripts:
# Main dp/MS sweep (produces results.csv)
bash run_same.sh --dp 10 --knn 8 --ms 1
# kNN sweep (dp=5 fixed)
bash run_parameter_sweep.sh
# Noise robustness sweep
bash run_robustness.sh
See each .sh file for the full parameter specifications.