Summary
The prologue-fusion matcher in mlir_template.py:codegen_template_code decides which input buffer to reuse-as-spad by comparing the prologue node's get_numel() against each candidate read buffer's full get_size(). When the prologue reads a slice/chunk view of a larger buffer (numel of the view < numel of the parent buffer), the matcher finds no candidate and the bare assert(candidate_found) fires, killing the compile.
Triggered by llama4 MoE's SwiGLU FFN: gate_up = bmm(...) # [..., 2*E] then chunk(2, dim=-1) to produce gate/up, followed by silu(gate) * up as the prologue for the next bmm.
Repro
On develop @ 7b6daed (PR #231 merged), transformers 4.51.3:
python scripts/op_coverage.py --models llama4
(num_hidden_layers=2, interleave_moe_layer_step=2, num_local_experts=4, batch=1, seq_len=32, fp32.)
Original traceback
File ".../PyTorchSimFrontend/mlir/mlir_template.py", line 542, in codegen_template_code
assert(candidate_found)
torch._inductor.exc.InductorError: AssertionError:
Diagnostic (patched assert locally with shape printout)
[prologue fusion] no input buffer matches numel of prologue node
node: SchedulerNode(name='op95')
node.get_numel(): 393216
node.node.get_size(): [4, 32, 3072]
reads: ['buf90', 'buf90']
candidate buffers:
buf90: size=[4, 32, 6144] numel=786432
393216 * 2 == 786432 -- the prologue reads two slice views of buf90, each half its parent. The matcher only checks the parent buffer's full numel and never considers the view.
Source pattern (transformers/models/llama4/modeling_llama4.py)
gate_up = torch.bmm(hidden_states, self.gate_up_proj) # [..., 2 * expert_dim]
gate, up = gate_up.chunk(2, dim=-1) # each [..., expert_dim]
next_h = silu(gate) * up # prologue of down_proj
Root cause
mlir_template.py:537-542:
for candidate_read in read_list:
if candidate_read in buf_dict and reduce(operator.mul, buf_dict[candidate_read].get_size(), 1) == node.node.get_numel():
prologue_input_arg = candidate_read
candidate_found = True
break
assert(candidate_found)
buf_dict[candidate_read].get_size() is the parent buffer's size, not the view's.
read_list is derived from node.read_writes.reads (memdeps) which lose the view info before reaching this code.
- For any
chunk/split/slice read, the numel comparison is structurally guaranteed to mismatch.
Suggested fix
Two options:
-
Make the matcher view-aware: walk node.read_writes.reads and use each MemoryDep's actual access size (or node.node.layout's view) instead of the parent buffer's get_size(). The current comment "memdep.get_size() != data.get_size()" already acknowledges this gap.
-
Bail out gracefully when no candidate matches: turn the bare assert into a fallback that skips prologue fusion for this node (codegen the prologue as a standalone kernel). At minimum this should be the behavior; currently a recoverable scheduling choice kills the whole compile.
The bare assert(candidate_found) is also worth replacing with the diagnostic message above so future shape mismatches surface their context instead of an empty AssertionError.
Scope
Blocks llama4 MoE end-to-end forward; the same pattern likely affects any future model that uses split/chunk immediately before a fused-prologue gemm/bmm.
Environment
Summary
The prologue-fusion matcher in
mlir_template.py:codegen_template_codedecides which input buffer to reuse-as-spad by comparing the prologue node'sget_numel()against each candidate read buffer's fullget_size(). When the prologue reads a slice/chunk view of a larger buffer (numel of the view < numel of the parent buffer), the matcher finds no candidate and the bareassert(candidate_found)fires, killing the compile.Triggered by llama4 MoE's SwiGLU FFN:
gate_up = bmm(...) # [..., 2*E]thenchunk(2, dim=-1)to produce gate/up, followed bysilu(gate) * upas the prologue for the next bmm.Repro
On develop @ 7b6daed (PR #231 merged), transformers 4.51.3:
(
num_hidden_layers=2,interleave_moe_layer_step=2,num_local_experts=4, batch=1, seq_len=32, fp32.)Original traceback
Diagnostic (patched assert locally with shape printout)
393216 * 2 == 786432-- the prologue reads two slice views ofbuf90, each half its parent. The matcher only checks the parent buffer's full numel and never considers the view.Source pattern (transformers/models/llama4/modeling_llama4.py)
Root cause
mlir_template.py:537-542:buf_dict[candidate_read].get_size()is the parent buffer's size, not the view's.read_listis derived fromnode.read_writes.reads(memdeps) which lose the view info before reaching this code.chunk/split/sliceread, the numel comparison is structurally guaranteed to mismatch.Suggested fix
Two options:
Make the matcher view-aware: walk
node.read_writes.readsand use eachMemoryDep's actual access size (ornode.node.layout's view) instead of the parent buffer'sget_size(). The current comment "memdep.get_size() != data.get_size()" already acknowledges this gap.Bail out gracefully when no candidate matches: turn the bare
assertinto a fallback that skips prologue fusion for this node (codegen the prologue as a standalone kernel). At minimum this should be the behavior; currently a recoverable scheduling choice kills the whole compile.The bare
assert(candidate_found)is also worth replacing with the diagnostic message above so future shape mismatches surface their context instead of an emptyAssertionError.Scope
Blocks llama4 MoE end-to-end forward; the same pattern likely affects any future model that uses
split/chunkimmediately before a fused-prologue gemm/bmm.Environment
CONFIG_FUSION_PROLOGUEenabled (default)