Stateless broadcast optimization#190
Draft
cavusmustafa wants to merge 3 commits into
Draft
Conversation
Collaborator
cavusmustafa
commented
May 26, 2026
- Stateless translation updated so gpu plugin can capture ov::op::internal::RoPE
- Temporary performance solution for stateless gpu: Bypass v13 SDPA entirely. Express attention manually so MatMul's NUMPY broadcast handles the GQA expansion at kernel level. K and V stay at n_heads_kv shape; the GEMM kernel reads them once and broadcasts via stride trick. We can revert this once we can utilize internal SDPA kernel which supports GQA broadcasting.
a56fb28 to
115a310
Compare
wine99
approved these changes
Jun 1, 2026
| auto q_5d_shape = ov::op::v0::Constant::create( | ||
| ov::element::i64, {5}, | ||
| std::vector<int64_t>{1, num_heads_kv, factor, -1, head_size}); | ||
|
|
Collaborator
There was a problem hiding this comment.
I believe qkv arrive as [B, n_heads, S, head_size] where B is the extra input n_seq_active, so this code does not work correctly with llama-perplexity or llama-server -np > 1.
If the ov pattern supports multiple sequences, i.e. B != 1, we can change the shape to {0, num_heads_kv, 1, -1, head_size} and set special_zero = true in Reshape. If the ov pattern does not support multiple sequences, we can set use_manual_gqa_attention to false if n_seq > 1 or manually run perplexity with GGML_OPENVINO_MANUAL_GQA_ATTN=0. Otherwise LGTM
Collaborator
There was a problem hiding this comment.
FYI to run llama-server -np > 1 or llama-perplexity you need to include the commit from #199
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.