LUAD33 — Protein + Xenium Cross-Modal Alignment¶

Paper Figures: Fig 5 (a–f), Fig S15–S18

This notebook reproduces the LUAD (lung adenocarcinoma) cross-modal alignment figures from the SAME paper. SAME aligns protein (PCF, 33 markers) and RNA (Xenium, 300+ genes) spatial data from adjacent serial sections.

Dataset:

  • Template (RNA): ~100K cells — 10x Xenium spatial transcriptomics
  • Query (Protein): ~94K cells — Polychromatic Flow (PCF) imaging
  • 5 cell types: B cell, Epithelial, Mesenchymal, Myeloid, T cell

Main result: dp=10, MS=3, knn=8, window_size=13000

Parameter Value
delaunay_penalty 10 (sweep: 0, 1, 5, 10, 25, 50)
window_size 13000
overlap 250
max_matches 1
radius / knn 250 / 8
metacell_size 3
lazy_constraints True

Setup¶

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import anndata as ad
import pickle
from pathlib import Path
from scipy.spatial.distance import cdist

plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.size'] = 12

NOTEBOOK_DIR = Path('.').resolve()
PROJECT_ROOT = NOTEBOOK_DIR.parents[1]
import sys
sys.path.insert(0, str(PROJECT_ROOT))

DATA_DIR = NOTEBOOK_DIR / 'data'
RESULTS_DIR = NOTEBOOK_DIR / 'results'
FIG_DIR = NOTEBOOK_DIR / 'figures'
FIG_DIR.mkdir(exist_ok=True)

from src.eval_utils import check_alignment, check_triangle_violations
from src.metacell_utils import unpack_metacell_matches

print(f'Project root: {PROJECT_ROOT}')
Project root: /hpc/group/singhlab/rawdata/ap756/1024_same/heart/SAME
In [2]:
CELL_TYPES = ['B cell', 'Epithelial', 'Mesenchymal', 'Myeloid', 'T cell']

# Colors from original notebooks/luad_33/ analysis
CT_COLORS = {
    'B cell':       (1.0, 0.0, 0.8117654238977766),
    'Epithelial':   (0.0, 0.29411984742867114, 1.0),
    'Mesenchymal':  (1.0, 0.741177211765447, 0.0),
    'Myeloid':      (0.44705736433677606, 0.0, 1.0),
    'T cell':       (0.0, 1.0, 0.9647031631761764),
}

1. Load Input Data¶

In [3]:
# Protein (PCF) = query (aligned), RNA (Xenium) = template (ref)
alignDF = pd.read_csv(DATA_DIR / 'align_pcf.csv', index_col=0)
refDF = pd.read_csv(DATA_DIR / 'ref_xen.csv', index_col=0)

alignDF['Cell_Num_Old'] = alignDF.index.values
refDF['Cell_Num_Old'] = refDF.index.values

alignDF[CELL_TYPES] = alignDF[CELL_TYPES] * 100
refDF[CELL_TYPES] = refDF[CELL_TYPES] * 100
alignDF['cell_type'] = alignDF[CELL_TYPES].idxmax(axis=1)
refDF['cell_type'] = refDF[CELL_TYPES].idxmax(axis=1)

print(f'Template (Xenium RNA):  {len(refDF)} cells')
print(f'Query (PCF Protein):    {len(alignDF)} cells')
print(f'\nCell type distribution (PCF):')
print(alignDF['cell_type'].value_counts())
Template (Xenium RNA):  99827 cells
Query (PCF Protein):    94442 cells

Cell type distribution (PCF):
cell_type
Mesenchymal    44240
Epithelial     25697
Myeloid         9213
T cell          8052
B cell          7240
Name: count, dtype: int64

2. Fig 5a,b — Cell Types in Xenium Template and PCF Query¶

In [4]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, df, title in zip(axes,
                          [refDF, alignDF],
                          ['a. Xenium Template (RNA)', 'b. PCF Query (Protein)']):
    for ct in CELL_TYPES:
        mask = df['cell_type'] == ct
        if mask.sum() > 0:
            ax.scatter(df.loc[mask, 'X'], df.loc[mask, 'Y'],
                       s=0.1, alpha=0.5, label=ct, color=CT_COLORS[ct])
    ax.set_title(title, fontsize=14, loc='left', fontweight='bold')
    ax.set_aspect('equal')
    ax.invert_yaxis()
    ax.set_axis_off()

axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', markerscale=6, fontsize=9)
plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig5ab_cell_types.png', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

3. Penalty Sweep — Runtime (Fig S18)¶

Total optimization time across dp = 0, 1, 5, 10, 25, 50.

In [5]:
mc_align = pickle.load(open(RESULTS_DIR / 'mc_align.pkl', 'rb'))
mc_ref = pickle.load(open(RESULTS_DIR / 'mc_ref.pkl', 'rb'))

dp_values = [0, 1, 5, 10, 25, 50]
sweep_results = []

for dp in dp_values:
    folder = RESULTS_DIR / f'dp{dp}_knn8_MS3'
    outputDF = pd.read_csv(folder / 'matchedDF.csv')

    # Runtime per window
    per_window = outputDF[['window_id', 'run_time']].groupby('window_id').first()['run_time']
    total_hours = per_window.sum() / 3600
    total_minutes = per_window.sum() / 60

    sweep_results.append({
        'dp': dp,
        'total_time_hours': total_hours,
        'total_time_minutes': total_minutes,
    })
    print(f'dp={dp:>2d}: time={total_hours:.2f}h ({total_minutes:.1f} min)')

sweep_df = pd.DataFrame(sweep_results)
sweep_df
dp= 0: time=0.01h (0.5 min)
dp= 1: time=0.01h (0.6 min)
dp= 5: time=0.01h (0.7 min)
dp=10: time=0.03h (1.8 min)
dp=25: time=4.17h (249.9 min)
dp=50: time=10.15h (608.9 min)
Out[5]:
dp total_time_hours total_time_minutes
0 0 0.007676 0.460565
1 1 0.009880 0.592800
2 5 0.012190 0.731406
3 10 0.029466 1.767981
4 25 4.165308 249.918498
5 50 10.148551 608.913084
In [ ]:
fig, axs = plt.subplots(1, 2, figsize=(9, 3.5))

# Bar chart
sns.barplot(data=sweep_df, x='dp', y='total_time_hours', ax=axs[0], color='steelblue')
axs[0].set_xlabel('Delaunay Penalty')
axs[0].set_ylabel('Total time (hours)')
axs[0].set_yscale('log')
for p in axs[0].patches:
    h = p.get_height()
    fmt = f'{h:.3f}' if h < 0.01 else f'{h:.2f}'
    axs[0].annotate(fmt, (p.get_x() + p.get_width()/2., h),
                    ha='center', va='bottom', fontsize=8, xytext=(0, 2),
                    textcoords='offset points')
axs[0].spines['top'].set_visible(False)
axs[0].spines['right'].set_visible(False)

# Table
axs[1].axis('off')
table_data = [[int(row['dp']), f"{row['total_time_minutes']:.2f}"]
              for _, row in sweep_df.iterrows()]
table = axs[1].table(cellText=table_data,
                     colLabels=['Penalty', 'Total time (min)'],
                     cellLoc='center', loc='center', colWidths=[0.4, 0.4])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)
for i in range(2):
    table[(0, i)].set_facecolor('steelblue')
    table[(0, i)].set_text_props(weight='bold', color='white')
for i in range(1, len(table_data) + 1):
    for j in range(2):
        table[(i, j)].set_facecolor('#E7E6E6' if i % 2 == 0 else '#FFFFFF')

plt.tight_layout()
plt.savefig(FIG_DIR / 'FigS18_time_vs_dp.svg', dpi=300, bbox_inches='tight')
plt.savefig(FIG_DIR / 'FigS18_time_vs_dp.png', dpi=300, bbox_inches='tight')
plt.show()

4. Main Result (dp=10) — Spatial Alignment + Top-k Match (Fig S19)¶

In [13]:
main_dir = RESULTS_DIR / 'dp10_knn8_MS3'
matchesDF = pd.read_csv(main_dir / 'matchedDF.csv')

# Unpack to individual cells
individual_matches = unpack_metacell_matches(
    matchesDF, mc_align.metacell_df, mc_ref.metacell_df,
    aligned_df=mc_align.original_df, ref_df=mc_ref.original_df,
    strategy='nearest',
    aligned_original_idx_col='Cell_Num_Old',
    ref_original_idx_col='Cell_Num_Old')

