Skip to content

JigsawStack/diffusion-gemma-asr

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Audio-native ASR on DiffusionGemma

Speech-to-text that runs through DiffusionGemma's diffusion decoder instead of an autoregressive one. Audio is projected into the Gemma embedding space, and the transcript is produced by parallel diffusion denoising in about 8 to 16 steps, so decoding cost does not grow with transcript length. Multilingual: English, German, French, Spanish, Hindi, Mandarin.

It is an adapter on a frozen backbone: about 42M trained parameters (0.16%) on top of a frozen 26B DiffusionGemma and a frozen whisper-small encoder.

How it works

audio (16 kHz)
  -> frozen whisper-small encoder      acoustic features, 1500 x 768
  -> trainable projector               conv subsample 8x + linear to 2816, ~19M params
  -> 188 audio embeddings scattered into the prompt
  -> DiffusionGemma encoder            causal, builds a read-only KV cache
  -> DiffusionGemma decoder            bidirectional, denoises a 256-token canvas
  -> transcript

Training uses three losses: the diffusion objective (the generator), an autoregressive auxiliary, and a CTC loss applied to the projector through the frozen lm_head. The CTC loss is what lets a frozen backbone learn to use the audio. It runs only during training and is dropped at inference.

Results

Whisper-normalized WER/CER, 16 diffusion steps.

benchmark metric score
LibriSpeech test-clean (en) WER 6.6%
FLEURS English WER 15.7%
VoxPopuli English WER 18.5%
FLEURS Hindi CER 15.8%
FLEURS Mandarin CER 29.6%

Speed is roughly 11 to 17x realtime. Eight steps is close to the best accuracy and the fastest; more steps barely help.

Layout

modal_app.py              Modal app: image, volumes, constants
src/audio.py              whisper features + trainable projector
src/model.py              audio injection, the three losses, diffusion generate
src/data.py               dataset + collator (dynamic canvas, CTC targets)
src/train.py              training entrypoint
src/evaluate.py           WER/CER + latency (--normalizer whisper|basic, --metric wer|cer)
src/serve.py              FastAPI inference endpoint (audio_url or raw bytes)
scripts/download_data.py  LibriSpeech / FLEURS / VoxPopuli -> Modal volume
scripts/probe.py          load and introspect the base model
scripts/publish_hf.py     push the adapter to the Hub
scripts/publish_space.py  build the Gradio demo Space
space/                    the Gradio demo (calls the served endpoint)

Usage

Runs on Modal. A single H100 80 GB holds the 26B model.

# data (once)
modal run scripts/download_data.py::prepare --subset train.clean.100
modal run scripts/download_data.py::prepare --subset test.clean

# train (frozen backbone, projector + LoRA, three losses)
modal run src/train.py::main --lora --ar-weight 1.0 --ctc-weight 1.0 \
  --epochs 10 --batch-size 3 --grad-accum 4 --lr 5e-4 --run-name run1

# evaluate
modal run src/evaluate.py::main --ckpt final --run-name run1 \
  --subset test.clean --normalizer whisper --max-steps 16

# serve
modal deploy src/serve.py

Notes

The encoder has a 30-second window, so longer audio is split at silence and the segments are concatenated, the same approach Whisper uses. The base models download to a Modal volume on first run; the adapter ships separately. Checkpoints and datasets live on Modal volumes, not in git.

Contributors

Languages