2,363
ADNI subjects
~70M
Model params
0.707
DX 3-class Bal. Acc.
0.933
CN vs Dem Bal. Acc.
A custom vision-language model that ingests T1 MRI, DTI FA maps, and clinical scores to predict six tasks, including preclinical amyloid detection. Extended with a FAISS-based retrieval pipeline and a three-way LLM comparison (Mistral 7B, Gemma 4 26B MoE, MedGemma 1.5 4B) for visual question answering over brain scans.
Under the NIA-AA A/T/N framework, cognitively normal (CN) individuals who test positive for amyloid are already on the Alzheimer's disease biological continuum. This is preclinical AD: no symptoms yet, but a 3 to 5x increased risk of progressing to MCI or dementia within 3 to 5 years. Anti-amyloid therapies (Lecanemab, Donanemab) work best at this stage.
The problem is finding these people. Amyloid PET scans cost $3,000 to $6,000 each. CSF draws are invasive. Neither works at screening scale. This project asks whether structural MRI + DTI + routine clinical scores can do that job. No PET infrastructure needed. Just the imaging and labs most patients already have.
Primary target
Detect amyloid positivity in CN subjects well enough to work as a rule-out triage tool. High specificity matters here: if the model predicts CN amyloid-negative, those subjects can skip the expensive PET scan.
All data comes from the Alzheimer's Disease Neuroimaging Initiative (ADNI). After filtering to subjects with valid DX labels and 9DOF T1 paths, deduplicating to one scan per subject, and recovering DTI paths from an earlier dataset cut, the combined v3 cohort is 2,363 subjects with DTI coverage of 39.4% (nearly doubled from the earlier 19.8%).
80 / 20 stratified split by diagnosis
| Class | Train | Test | Total |
|---|---|---|---|
| CN — Cognitively Normal | 669 | 168 | 837 |
| MCI — Mild Cognitive Impairment | 650 | 163 | 813 |
| Dementia | 570 | 143 | 713 |
| Total | 1,889 | 474 | 2,363 |
100%
T1 MRI (9DOF 2mm)
39.4%
DTI FA coverage
~100%
Clinical scores
The model is a multi-modal vision-language model with missing-modality support. Three modality-specific encoders produce ℓ2-normalized 512-d embeddings, each gated by a per-modality masking probability during training. The masked embeddings are fused via 8-head cross-attention with a learnable pool query, producing a fused representation z_f ∈ ℝ⁵¹² that feeds six MLP task heads.
Three modality-specific encoders produce ℓ2-normalized 512-d embeddings. Each passes through a masking gate that randomly drops modalities during training (T1: 10%, DTI: 30%, Clinical: 5%). Masked embeddings are fused via 8-head cross-attention with a learnable pool query, producing z_f ∈ ℝ⁵¹² that feeds six MLP task heads. Amyloid head upweighted as the primary clinical target.
Imaging encoders
Two independent 3D ResNet-18 networks for T1 MRI and DTI FA maps. Input volumes are 91×109×91, output is 512-d followed by a linear projection + LayerNorm.
Clinical encoder
MLP over 6 continuous features (CDR-SB, ADAS-11/13, MMSE, MoCA, AV45 SUVR) plus an APOE genotype embedding. Output: 512-d, ℓ2-normalized.
Design decision: no label leakage
An earlier iteration fed DX_code, SEX_code, and Amyloid_code into the clinical encoder. Those are the same variables the task heads are trying to predict, so the accuracy numbers (DX3 at 97.6%, sex at 98.3%) were meaningless. The v2/v3 model removes these inputs entirely. DX, sex, age, and amyloid are prediction targets only; the clinical encoder receives 7 features. Every number on this page is leakage-free.
Training runs in two stages. Stage 1 is contrastive pre-training: a pairwise CLIP/InfoNCE loss between all three modality pairs (T1–DTI, T1–Clinical, DTI–Clinical), computed only on pairs where both modalities are present. Stage 2B is multi-task fine-tuning across six heads with differential learning rates (backbone 10⁻⁵, heads 5×10⁻⁴).
During training, modalities are randomly dropped so the model learns to work with any subset at inference time. T1 is dropped 10% of the time, DTI 30%, Clinical 5%. DTI gets the highest drop rate because only 39.4% of subjects have it, so the model needs to handle "missing DTI" as the normal case, not the exception.
Stage 1 · Contrastive
30 epochs · pairwise InfoNCE
lr 1×10⁻⁴ · AdamW · cosine anneal
Stage 2B · Multi-task
30 epochs · 6 joint heads
focal (γ=2) + smooth L1 · AMP FP16
Stage 2B training history. Best composite score at epoch 5; later epochs show overfitting, particularly on amyloid. The model used downstream is the epoch-5 checkpoint.
Evaluated on the held-out 474-subject test set using all available modalities. Headline metrics on the best model (Stage 2B, epoch 5):
| Task | Bal. Acc. | Macro F1 | AUC |
|---|---|---|---|
| DX 3-class (CN / MCI / Dem)primary | 0.707 | 0.703 | 0.865 |
| DX Binary (CN vs Dem) | 0.933 | 0.932 | 0.981 |
| Sex | 0.575 | 0.563 | 0.597 |
| Amyloid (A𝛽+ / A𝛽−) | 0.733 | 0.733 | 0.806 |
| CN Amyloid — preclinicalprimary | 0.688 | 0.695 | 0.685 |
| MCI Amyloid | 0.604 | 0.595 | 0.723 |
| Dementia Amyloid | 0.484 | 0.456 | 0.906 |
| Age (years, MAE ↓) | 6.31 | ||
| CDR-SB (MAE ↓) | 0.97 | ||
DX 3-class confusion matrix · full test set (n=474)
Rows = true class, columns = predicted. Each cell shows count and row-normalized %. MCI is the hardest class at 55% recall. It sits between CN and Dementia so the model hedges in both directions. Dementia recall is the strongest at 86%, with only 2 misclassified as CN.
CN amyloid (preclinical screen)
0.846 spec · 0.529 sens
High specificity, moderate sensitivity. That tradeoff is intentional for a triage screen. When the model predicts CN amyloid-negative, it is right 85% of the time.
Dementia amyloid
0.906 AUC
By the time disease is established, the signal is clear. Amyloid-positivity in dementia is correctly identified almost 97% of the time (sens = 0.969).
Every one of the seven possible modality subsets was evaluated using the same trained model, with non-selected modalities masked at inference time. This shows where the signal actually comes from and which combination works best for each task.
DX 3-class (Bal. Acc.)
CN / MCI / Dementia
DX Binary (Bal. Acc.)
CN vs Dementia
Amyloid (Bal. Acc.)
Overall positivity
CN Amyloid (Bal. Acc.)
Preclinical screen
DTI+Clin wins three of the four tasks (DX Binary, overall Amyloid, tied on DX 3-class). For preclinical CN amyloid though, you need all three: T1 + DTI + Clinical reaches 0.688 Bal. Acc. vs. Clinical-only at 0.557, a 13.1 point gap on the hardest task. DTI alone is the weakest combination overall, but it still carries real signal for severe dementia cases.
The VLM gives you a prediction and a confidence score. What it doesn't give you is an explanation, or any way to ask follow-up questions in natural language. That's what the VQA extension adds. The frozen VLM encodes a query scan into a 512-d embedding. FAISS finds the 50 most similar training subjects by inner product. A cross-encoder reranks those 50 down to the top 5 most relevant matches. Those 5 captions become the context fed to a language model, which answers clinical questions about the scan.
Frozen multi-modal encoders produce the fused embedding z_f ∈ ℝ⁵¹². FAISS retrieves top-50 similar training subjects; a cross-encoder reranks down to top-5. The LLM generates answers from the retrieved context. Three LLM backbones are compared: Mistral 7B, Gemma 4 26B MoE, and MedGemma 1.5 4B.
Text encoder
all-MiniLM-L6-v2, MLM-pretrained on 26,889 clinical sentences (25K synthetic + 1,889 real captions), then contrastively aligned to the imaging embedding space. 384-d output projected up to 512-d.
Retrieval + rerank
FAISS IndexFlatIP over 1,889 ℓ2-normalized training vectors, exact inner-product search. Cross-encoder: ms-marco-MiniLM-L-6-v2, reranking top-50 to top-5.
Three models were given the same retrieved context: Mistral 7B Instruct v0.3 (general-purpose, dense), Gemma 4 26B MoE (larger, mixture-of-experts), and MedGemma 1.5 4B IT (smaller, fine-tuned on medical data). All quantized to 4-bit NF4. The question was simple: does a medical fine-tune beat a bigger general model on this task?
VQA Diagnosis
Standard · full modality
VQA Diagnosis
DTI-only query
BERTScore
contextual similarity
SBERT CosSim
sentence-level
Mistral 7B
7B dense
Gemma 4 26B
MoE
MedGemma 4B
medical FT
VQA Diagnosis
Standard · full modality
VQA Diagnosis
DTI-only query
BERTScore
contextual similarity
SBERT CosSim
sentence-level
Same retrieved context across all three models; only the generation model changes. Mistral 7B wins every metric: diagnosis VQA accuracy and text quality (BERTScore, SBERT). MedGemma's medical fine-tune loses to a general-purpose 7B model. At this size, instruction-following matters more than domain knowledge.
Headline finding
Scale beats domain. Mistral 7B, a general-purpose dense model, outperforms both a 26B MoE and a medically fine-tuned 4B on every metric. The retrieved context already supplies the medical knowledge. What matters is whether the model can follow instructions and format its output correctly.
CN amyloid Bal. Acc. goes from 0.557 (Clinical-only) to 0.688 (T1 + DTI + Clinical), a 13.1 point gain on the hardest and most clinically useful task. Specificity sits at 0.846, which is what you want for a triage tool.
Modality dropout during training (T1 10%, DTI 30%, Clinical 5%) means the model handles any combination at inference. In practice it works for the 60% of subjects who only have T1 + Clinical, not just the 39.4% with full DTI coverage.
DTI + Clinical wins DX Binary (0.938) and overall Amyloid (0.746). For CN amyloid specifically though, T1 matters: DTI + Clinical drops to 0.623 while the full stack reaches 0.688.
A general-purpose 7B dense model beats a 26B MoE and a medically fine-tuned 4B on both diagnosis accuracy (94.7% vs 92.7% vs 50.7%) and text quality. The retrieved context does the medical heavy lifting. The model just needs to read it and respond clearly.
Under a DTI-only query, FAISS retrieval at @5 drops to 40.9%. Mistral 7B still reaches 75.3% VQA accuracy on those same queries. The LLM can extract useful signal even from partially mismatched context.
Status
This work was completed at the Keck School of Medicine of USC. Once the paper is submitted I'll link the manuscript and GitHub repo here.