# Cell type labels
aligned_ct = mc_align.original_df.set_index('Cell_Num_Old')['cell_type']
ref_ct = mc_ref.original_df.set_index('Cell_Num_Old')['cell_type']
individual_matches['aligned_celltype'] = individual_matches['Aligned_cell_id'].map(aligned_ct)
individual_matches['ref_celltype'] = individual_matches['Ref_cell_id'].map(ref_ct)
individual_matches['celltype_match'] = individual_matches['aligned_celltype'] == individual_matches['ref_celltype']

# Map to Xenium Cell_Num for top-k evaluation
mc_ref.original_df.index = mc_ref.original_df.Cell_Num_Old.values
individual_matches['Cell_Num'] = mc_ref.original_df.loc[
    individual_matches['Ref_cell_id'].values, 'Cell_Num'].values

# Get SAME-aligned coords from ref
ref_coords = mc_ref.original_df.set_index('Cell_Num_Old')[['X', 'Y']]
individual_matches['SAME_X'] = individual_matches['Ref_cell_id'].map(ref_coords['X'])
individual_matches['SAME_Y'] = individual_matches['Ref_cell_id'].map(ref_coords['Y'])

# Map to QuPathID for h5ad subsetting later
mc_align.original_df.index = mc_align.original_df.Cell_Num_Old.values
individual_matches['QuPathID'] = mc_align.original_df.loc[
    individual_matches['Aligned_cell_id'].values, 'QuPathID'].values

# Dominant cell type (from PCF)
individual_matches['Dominant_Cell_Type'] = individual_matches['aligned_celltype']

accuracy = individual_matches['celltype_match'].mean() * 100
print(f'dp=10: {len(individual_matches)} matches, {accuracy:.1f}% cell type accuracy')
dp=10: 91684 matches, 72.5% cell type accuracy
In [24]:
# Top-k cell type match using Xenium reference probabilities
xenDF = pd.read_csv(DATA_DIR / 'ref_xen.csv', index_col=0)
xenDF.index = xenDF['Cell_Num'].astype(str)

# Vectorized top-k evaluation
ref_probs = xenDF[CELL_TYPES].astype(float)
valid = individual_matches['Cell_Num'].astype(str).isin(ref_probs.index)
ref_rows = ref_probs.loc[individual_matches.loc[valid, 'Cell_Num'].astype(str)].values
ct_array = np.array(CELL_TYPES)
dom_types = individual_matches.loc[valid, 'Dominant_Cell_Type'].values

for k in [1, 2, 3]:
    matches = np.zeros(len(individual_matches), dtype=bool)
    top_k_idx = np.argpartition(ref_rows, -k, axis=1)[:, -k:]
    top_k_types = ct_array[top_k_idx]
    matches[valid.values] = np.any(top_k_types == dom_types[:, np.newaxis], axis=1)
    individual_matches[f'top_{k}_match'] = matches

print(f"Top-1 match: {individual_matches['top_1_match'].mean():.1%}")
print(f"Top-2 match: {individual_matches['top_2_match'].mean():.1%}")
print(f"Top-3 match: {individual_matches['top_3_match'].mean():.1%}")

# Plot
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ix, (ax, k) in enumerate(zip(axes, [1, 2, 3])):
    col = f'top_{k}_match'
    correct = individual_matches[individual_matches[col]]
    incorrect = individual_matches[~individual_matches[col]]
    sns.scatterplot(data=individual_matches, x='SAME_X', y='SAME_Y', s=1, alpha=0.5, hue=col, palette={True: 'blue', False: 'red'}, ax=ax)
    rate = individual_matches[col].mean() * 100
    ax.set_title(f'{chr(97+ix)}. Correct in top-{k} ({rate:.1f}%)',
                 loc='left', fontweight='bold')
    ax.invert_yaxis()
    ax.set_aspect('equal')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend(bbox_to_anchor=(0, 0.99), loc='upper left', markerscale=5)
    ax.set_xticks([])
    ax.set_yticks([])
plt.tight_layout()

plt.savefig(FIG_DIR / 'FigS19_top_k_match.png', dpi=300, bbox_inches='tight')
plt.show()
Top-1 match: 72.5%
Top-2 match: 81.1%
Top-3 match: 88.3%
No description has been provided for this image

5. Cross-Modal Integration — Combined Protein + RNA AnnData¶

Create combined AnnData with PCF protein and Xenium RNA features for each matched cell.

