LUAD33 — Protein + Xenium Cross-Modal Alignment¶
Paper Figures: Fig 5 (a–f), Fig S15–S18
This notebook reproduces the LUAD (lung adenocarcinoma) cross-modal alignment figures from the SAME paper. SAME aligns protein (PCF, 33 markers) and RNA (Xenium, 300+ genes) spatial data from adjacent serial sections.
Dataset:
- Template (RNA): ~100K cells — 10x Xenium spatial transcriptomics
- Query (Protein): ~94K cells — Polychromatic Flow (PCF) imaging
- 5 cell types: B cell, Epithelial, Mesenchymal, Myeloid, T cell
Main result: dp=10, MS=3, knn=8, window_size=13000
| Parameter | Value |
|---|---|
delaunay_penalty |
10 (sweep: 0, 1, 5, 10, 25, 50) |
window_size |
13000 |
overlap |
250 |
max_matches |
1 |
radius / knn |
250 / 8 |
metacell_size |
3 |
lazy_constraints |
True |
Setup¶
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import anndata as ad
import pickle
from pathlib import Path
from scipy.spatial.distance import cdist
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.size'] = 12
NOTEBOOK_DIR = Path('.').resolve()
PROJECT_ROOT = NOTEBOOK_DIR.parents[1]
import sys
sys.path.insert(0, str(PROJECT_ROOT))
DATA_DIR = NOTEBOOK_DIR / 'data'
RESULTS_DIR = NOTEBOOK_DIR / 'results'
FIG_DIR = NOTEBOOK_DIR / 'figures'
FIG_DIR.mkdir(exist_ok=True)
from src.eval_utils import check_alignment, check_triangle_violations
from src.metacell_utils import unpack_metacell_matches
print(f'Project root: {PROJECT_ROOT}')
Project root: /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME
CELL_TYPES = ['B cell', 'Epithelial', 'Mesenchymal', 'Myeloid', 'T cell']
# Colors from original notebooks/luad_33/ analysis
CT_COLORS = {
'B cell': (1.0, 0.0, 0.8117654238977766),
'Epithelial': (0.0, 0.29411984742867114, 1.0),
'Mesenchymal': (1.0, 0.741177211765447, 0.0),
'Myeloid': (0.44705736433677606, 0.0, 1.0),
'T cell': (0.0, 1.0, 0.9647031631761764),
}
1. Load Input Data¶
# Protein (PCF) = query (aligned), RNA (Xenium) = template (ref)
alignDF = pd.read_csv(DATA_DIR / 'align_pcf.csv', index_col=0)
refDF = pd.read_csv(DATA_DIR / 'ref_xen.csv', index_col=0)
alignDF['Cell_Num_Old'] = alignDF.index.values
refDF['Cell_Num_Old'] = refDF.index.values
alignDF[CELL_TYPES] = alignDF[CELL_TYPES] * 100
refDF[CELL_TYPES] = refDF[CELL_TYPES] * 100
alignDF['cell_type'] = alignDF[CELL_TYPES].idxmax(axis=1)
refDF['cell_type'] = refDF[CELL_TYPES].idxmax(axis=1)
print(f'Template (Xenium RNA): {len(refDF)} cells')
print(f'Query (PCF Protein): {len(alignDF)} cells')
print(f'\nCell type distribution (PCF):')
print(alignDF['cell_type'].value_counts())
Template (Xenium RNA): 99827 cells Query (PCF Protein): 94442 cells Cell type distribution (PCF): cell_type Mesenchymal 44240 Epithelial 25697 Myeloid 9213 T cell 8052 B cell 7240 Name: count, dtype: int64
2. Fig 5a,b — Cell Types in Xenium Template and PCF Query¶
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
for ax, df, title in zip(axes,
[refDF, alignDF],
['a. Xenium Template (RNA)', 'b. PCF Query (Protein)']):
for ct in CELL_TYPES:
mask = df['cell_type'] == ct
if mask.sum() > 0:
ax.scatter(df.loc[mask, 'X'], df.loc[mask, 'Y'],
s=0.1, alpha=0.5, label=ct, color=CT_COLORS[ct])
ax.set_title(title, fontsize=14, loc='left', fontweight='bold')
ax.set_aspect('equal')
ax.invert_yaxis()
ax.set_axis_off()
axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', markerscale=6, fontsize=9)
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig5ab_cell_types.png', dpi=300, bbox_inches='tight')
plt.show()
3. Penalty Sweep — Runtime (Fig S18)¶
Total optimization time across dp = 0, 1, 5, 10, 25, 50.
mc_align = pickle.load(open(RESULTS_DIR / 'mc_align.pkl', 'rb'))
mc_ref = pickle.load(open(RESULTS_DIR / 'mc_ref.pkl', 'rb'))
dp_values = [0, 1, 5, 10, 25, 50]
sweep_results = []
for dp in dp_values:
folder = RESULTS_DIR / f'dp{dp}_knn8_MS3'
outputDF = pd.read_csv(folder / 'matchedDF.csv')
# Runtime per window
per_window = outputDF[['window_id', 'run_time']].groupby('window_id').first()['run_time']
total_hours = per_window.sum() / 3600
total_minutes = per_window.sum() / 60
sweep_results.append({
'dp': dp,
'total_time_hours': total_hours,
'total_time_minutes': total_minutes,
})
print(f'dp={dp:>2d}: time={total_hours:.2f}h ({total_minutes:.1f} min)')
sweep_df = pd.DataFrame(sweep_results)
sweep_df
dp= 0: time=0.01h (0.5 min) dp= 1: time=0.01h (0.6 min) dp= 5: time=0.01h (0.7 min) dp=10: time=0.03h (1.8 min) dp=25: time=4.17h (249.9 min) dp=50: time=10.15h (608.9 min)
| dp | total_time_hours | total_time_minutes | |
|---|---|---|---|
| 0 | 0 | 0.007676 | 0.460565 |
| 1 | 1 | 0.009880 | 0.592800 |
| 2 | 5 | 0.012190 | 0.731406 |
| 3 | 10 | 0.029466 | 1.767981 |
| 4 | 25 | 4.165308 | 249.918498 |
| 5 | 50 | 10.148551 | 608.913084 |
fig, axs = plt.subplots(1, 2, figsize=(9, 3.5))
# Bar chart
sns.barplot(data=sweep_df, x='dp', y='total_time_hours', ax=axs[0], color='steelblue')
axs[0].set_xlabel('Delaunay Penalty')
axs[0].set_ylabel('Total time (hours)')
axs[0].set_yscale('log')
for p in axs[0].patches:
h = p.get_height()
fmt = f'{h:.3f}' if h < 0.01 else f'{h:.2f}'
axs[0].annotate(fmt, (p.get_x() + p.get_width()/2., h),
ha='center', va='bottom', fontsize=8, xytext=(0, 2),
textcoords='offset points')
axs[0].spines['top'].set_visible(False)
axs[0].spines['right'].set_visible(False)
# Table
axs[1].axis('off')
table_data = [[int(row['dp']), f"{row['total_time_minutes']:.2f}"]
for _, row in sweep_df.iterrows()]
table = axs[1].table(cellText=table_data,
colLabels=['Penalty', 'Total time (min)'],
cellLoc='center', loc='center', colWidths=[0.4, 0.4])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)
for i in range(2):
table[(0, i)].set_facecolor('steelblue')
table[(0, i)].set_text_props(weight='bold', color='white')
for i in range(1, len(table_data) + 1):
for j in range(2):
table[(i, j)].set_facecolor('#E7E6E6' if i % 2 == 0 else '#FFFFFF')
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS18_time_vs_dp.svg', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'FigS18_time_vs_dp.png', dpi=300, bbox_inches='tight')
plt.show()
4. Main Result (dp=10) — Spatial Alignment + Top-k Match (Fig S19)¶
main_dir = RESULTS_DIR / 'dp10_knn8_MS3'
matchesDF = pd.read_csv(main_dir / 'matchedDF.csv')
# Unpack to individual cells
individual_matches = unpack_metacell_matches(
matchesDF, mc_align.metacell_df, mc_ref.metacell_df,
aligned_df=mc_align.original_df, ref_df=mc_ref.original_df,
strategy='nearest',
aligned_original_idx_col='Cell_Num_Old',
ref_original_idx_col='Cell_Num_Old')
# Cell type labels
aligned_ct = mc_align.original_df.set_index('Cell_Num_Old')['cell_type']
ref_ct = mc_ref.original_df.set_index('Cell_Num_Old')['cell_type']
individual_matches['aligned_celltype'] = individual_matches['Aligned_cell_id'].map(aligned_ct)
individual_matches['ref_celltype'] = individual_matches['Ref_cell_id'].map(ref_ct)
individual_matches['celltype_match'] = individual_matches['aligned_celltype'] == individual_matches['ref_celltype']
# Map to Xenium Cell_Num for top-k evaluation
mc_ref.original_df.index = mc_ref.original_df.Cell_Num_Old.values
individual_matches['Cell_Num'] = mc_ref.original_df.loc[
individual_matches['Ref_cell_id'].values, 'Cell_Num'].values
# Get SAME-aligned coords from ref
ref_coords = mc_ref.original_df.set_index('Cell_Num_Old')[['X', 'Y']]
individual_matches['SAME_X'] = individual_matches['Ref_cell_id'].map(ref_coords['X'])
individual_matches['SAME_Y'] = individual_matches['Ref_cell_id'].map(ref_coords['Y'])
# Map to QuPathID for h5ad subsetting later
mc_align.original_df.index = mc_align.original_df.Cell_Num_Old.values
individual_matches['QuPathID'] = mc_align.original_df.loc[
individual_matches['Aligned_cell_id'].values, 'QuPathID'].values
# Dominant cell type (from PCF)
individual_matches['Dominant_Cell_Type'] = individual_matches['aligned_celltype']
accuracy = individual_matches['celltype_match'].mean() * 100
print(f'dp=10: {len(individual_matches)} matches, {accuracy:.1f}% cell type accuracy')
dp=10: 91684 matches, 72.5% cell type accuracy
# Top-k cell type match using Xenium reference probabilities
xenDF = pd.read_csv(DATA_DIR / 'ref_xen.csv', index_col=0)
xenDF.index = xenDF['Cell_Num'].astype(str)
# Vectorized top-k evaluation
ref_probs = xenDF[CELL_TYPES].astype(float)
valid = individual_matches['Cell_Num'].astype(str).isin(ref_probs.index)
ref_rows = ref_probs.loc[individual_matches.loc[valid, 'Cell_Num'].astype(str)].values
ct_array = np.array(CELL_TYPES)
dom_types = individual_matches.loc[valid, 'Dominant_Cell_Type'].values
for k in [1, 2, 3]:
matches = np.zeros(len(individual_matches), dtype=bool)
top_k_idx = np.argpartition(ref_rows, -k, axis=1)[:, -k:]
top_k_types = ct_array[top_k_idx]
matches[valid.values] = np.any(top_k_types == dom_types[:, np.newaxis], axis=1)
individual_matches[f'top_{k}_match'] = matches
print(f"Top-1 match: {individual_matches['top_1_match'].mean():.1%}")
print(f"Top-2 match: {individual_matches['top_2_match'].mean():.1%}")
print(f"Top-3 match: {individual_matches['top_3_match'].mean():.1%}")
# Plot
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ix, (ax, k) in enumerate(zip(axes, [1, 2, 3])):
col = f'top_{k}_match'
correct = individual_matches[individual_matches[col]]
incorrect = individual_matches[~individual_matches[col]]
sns.scatterplot(data=individual_matches, x='SAME_X', y='SAME_Y', s=1, alpha=0.5, hue=col, palette={True: 'blue', False: 'red'}, ax=ax)
rate = individual_matches[col].mean() * 100
ax.set_title(f'{chr(97+ix)}. Correct in top-{k} ({rate:.1f}%)',
loc='left', fontweight='bold')
ax.invert_yaxis()
ax.set_aspect('equal')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.legend(bbox_to_anchor=(0, 0.99), loc='upper left', markerscale=5)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS19_top_k_match.png', dpi=300, bbox_inches='tight')
plt.show()
Top-1 match: 72.5% Top-2 match: 81.1% Top-3 match: 88.3%
5. Cross-Modal Integration — Combined Protein + RNA AnnData¶
Create combined AnnData with PCF protein and Xenium RNA features for each matched cell.
pcfAD = sc.read_h5ad(DATA_DIR / '0325_pcf_annotated.h5ad')
xenAD = sc.read_h5ad(DATA_DIR / '0325_tsu33_annotated.h5ad')
# Filter out 'Other' cell type from PCF
pcfAD = pcfAD[pcfAD.obs['Annotations'] != 'Other']
print(f'PCF AnnData: {pcfAD.shape}')
print(f'Xenium AnnData: {xenAD.shape}')
PCF AnnData: (94442, 19) Xenium AnnData: (99827, 194)
# Index h5ad by QuPathID / cell_id (matching original postprocess.ipynb)
pcfAD.obs.index = pcfAD.obs.QuPathID
xenAD.obs.index = xenAD.obs.cell_id
# Subset PCF to matched cells
pcfAD_subset = pcfAD[individual_matches.QuPathID.values, :].copy()
pcfAD_subset.var_names = ['pcf_' + x for x in pcfAD_subset.var_names]
# Add xen_id, prefix all PCF obs columns
pcfAD_subset.obs.loc[:, 'xen_id'] = individual_matches.Cell_Num.values
pcfAD_subset.obs.columns = ['pcf_' + x for x in pcfAD_subset.obs.columns]
# Subset Xenium via matched xen_id
xenAD_subset = xenAD[pcfAD_subset.obs.pcf_xen_id.values, :].copy()
xenAD_subset.var_names = ['xen_' + x for x in xenAD_subset.var_names]
xenAD_subset.obs.loc[:, 'QuPathID'] = pcfAD_subset.obs.pcf_QuPathID.values
xenAD_subset.obs.index = xenAD_subset.obs.QuPathID
xenAD_subset.obs.columns = ['xen_' + x for x in xenAD_subset.obs.columns]
# Combine protein + RNA
combinedAD = sc.concat([pcfAD_subset, xenAD_subset], axis=1)
combinedAD.var = pd.concat([pcfAD_subset.var, xenAD_subset.var], axis=0)
combinedAD.obs = pd.concat([pcfAD_subset.obs, xenAD_subset.obs], axis=1)
combinedAD.obsm['X_spatial'] = xenAD_subset.obsm['spatial'].copy()
print(f'Combined AnnData: {combinedAD.shape}')
print(f'Cell type match rate: {individual_matches["celltype_match"].mean():.1%}')
/hpc/group/singhlab/user/ap756/tmp/ipykernel_1727814/3056355313.py:2: ImplicitModificationWarning: Trying to modify index of attribute `.obs` of view, initializing view as actual.
pcfAD.obs.index = pcfAD.obs.QuPathID
/hpc/group/singhlab/rawdata/ap756/rapidsc/lib/python3.12/site-packages/anndata/_core/anndata.py:1756: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
utils.warn_names_duplicates("obs")
Combined AnnData: (91684, 213) Cell type match rate: 72.5%
6. Fig 5 — Cross-Modal Matrixplot¶
Protein and RNA markers grouped by cell type, showing concordance between matched modalities. Protein labels in blue, RNA in red.
marker_genes = [
'pcf_Pan-Cytokeratin', 'xen_KRT8', 'xen_SFTPB',
'pcf_a-SMA/ACTA2', 'xen_ACTA2',
'pcf_CD19', 'xen_CD19',
'pcf_CD68', 'xen_CD68',
'pcf_CD4', 'xen_CD4', 'xen_CD8A',
]
xlab = [
'PanCK', 'KRT8', 'SFTPB',
'SMA', 'ACTA2',
'CD19', 'CD19',
'CD68', 'CD68',
'CD4', 'CD4', 'CD8A',
]
# Color map: protein (pcf) = blue, RNA (xen) = red
protein_color, mrna_color = '#1f77b4', '#d62728'
is_protein = [g.startswith('pcf_') for g in marker_genes]
gx = sc.pl.matrixplot(
combinedAD, var_names=marker_genes,
groupby='pcf_Annotations',
categories_order=['Epithelial', 'Mesenchymal', 'B cell', 'Myeloid', 'T cell'],
standard_scale='var', cmap='cividis', return_fig=True)
# Color x-axis labels by modality
main_ax = gx.get_axes()['mainplot_ax']
main_ax.set_xticklabels(xlab)
for i, label in enumerate(main_ax.get_xticklabels()):
label.set_color(protein_color if is_protein[i] else mrna_color)
label.set_fontweight('bold')
plt.savefig(FIG_DIR / 'Fig5_matrixplot.svg', bbox_inches='tight')
plt.savefig(FIG_DIR / 'Fig5_matrixplot.png', dpi=300, bbox_inches='tight')
plt.show()
7. T Cell Exhaustion Analysis¶
Subset T cells, compute distance to SFTPB+ tumor cells, bin into Tumor Infiltrating / Boundary / Stroma zones, and show exhaustion marker expression by zone.
# Distance to SFTPB+ epithelial cells (tumor proxy)
sftpb_expr = combinedAD[:, combinedAD.var_names == 'xen_SFTPB'].X
if hasattr(sftpb_expr, 'toarray'):
sftpb_expr = sftpb_expr.toarray()
sftpb_expr = sftpb_expr.flatten()
mean_sftp = sftpb_expr.mean()
sftp_pos_mask = (sftpb_expr > mean_sftp) & (combinedAD.obs.pcf_Annotations == 'Epithelial').values
sftp_pos_coords = combinedAD.obsm['X_spatial'][sftp_pos_mask]
distances = cdist(combinedAD.obsm['X_spatial'], sftp_pos_coords).min(axis=1)
combinedAD.obs['distance_to_sftp'] = distances
print(f'SFTPB+ epithelial cells: {sftp_pos_mask.sum()}')
SFTPB+ epithelial cells: 20009
# Subset T cells with CD3E expression (matching original postprocess.ipynb)
cd3e_expr = combinedAD[:, combinedAD.var_names == 'xen_CD3E'].X
if hasattr(cd3e_expr, 'toarray'):
cd3e_expr = cd3e_expr.toarray()
cd3e_expr = cd3e_expr.flatten()
t_cell_mask = (combinedAD.obs['pcf_Annotations'] == 'T cell').values & (cd3e_expr > 0.01)
t_cell_ad = combinedAD[t_cell_mask].copy()
# Keep only RNA genes for downstream analysis
xen_genes = [g for g in t_cell_ad.var_names if g.startswith('xen_')]
t_cell_ad = t_cell_ad[:, xen_genes]
# Bin by distance to tumor
bins = [0, 15, 30, float('inf')]
labels = ['Tumor Infiltrating', 'Boundary', 'Stroma']
t_cell_ad.obs['distance_bins'] = pd.cut(
t_cell_ad.obs['distance_to_sftp'], bins=bins, labels=labels)
print(f'T cells (CD3E+): {t_cell_ad.shape[0]}')
print(t_cell_ad.obs['distance_bins'].value_counts())
T cells (CD3E+): 1746 distance_bins Boundary 769 Stroma 501 Tumor Infiltrating 467 Name: count, dtype: int64
/hpc/group/singhlab/user/ap756/tmp/ipykernel_1727814/1873578975.py:17: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual. t_cell_ad.obs['distance_bins'] = pd.cut(
exhaustion_markers = [
'xen_SDC4', 'xen_PTP4A1', 'xen_WSB1', 'xen_ICAM1',
'xen_GZMA', 'xen_LAG3', 'xen_PIK3IP1', 'xen_FOS',
'xen_TIGIT', 'xen_CTLA4', 'xen_CXCL13',
]
# Filter to genes present in the data
exhaustion_markers = [g for g in exhaustion_markers if g in t_cell_ad.var_names]
short_names = [g.replace('xen_', '') for g in exhaustion_markers]
fig, ax = plt.subplots(figsize=(6, 3))
gx = sc.pl.matrixplot(
t_cell_ad, var_names=exhaustion_markers,
groupby='distance_bins',
standard_scale='var', cmap='cividis',
swap_axes=False, ax=ax, show=False, return_fig=True)
main_ax = gx.get_axes()['mainplot_ax']
main_ax.set_xticklabels(short_names)
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig5_t_cell_exhaustion.svg', bbox_inches='tight')
plt.savefig(FIG_DIR / 'Fig5_t_cell_exhaustion.png', dpi=300, bbox_inches='tight')
plt.show()
(Optional) Full Reproduction¶
# Main result
bash run_same.sh --dp 10 --knn 8 --ms 3
# Full penalty sweep
for dp in 0 1 5 10 25 50; do
bash run_same.sh --dp $dp
done
See run_same.sh for the full parameter specification.