Parallel Representative Tokens for Retrieval with Diffusion Language Models
Quick Start Β· Setup Β· Backbones Β· Reproducing Β· Layout Β· Citation
TL;DR β Diffusion language models (DLMs) generate responses through a masked-position prediction interface: they pre-allocate masked output positions and predict them jointly under bidirectional attention. DiffRetriever uses that interface directly for retrieval. It appends
$K$ masked positions after a retrieval prompt and reads$K$ hidden states and$K$ logit vectors from a single forward pass. With$K{=}1$ this already beats encoder-style diffusion retrieval; with$K{>}1$ it gives ColBERT-style multi-representation retrieval at near single-pass encoding cost. The autoregressive equivalent must decode$K$ tokens sequentially, paying$\approx 15\times$ the encoding latency at zero shot.
Overview of DiffRetriever. A query and a passage are each formatted with a retrieval prompt that ends in $K_q$ (query) or $K_p$ (passage) [MASK] positions, capped by fixed suffix tokens. A diffusion language model reads all masked positions in a single bidirectional forward pass, yielding $K$ hidden states for dense retrieval and $K$ next-token logit vectors for sparse retrieval on each side. Scoring uses ColBERT-style MaxSim on the dense vectors and max-pooled logits on the sparse vectors, followed by hybrid score interpolation. PromptReps on an autoregressive backbone generates retrieval representations sequentially up to a maximum budget $N$, so wall-clock cost grows with the number generated; DiffRetriever pre-allocates $K$ masked positions and reads them in parallel from one pass.
BEIR-7 NDCG@10 vs. query encoding plus search latency (ms/query, 100K-document MSΒ MARCO sample). Left: zero-shot. Right: fine-tuned. Dashed lines link single-representation and multi-representation variants of the same system. DiffRetriever gains from multi-representation retrieval at near single-representation cost in both panels; the autoregressive PromptReps equivalent pays much higher latency for the same number of representations. Across MSΒ MARCO dev, TREC DL 19/20, and BEIR-7, DiffRetriever obtains the strongest aggregate effectiveness within our matched comparison.
Latency scaling on synthetic inputs and indices (single H100, same attention backend across backbones). Top row: encoding latency vs. input sequence length. Bottom row: search latency vs. index size (log scale). Left column: autoregressive PromptReps (Qwen2.5, LLaMA3). Right column: DiffRetriever on diffusion backbones (Dream, LLaDA). Open markers = single-representation, filled = multi-representation. DiffRetriever's multi-representation encoding stays close to its single-representation cost, while AR multi-representation remains 2β3Γ AR single-representation across the input range.
Models on Hugging Face: trained checkpoints for DiffRetriever (Dream, LLaDA) and the re-trained baselines (PromptReps, DiffEmbed, RepLLaMA) will be released on the Hugging Face Hub. They are not available yet β this README will be updated with the model URLs when the release lands.
Zero-shot retrieval with Dream-7B, end-to-end in pure Python β no SLURM, no scripts, no data download.
After setting up the env (pip install -r requirements.txt), paste this whole block into a file (demo.py) at the repo root and run it. The model auto-downloads from HuggingFace on first run (~14 GB).
"""DiffRetriever zero-shot demo β Dream-7B, K=4 masked positions, single forward pass."""
import sys
sys.path.insert(0, "src") # repo-local import
import torch
import torch.nn.functional as F
from models.diffretriever_dream import DreamRetriever
# 1. Load the encoder (zero-shot β no fine-tuning required)
model = DreamRetriever(
model_name="Dream-org/Dream-v0-Instruct-7B",
max_length=512,
n_gen_tokens=4, # K = number of masked positions appended
num_denoise_steps=1, # S=1 β a single forward pass
query_prompt="prompts/default/query_prompt_few.yaml",
passage_prompt="prompts/default/passage_prompt_few.yaml",
)
model.eval()
# 2. Tiny demo corpus
queries = [
"what causes the seasons on earth?",
"best way to learn guitar at home",
]
passages = [
"The tilt of Earth's axis relative to its orbital plane causes seasonal variation in sunlight.",
"Pick a beginner-friendly acoustic guitar and practice 15 minutes daily with online tutorials.",
"Photosynthesis converts carbon dioxide and water into glucose using sunlight.",
"Plate tectonics describes how Earth's lithosphere is divided into moving plates.",
"Online video lessons are an efficient way to learn an instrument at your own pace.",
]
# 3. Encode β one forward pass returns repr_hidden ([N, K, H]) and sparse activations
with torch.inference_mode():
q = model.encode(queries, is_query=True, encoding_mode="promptreps", encode_type="all_steps")
p = model.encode(passages, is_query=False, encoding_mode="promptreps", encode_type="all_steps")
# 4. ColBERT MaxSim scoring on the K-vector outputs
q_vec = F.normalize(q["repr_hidden"].float(), dim=-1) # [Q, K_q, H]
p_vec = F.normalize(p["repr_hidden"].float(), dim=-1) # [P, K_p, H]
sim = torch.einsum("qkh,pdh->qkpd", q_vec, p_vec) # [Q, K_q, P, K_p]
scores = sim.max(dim=-1).values.clamp(min=0).sum(dim=1) # [Q, P]
# 5. Top-3 hits per query
for i, query in enumerate(queries):
top = scores[i].topk(3)
print(f"\nQ: {query}")
for s, idx in zip(top.values.tolist(), top.indices.tolist()):
print(f" {s:.3f} {passages[idx]}")Expected output (scores will vary slightly across GPUs):
Q: what causes the seasons on earth?
3.21 The tilt of Earth's axis relative to its orbital plane causes seasonal variation in sunlight.
2.18 Plate tectonics describes how Earth's lithosphere is divided into moving plates.
1.92 Photosynthesis converts carbon dioxide and water into glucose using sunlight.
Q: best way to learn guitar at home
3.05 Pick a beginner-friendly acoustic guitar and practice 15 minutes daily with online tutorials.
2.41 Online video lessons are an efficient way to learn an instrument at your own pace.
1.55 Photosynthesis converts carbon dioxide and water into glucose using sunlight.
That's the whole pipeline β append DreamRetriever with LLaDA2Retriever (from models.diffretriever_llada import LLaDA2Retriever). To run sparse retrieval, use the sparse_indices / sparse_values keys also returned by encode(...). To run fusion (hybrid), blend dense + sparse with min-max normalisation following PromptReps. The full sweep is wrapped in scripts/run_encode.sh + scripts/run_eval.sh.
src/
βββ models/ Retrievers (zero-shot + trainable)
β βββ diffretriever_dream.py Zero-shot DiffRetriever β Dream backbone
β βββ diffretriever_llada.py Zero-shot DiffRetriever β LLaDA backbone
β βββ diffretriever_trainable.py Fine-tunable DiffRetriever (Dream / LLaDA)
β βββ promptreps.py Zero-shot PromptReps (autoregressive)
β βββ promptreps_trainable.py Fine-tunable PromptReps (autoregressive)
β βββ diffembed.py DiffEmbed baseline (encoder-style DLM)
β βββ repllama.py RepLLaMA baseline (autoregressive)
β βββ bottleneck_retriever.py Bottleneck / Semantic-Hub variant (ablation)
β βββ block_schedule.py Multi-step denoising schedule
β βββ backbone_adapters.py HF model loading / LoRA wiring
β βββ sparse_utils.py Sparse score helpers (content-word filter)
βββ evaluation/
βββ evaluator.py Per-query scoring + metric aggregation
scripts/
βββ train_diffretriever.py Train DiffRetriever (Dream / LLaDA)
βββ train_promptreps.py Train PromptReps (LLaMA3 / Qwen2.5)
βββ train_diffembed.py Train DiffEmbed
βββ train_repllama.py Train RepLLaMA
βββ encode.py Encode queries / passages (all retrievers go through this)
βββ evaluate_sweep.py Evaluate over the (K_q, K_p) sweep
βββ eval_trec.py Compute MRR / NDCG with pytrec-eval
βββ prepare_msmarco.py MS MARCO data prep
βββ preprocess_msmarco_aug.py Augmented triples prep
βββ shard_io.py Sharded encoding I/O
βββ download_data.sh Fetch MS MARCO + TREC DL + BEIR-7 + NLTK data
βββ run_train.sh Portable launcher: training
βββ run_encode.sh Portable launcher: encoding
βββ run_eval.sh Portable launcher: evaluation
configs/
βββ ds_zero2.json DeepSpeed ZeRO-2 config
βββ ds_zero3.json DeepSpeed ZeRO-3 config
βββ naming.sh Backbone / config naming helpers
βββ dataset_config.sh Dataset path helpers
prompts/default/ Retrieval prompts (see Β§3.1 of the paper)
βββ query_prompt_one.yaml "one word" β single-representation (K=1)
βββ query_prompt_few.yaml "a few words" β multi-representation (K>1)
βββ passage_prompt_{one,few}.yaml same template with "Query" β "Passage"
Note: this repo bundles only what's needed to reproduce the paper. Internal analysis / plot scripts and benchmark drivers are kept in the research repository and are not redistributed here.
We use conda. The pinned requirements.txt is a freeze of the env used during development on a single H100 node (CUDA 12.6, Linux x86_64, Python 3.10).
# 1. Create env
conda create -n diffretriever python=3.10 -y
conda activate diffretriever
# 2. Install pinned dependencies (covers training + encoding + eval)
pip install -r requirements.txt
# 3. Download the datasets and the small NLTK corpora (stopwords + punkt)
bash scripts/download_data.sh # MS MARCO + TREC DL19/DL20 + BEIR-7 + nltk
# or selectively:
# bash scripts/download_data.sh --msmarco
# bash scripts/download_data.sh --beirrequirements.txt is exhaustive β it covers training (DeepSpeed, accelerate, peft) as well as encoding and evaluation. Training uses HuggingFace Trainer directly with the retriever classes under src/models/; there is no separate "training extras" file.
Optional but strongly recommended for speed: flash-attention 2. It is not pinned in requirements.txt because the prebuilt wheel is platform-specific. Install the matching wheel for your CUDA / torch / cxx11abi from the flash-attention releases, or:
pip install flash-attn --no-build-isolationCore versions in the freeze:
torch==2.6.0+cu126,transformers==4.54.0(Dream / LLaDA require this exact range)accelerate==1.12.0,peft==0.18.1,deepspeed==0.18.8pytrec-eval-terrier==0.5.6for retrieval metrics
The four backbones used in the paper (~7--8B parameters each):
| Backbone | HF id | Family |
|---|---|---|
| LLaMA3-8B-Instruct | meta-llama/Meta-Llama-3-8B-Instruct |
Autoregressive |
| Qwen2.5-7B-Instruct | Qwen/Qwen2.5-7B-Instruct |
Autoregressive |
| Dream-v0-Instruct-7B | Dream-org/Dream-v0-Instruct-7B |
Diffusion |
| LLaDA-8B-Instruct | GSAI-ML/LLaDA-8B-Instruct |
Diffusion |
src/models/backbone_adapters.py handles the HF loading + tokenizer setup for all four. Dream is initialised from Qwen2.5 and then trained with masked-position denoising, which gives our tightest controlled comparison between autoregressive and diffusion decoding (same architecture and initialisation, different training objective).
bash scripts/download_data.sh # MS MARCO + TREC DL 2019/2020 + BEIR-7 + NLTK
python scripts/prepare_msmarco.py # Optional: HF-cached MSMARCO splits
python scripts/preprocess_msmarco_aug.py # Pre-tokenize Tevatron/msmarco-passage-augAll workflow scripts are minimal portable launchers β open them, edit the variables at the top for your setup, and run. They wrap scripts/*.py with the canonical arguments used in the paper.
The portable launchers wrap the underlying Python scripts with the canonical paper arguments β open them and edit the variables at the top for your local paths.
# Encode queries and passages (zero-shot DiffRetriever / PromptReps)
MODEL_TYPE=dream K=4 PROMPT_VARIANT=few \
bash scripts/run_encode.sh
# Score the encoded representations
RESULTS_DIR=results/dream_few_K4/msmarco \
QRELS=data/msmarco/qrels.dev.tsv \
bash scripts/run_eval.shOr invoke the underlying scripts directly (this is what the launchers call):
# 1. Encode queries β one shard at a time (--input_file is required)
python scripts/encode.py \
--model_type dream \
--model_name_or_path Dream-org/Dream-v0-Instruct-7B \
--input_file data/msmarco/queries.dev.jsonl \
--output_dir results/dream_few_K4/msmarco/queries \
--is_query \
--query_prompt prompts/default/query_prompt_few.yaml \
--passage_prompt prompts/default/passage_prompt_few.yaml \
--n_gen_tokens 4 --num_denoise_steps 1 \
--encode_type all_steps --sparse_topk 256 \
--shard_id 0 --num_shards 1
# 2. Encode corpus the same way (--input_file data/msmarco/corpus.jsonl,
# drop --is_query, set --num_shards to fan out across the cluster)
# 3. Score β produces summary.json + {mode}.json + {mode}.trec per mode
python scripts/evaluate_sweep.py \
--query_dir results/dream_few_K4/msmarco/queries \
--corpus_dir results/dream_few_K4/msmarco/corpus \
--qrels data/msmarco/qrels.dev.tsv \
--output_dir results/dream_few_K4/msmarco/evalFor the
# DiffRetriever β Dream / LLaDA backbones
MODEL_TYPE=dream MODEL_NAME=Dream-org/Dream-v0-Instruct-7B \
K_Q=4 K_P=16 \
bash scripts/run_train.sh
# PromptReps and the re-trained baselines call the matching Python scripts:
# python scripts/train_promptreps.py ... # PromptReps (LLaMA3 / Qwen2.5)
# python scripts/train_diffembed.py ... # DiffEmbed
# python scripts/train_repllama.py ... # RepLLaMAAll fine-tuning uses LoRA (r=16, Ξ±=64) + DeepSpeed ZeRO-2, InfoNCE with temperature Ο=0.01, one positive + 15 hard negatives, global batch 128, on the Tevatron MS MARCO augmented triples. Diffusion backbones train at their train-selected $(K_q^, K_p^)$; autoregressive backbones use a generation cap of N=4 during fine-tuning (zero-shot uses N=20).
# Sweep all five score modes against encoded representations in one pass:
# single_dense, multi_dense, sparse_max,
# fusion_single_sparse_max, fusion_multi_sparse_max
python scripts/evaluate_sweep.py \
--query_dir <encoded queries dir> \
--corpus_dir <encoded corpus dir> \
--qrels <qrels.tsv> \
--output_dir <output dir>
# Or score a single TREC run file with pytrec-eval (positional args: qrels first, run second)
python scripts/eval_trec.py <qrels> <runfile> --metrics mrr_cut_10 ndcg_cut_10The five score modes map to the paper's scoring breakdown:
-
single_denseβ inner product on the$K{=}1$ representation, or mean-pool of$K{>}1$ representations. -
multi_denseβ ColBERT MaxSim over the$K_q \times K_p$ representations (Equation 1 in the paper). -
sparse_maxβ max-pooled$\log(1 + \mathrm{ReLU}(\ell))$ on logits across the$K$ representations (Equation 2). -
fusion_*_sparse_maxβ equal-weight min-max-normalised hybrid of the corresponding dense + sparse scores (Equation 3).
See What's in this repo above for the full tree.
If you find this work useful, please cite:
@article{wang2026diffretriever,
title={DiffRetriever: Parallel Representative Tokens for Retrieval with Diffusion Language Models},
author={Wang, Shuai and Yin, Yu and Zhuang, Shengyao and Koopman, Bevan and Zuccon, Guido},
journal={arXiv preprint arXiv:2605.07210},
year={2026}
}MIT β see LICENSE.