In [27]:
pcfAD = sc.read_h5ad(DATA_DIR / '0325_pcf_annotated.h5ad')
xenAD = sc.read_h5ad(DATA_DIR / '0325_tsu33_annotated.h5ad')

# Filter out 'Other' cell type from PCF
pcfAD = pcfAD[pcfAD.obs['Annotations'] != 'Other']

print(f'PCF AnnData:    {pcfAD.shape}')
print(f'Xenium AnnData: {xenAD.shape}')
PCF AnnData:    (94442, 19)
Xenium AnnData: (99827, 194)
In [28]:
# Index h5ad by QuPathID / cell_id (matching original postprocess.ipynb)
pcfAD.obs.index = pcfAD.obs.QuPathID
xenAD.obs.index = xenAD.obs.cell_id

# Subset PCF to matched cells
pcfAD_subset = pcfAD[individual_matches.QuPathID.values, :].copy()
pcfAD_subset.var_names = ['pcf_' + x for x in pcfAD_subset.var_names]

# Add xen_id, prefix all PCF obs columns
pcfAD_subset.obs.loc[:, 'xen_id'] = individual_matches.Cell_Num.values
pcfAD_subset.obs.columns = ['pcf_' + x for x in pcfAD_subset.obs.columns]

# Subset Xenium via matched xen_id
xenAD_subset = xenAD[pcfAD_subset.obs.pcf_xen_id.values, :].copy()
xenAD_subset.var_names = ['xen_' + x for x in xenAD_subset.var_names]
xenAD_subset.obs.loc[:, 'QuPathID'] = pcfAD_subset.obs.pcf_QuPathID.values
xenAD_subset.obs.index = xenAD_subset.obs.QuPathID
xenAD_subset.obs.columns = ['xen_' + x for x in xenAD_subset.obs.columns]

# Combine protein + RNA
combinedAD = sc.concat([pcfAD_subset, xenAD_subset], axis=1)
combinedAD.var = pd.concat([pcfAD_subset.var, xenAD_subset.var], axis=0)
combinedAD.obs = pd.concat([pcfAD_subset.obs, xenAD_subset.obs], axis=1)
combinedAD.obsm['X_spatial'] = xenAD_subset.obsm['spatial'].copy()

print(f'Combined AnnData: {combinedAD.shape}')
print(f'Cell type match rate: {individual_matches["celltype_match"].mean():.1%}')
/hpc/group/singhlab/user/ap756/tmp/ipykernel_1727814/3056355313.py:2: ImplicitModificationWarning: Trying to modify index of attribute `.obs` of view, initializing view as actual.
  pcfAD.obs.index = pcfAD.obs.QuPathID
/hpc/group/singhlab/rawdata/ap756/rapidsc/lib/python3.12/site-packages/anndata/_core/anndata.py:1756: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
Combined AnnData: (91684, 213)
Cell type match rate: 72.5%

6. Fig 5 — Cross-Modal Matrixplot¶

Protein and RNA markers grouped by cell type, showing concordance between matched modalities. Protein labels in blue, RNA in red.

In [29]:
marker_genes = [
    'pcf_Pan-Cytokeratin', 'xen_KRT8', 'xen_SFTPB',
    'pcf_a-SMA/ACTA2', 'xen_ACTA2',
    'pcf_CD19', 'xen_CD19',
    'pcf_CD68', 'xen_CD68',
    'pcf_CD4', 'xen_CD4', 'xen_CD8A',
]
xlab = [
    'PanCK', 'KRT8', 'SFTPB',
    'SMA', 'ACTA2',
    'CD19', 'CD19',
    'CD68', 'CD68',
    'CD4', 'CD4', 'CD8A',
]

# Color map: protein (pcf) = blue, RNA (xen) = red
protein_color, mrna_color = '#1f77b4', '#d62728'
is_protein = [g.startswith('pcf_') for g in marker_genes]

gx = sc.pl.matrixplot(
    combinedAD, var_names=marker_genes,
    groupby='pcf_Annotations',
    categories_order=['Epithelial', 'Mesenchymal', 'B cell', 'Myeloid', 'T cell'],
    standard_scale='var', cmap='cividis', return_fig=True)

# Color x-axis labels by modality
main_ax = gx.get_axes()['mainplot_ax']
main_ax.set_xticklabels(xlab)
for i, label in enumerate(main_ax.get_xticklabels()):
    label.set_color(protein_color if is_protein[i] else mrna_color)
    label.set_fontweight('bold')

plt.savefig(FIG_DIR / 'Fig5_matrixplot.svg', bbox_inches='tight')
plt.savefig(FIG_DIR / 'Fig5_matrixplot.png', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

7. T Cell Exhaustion Analysis¶

Subset T cells, compute distance to SFTPB+ tumor cells, bin into Tumor Infiltrating / Boundary / Stroma zones, and show exhaustion marker expression by zone.

In [30]:
# Distance to SFTPB+ epithelial cells (tumor proxy)
sftpb_expr = combinedAD[:, combinedAD.var_names == 'xen_SFTPB'].X
if hasattr(sftpb_expr, 'toarray'):
    sftpb_expr = sftpb_expr.toarray()
sftpb_expr = sftpb_expr.flatten()

mean_sftp = sftpb_expr.mean()
sftp_pos_mask = (sftpb_expr > mean_sftp) & (combinedAD.obs.pcf_Annotations == 'Epithelial').values
sftp_pos_coords = combinedAD.obsm['X_spatial'][sftp_pos_mask]

distances = cdist(combinedAD.obsm['X_spatial'], sftp_pos_coords).min(axis=1)
combinedAD.obs['distance_to_sftp'] = distances

print(f'SFTPB+ epithelial cells: {sftp_pos_mask.sum()}')
SFTPB+ epithelial cells: 20009
In [31]:
# Subset T cells with CD3E expression (matching original postprocess.ipynb)
cd3e_expr = combinedAD[:, combinedAD.var_names == 'xen_CD3E'].X
if hasattr(cd3e_expr, 'toarray'):
    cd3e_expr = cd3e_expr.toarray()
cd3e_expr = cd3e_expr.flatten()

t_cell_mask = (combinedAD.obs['pcf_Annotations'] == 'T cell').values & (cd3e_expr > 0.01)
t_cell_ad = combinedAD[t_cell_mask].copy()

# Keep only RNA genes for downstream analysis
xen_genes = [g for g in t_cell_ad.var_names if g.startswith('xen_')]
t_cell_ad = t_cell_ad[:, xen_genes]

# Bin by distance to tumor
bins = [0, 15, 30, float('inf')]
labels = ['Tumor Infiltrating', 'Boundary', 'Stroma']
t_cell_ad.obs['distance_bins'] = pd.cut(
    t_cell_ad.obs['distance_to_sftp'], bins=bins, labels=labels)

print(f'T cells (CD3E+): {t_cell_ad.shape[0]}')
print(t_cell_ad.obs['distance_bins'].value_counts())
T cells (CD3E+): 1746
distance_bins
Boundary              769
Stroma                501
Tumor Infiltrating    467
Name: count, dtype: int64
/hpc/group/singhlab/user/ap756/tmp/ipykernel_1727814/1873578975.py:17: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual.
  t_cell_ad.obs['distance_bins'] = pd.cut(
In [32]:
exhaustion_markers = [
    'xen_SDC4', 'xen_PTP4A1', 'xen_WSB1', 'xen_ICAM1',
    'xen_GZMA', 'xen_LAG3', 'xen_PIK3IP1', 'xen_FOS',
    'xen_TIGIT', 'xen_CTLA4', 'xen_CXCL13',
]
# Filter to genes present in the data
exhaustion_markers = [g for g in exhaustion_markers if g in t_cell_ad.var_names]
short_names = [g.replace('xen_', '') for g in exhaustion_markers]

fig, ax = plt.subplots(figsize=(6, 3))
gx = sc.pl.matrixplot(
    t_cell_ad, var_names=exhaustion_markers,
    groupby='distance_bins',
    standard_scale='var', cmap='cividis',
    swap_axes=False, ax=ax, show=False, return_fig=True)

main_ax = gx.get_axes()['mainplot_ax']
main_ax.set_xticklabels(short_names)

plt.tight_layout()
plt.savefig(FIG_DIR / 'Fig5_t_cell_exhaustion.svg', bbox_inches='tight')
plt.savefig(FIG_DIR / 'Fig5_t_cell_exhaustion.png', dpi=300, bbox_inches='tight')
plt.show()
No description has been provided for this image

(Optional) Full Reproduction¶

# Main result
bash run_same.sh --dp 10 --knn 8 --ms 3

# Full penalty sweep
for dp in 0 1 5 10 25 50; do
    bash run_same.sh --dp $dp
done

See run_same.sh for the full parameter specification.