Skip to content

[common] Grouped gemm update - nvfp4 for blackwell and fp8 blockwise hopper#2971

Open
pggPL wants to merge 33 commits into
NVIDIA:mainfrom
pggPL:grouped_gemm_nvfp4_and_hopper
Open

[common] Grouped gemm update - nvfp4 for blackwell and fp8 blockwise hopper#2971
pggPL wants to merge 33 commits into
NVIDIA:mainfrom
pggPL:grouped_gemm_nvfp4_and_hopper

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented May 8, 2026

Description

Adds Hopper (SM90) support to cuBLAS grouped GEMM and enables NVFP4 / FP8 block scaling recipes.

Fixes #2455

Type of change

  • Documentation change
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change
  • Infra/Build change
  • Code refactoring

Checklist

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 13 commits March 16, 2026 11:36
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use existing nvte_set_grouped_tensor_param with kNVTEGroupedWithGEMMSwizzledScales
instead of the dedicated set/get functions.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add CUBLAS_NVFP4_GROUPED_GEMM_VERSION and CUBLAS_FP8_BLOCK_GROUPED_GEMM_VERSION macros (13.4+)
- Update check_grouped_gemm_requirements to allow SM90 with cuBLAS 13.4+
- Refactor execute_grouped_gemm to use GroupedGemmConfig struct
- Add divisibility-by-128 validation for FP8 block scaling in setup kernel and quantizer
- Support scalar alpha/beta for Hopper (no per-group alpha/beta)
- Expose get_grouped_gemm_setup_workspace_size to PyTorch via pybind
- Update PyTorch tests to run grouped GEMM on Hopper with cuBLAS 13.4+

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
… scaling tests on Hopper

Extend nvte_grouped_gemm_with_discrete_inputA to handle NVFP4 (Float4E2M1)
inputs: accept kFloat4E2M1 dtype, propagate scale_inv pointers, collect
contiguous amax from discrete tensors, and enforce swizzled-scales checks
for NVFP4 alongside MXFP8. Also add GTEST_SKIP for FP8 tensor scaling
grouped GEMM on Hopper since cuBLAS does not support it there.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…M tests

The setup kernel computes per-tensor scale pointers as data_offset /
block_size, which assumes no padding in the scale buffer. This is only
correct when first_dim % 128 == 0 and last_dim % 128 == 0 (MXFP8) or
last_dim % 64 == 0 (NVFP4). Add explicit assertions in
build_grouped_tensor to catch any future test shapes that violate this.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…d_hopper

Conflicts resolved (3 files):

* tests/pytorch/test_numerics.py
  test_grouped_gemm_grouped_tensor: combined skip rules — Hopper (SM90) requires
  cuBLAS 13.4+, Blackwell+ (SM100) requires cuBLAS 13.3+. Kept main's
  use_bias_scale parametrization.

* transformer_engine/pytorch/cpp_extensions/gemm.py
  general_grouped_gemm_for_grouped_tensor: combined HEAD's num_alphabeta logic
  (single scalar on Hopper, per-group on Blackwell+) with main's cached
  _get_fp32_ones_tensor / _get_fp32_zeros_tensor helpers.

* transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
  - validate_grouped_gemm_inputs: kept HEAD's NVFP4 / FP8 block-scaling
    consistency checks, wrapped in main's nullptr-guard / continue-on-no-data
    pattern.
  - GroupedGemmConfig struct retained; added sm_count from main and
    propagated config_.sm_count -> gemm_config.sm_count in all three
    public APIs.
  - kMaxTensorsPerKernel rename to kMaxGroups (= 64) adopted from main.
  - execute_grouped_gemm signature uses GroupedGemmConfig (HEAD); body uses
    config.sm_count for CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET (from main).
  - Dropped HEAD's simple grouped_bias_add_kernel (dead code); kept main's
    advanced grouped_bias_add_kernel + find_tensor_for_row helper.
  - Replaced inline SM/cuBLAS preambles with check_grouped_gemm_requirements()
    calls in nvte_grouped_gemm, nvte_grouped_gemm_with_discrete_inputA, and
    nvte_grouped_gemm_with_discrete_out. The helper supports both
    Hopper (SM90 + cuBLAS 13.4+) and Blackwell+ (SM100 + cuBLAS 13.3+).
  - Kept HEAD's validate_grouped_gemm_inputs(..., use_per_group_alpha_beta)
    signature for proper alpha/beta validation across architectures.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…or swizzle tests

cublaslt_grouped_gemm.cu:
- Fix incorrect handling of NVFP4/MXFP8 columnwise data in
  build_grouped_gemm_multi_inputA_args by adding a swap_dims flag
  consistent with choose_grouped_operand_storage. Use A_sel.trans
  (post-flip) for gemm_config.avg_k so K is selected from the
  correct dim with discrete A_list.

tests/cpp/test_common.{h,cu}:
- Add enforce_grouped_gemm_alignment parameter (default true) to
  build_grouped_tensor; the MXFP8/NVFP4 first/last_dim 128/64
  alignment asserts are only relevant for the grouped GEMM setup
  kernel, so callers that bypass it (swizzle/unswizzle) opt out.

tests/cpp/operator/test_swizzle.cu:
- Pass enforce_grouped_gemm_alignment=false to build_grouped_tensor
  in MXFP8 swizzle/unswizzle/roundtrip tests, which intentionally
  exercise non-padded shapes.

tests/cpp/operator/test_grouped_gemm.cu:
- Sync GPU/cuBLAS skip rules across all 3 sub-tests, add
  cudaDeviceSynchronize() after nvte_multi_tensor_gemm reference for
  defensive sync, and skip NVFP4 + AllDifferent in all 3 sub-tests
  due to a known flaky bug in the nvte_multi_tensor_gemm reference.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…and_hopper

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

# Conflicts:
#	transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
@pggPL pggPL force-pushed the grouped_gemm_nvfp4_and_hopper branch 2 times, most recently from 335627f to 3f523e7 Compare May 11, 2026 14:04
pggPL and others added 4 commits May 11, 2026 16:13
Apply the same fix as upstream PR NVIDIA#2954 (MXFP8 unaligned dims) to the
analogous NVFP4 / FP8 block scaling paths in setup_grouped_gemm_kernel.

Background: cuBLAS grouped GEMM expects each expert's scale_inv to live
at a specific offset in the contiguous grouped buffer. The quantizer
allocates each per-expert scale_inv tensor padded to the layout cuBLAS
needs (swizzled 128x4 for MX/NV; ceildiv(., 128) x roundup(., 4) for
block scaling). The setup kernel was computing these offsets as
data_offset / block_size for everything except MXFP8 — silently correct
when dims align to 128, but pointing at the middle of the previous
expert's scale tile when they do not. In MoE forward this is reachable
through variable per-expert token counts.

Add three device helpers mirroring compute_grouped_tensor_mxfp8_-
scale_inv_offset:
- compute_grouped_tensor_nvfp4_scale_inv_offset
- compute_grouped_tensor_block_1d_scale_inv_offset
- compute_grouped_tensor_block_2d_scale_inv_offset
Each sums the same padded per-tensor sizes the quantizer uses at alloc
time (Float8BlockQuantizer::get_scale_shape, NVFP4Quantizer::get_scale_-
shape).

NVFP4 columnwise data is set up via use_columnwise(swap_dims=true), so
sel.shape is already pre-transposed for that recipe — the rowwise
formula on (first, last) recovers the colwise alloc. For block scaling
the formula depends on the canonical orientation, so propagate a new
swap_dims field on GroupedOperandSelection and pass effective_rowwise
(sel.rowwise || sel.swap_dims) into the kernel. MXFP8 is invariant
under this change because swap_dims is always false there and its
helper's byte count is invariant under the rowwise flag anyway.

Test: add ShapeCase::kUnalignedAllSame with (M, N, K) = (160, 288, 416)
— all multiples of 32/16 (per-recipe block size) but none multiples of
128, so each expert's scale tile is padded. Exercise it across MXFP8 /
NVFP4 / FP8 block scaling and the three transpose configs that match
the existing parameter grid. Relax build_grouped_tensor's defensive
%128 / %64 alignment assertions to %32 / %16 (block-size only), which
is the actual quantizer requirement now that the offset arithmetic no
longer assumes zero padding.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…st cleanup

Production:
- nvte_grouped_gemm_with_discrete_inputA no longer requires per-expert amax
  buffers to be contiguous. Add `amax_ptrs[kMaxGroups]` to MultiTensorGroupGemmInputArgs
  and read each tensor's amax via indirection in setup_grouped_gemm_kernel
  (mirrors the existing scale_inv_ptrs pattern). The launcher enables the
  NVFP4 alpha computation when amax is available from either source.
- Consolidate four near-identical
  compute_grouped_tensor_{mxfp8,nvfp4,block_1d,block_2d}_scale_inv_offset
  into a single template `compute_grouped_scale_inv_offset<PaddedFn>` and
  collapse the A/B recipe-switch in setup_grouped_gemm_kernel into a local
  `fill_scale_ptr` lambda.

Tests:
- Drop the per-test amax staging workaround in run_grouped_gemm_discrete_in_case
  (no longer needed after the contiguity relax).
- Fix amax management in make_nvfp4_operand: copy values into result's own
  amax buffers instead of aliasing pointers (prevents double-free).
- Extract the three duplicated cuBLAS-version/compute-capability skip blocks
  into a shared `grouped_gemm_skip_reason` helper.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Silences -Wunused-variable (NVIDIA#177-D in nvcc).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@pggPL pggPL force-pushed the grouped_gemm_nvfp4_and_hopper branch from fcefde1 to ce0e4d2 Compare May 11, 2026 14:14
pggPL added 2 commits May 11, 2026 16:17
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the grouped_gemm_nvfp4_and_hopper branch from ce0e4d2 to a4df7bd Compare May 11, 2026 14:20
pggPL and others added 2 commits May 11, 2026 16:44
- nvte_grouped_gemm and nvte_grouped_gemm_with_discrete_out now validate
  per-operand amax for NVFP4 (previously silently dropped the global-scale
  factor when amax was missing). discrete_inputA path also checks B's amax.
- Remove unused ShapeCase::kUnalignedAllSameNVFP4 enum and its comment.
- OperandStorageChoice::swap_dims now defaults to false; rowwise returns
  no longer pass spurious swap_dims=true.
- Unify GroupedGemmSetupWorkspace layout: from_buffers(nullptr, n) returns
  the total byte count, and required_setup_size derives its result from it
  so the layout cannot drift between the two.
- test_common.cu: consolidate the three gather_*_scales lambdas into a
  single gather_scale_inv(bytes_per_elem, get_shape, get_cpu_ptr) helper.
- test_grouped_gemm.cu: extract make_grouped_gemm_ref / make_alpha_beta /
  compare_grouped_d_to_multi helpers; the three run_* variants drop from
  ~1029 to 774 lines with no behavior change.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@pggPL pggPL marked this pull request as ready for review May 11, 2026 15:20
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented May 11, 2026

/te-ci pytorch

@pggPL pggPL requested a review from vthumbe1503 May 11, 2026 15:23
Comment thread tests/cpp/operator/test_grouped_gemm.cu Outdated
Comment on lines +144 to +163
Tensor make_nvfp4_rowwise(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16);
fillUniform(&input_bf16);

Tensor nvfp4(name, shape, DType::kFloat4E2M1, /*rowwise=*/true, /*columnwise=*/false,
NVTE_NVFP4_1D_SCALING);

QuantizationConfigWrapper quant_config;
nvte_quantize_v2(input_bf16.data(), nvfp4.data(), quant_config, 0);

Tensor nvfp4_sw(name + "_sw", shape, DType::kFloat4E2M1,
/*rowwise=*/true, /*columnwise=*/false, NVTE_NVFP4_1D_SCALING);
nvfp4_sw.set_with_gemm_swizzled_scales(true);
size_t data_bytes = test::bytes(nvfp4.rowwise_shape(), nvfp4.dtype());
NVTE_CHECK_CUDA(cudaMemcpy(nvfp4_sw.rowwise_dptr(), nvfp4.rowwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice));
nvte_swizzle_scaling_factors(nvfp4.data(), nvfp4_sw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
return nvfp4_sw;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just swizzle the scales in the first nvte_quantize_v2 call instead of going through 2 tensors?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not supported for nvfp4 (and it is not easy to workaround, for example swizlle in both directions also does not work). I refactored the nvfp4 tests to use only rowwise or only columnwise.

Comment thread tests/cpp/operator/test_grouped_gemm.cu Outdated
Copy link
Copy Markdown
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are issues with the test code (most notably the rowwise and columnwise taken from different inputs).

Comment thread tests/cpp/operator/test_grouped_gemm.cu Outdated
Comment thread tests/cpp/operator/test_grouped_gemm.cu Outdated
Comment thread tests/cpp/operator/test_grouped_gemm.cu Outdated
Comment thread tests/cpp/operator/test_grouped_gemm.cu Outdated
std::vector<size_t>{M, N},
DType::kBFloat16));
s.D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N}, DType::kBFloat16));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we only support BF16 output?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests for fp16 and fp32, I haven't received any example from cublas people with quantized output. We can ofc ask them.

AlphaBetaTensors ab = make_alpha_beta(num_gemms);

constexpr size_t cublas_ws_bytes = 32ull * 1024 * 1024;
const size_t setup_ws_bytes = nvte_get_grouped_gemm_setup_workspace_size(num_gemms);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH not a fan of the name of this function, but I guess that ship has sailed already.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add it in this PR, but I think the name is ok. Any proposition?

Comment thread tests/cpp/operator/test_grouped_gemm.cu Outdated
Comment thread tests/cpp/operator/test_grouped_gemm.cu
Comment thread tests/cpp/test_common.cu
Comment on lines +1285 to +1287
NVTE_CHECK(last_dims[i] % 32 == 0,
"MXFP8 grouped GEMM test: last_dim must be divisible by 32, got ",
last_dims[i]);
Copy link
Copy Markdown
Member

@ptrendx ptrendx May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to have 16 here, but probably not yet, see my PR #2894.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but currently we have 32 in docs.

Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Comment on lines +329 to +331
NVTE_CHECK(sm >= 100, api_name, " requires Blackwell (SM100) or newer architecture.");
NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name,
" requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong - what about the situation where we compiled against the cublas version that is not enough for any grouped gemm support (and so some stuff is not compiled I think?) but then run it on a system with newer cublas? This would pass, but the functionality would still not be there.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have compile time checks, they are in different places in the code.

- discrete-A path: force non_tn_fp8_ok=false for FP8 block scaling to match
  select_grouped_operand logic for B.
- Setup workspace: explicit 16-byte base alignment NVTE_CHECK before
  GroupedGemmSetupWorkspace::from_buffers; matches contract every standard
  allocator (cudaMalloc / PyTorch / XLA) already satisfies.
- Tests: remove compile-time cuBLAS gating, run-time check via
  cuda::cublas_version() suffices since the test doesn't call cuBLAS directly.
- Tests: replace InputCase enum with std::optional<NVTEScalingMode>;
  nullopt = BF16, otherwise scaling mode drives dispatch.
- Tests: NVFP4 operand uses one shared BF16 input transposed via
  nvte_transpose for the columnwise direction (no more duplicate fillUniform
  with different seeds across row/col).
- Tests: parametrize output dtype (BF16 default, plus FP16 and FP32 cases on
  BF16/FP8/NVFP4 recipes); implementation already accepts all three.
- Tests: add NVFP4 2D-quantization coverage (verifies VEC16 scale layout
  fed to cuBLAS is unchanged vs 1D).
- Tests: tighten NVFP4 alignment check from %16 to %32 (TMA requirement of
  the optimized BF16 quantize path), fix misleading kUnalignedAllSame
  comment ({160,288,416} are multiples of 32, common to MXFP8 and NVFP4).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL requested a review from ksivaman as a code owner May 15, 2026 11:52
Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
cuBLAS scaled-GEMM kernels all run in TN, so each operand only needs the
single direction matching its transpose flag. Mapping (is_A, transposed) ->
use_rowwise is uniform across MXFP8 / NVFP4 / FP8 block scaling; move it to
make_grouped_gemm_ref once instead of duplicating the if/else in each helper.

For NVFP4 specifically: drop nvte_transpose + the two-step intermediate
tensor assembly. nvte_quantize_v2 with a single-direction NVFP4 output
(rowwise XOR columnwise) directly produces what cuBLAS needs — columnwise
output goes via fallback quantize kernel (rowwise-only NVFP4 quantize kernel
hard-fails when output has no rowwise data, see quantize.cuh:111).

Matches the production pattern used by PyTorch and JAX bindings: never
allocate both NVFP4 directions on a single tensor (swizzle hard-fails when
both scale_invs are set, see swizzle.cu:985).

Net change: -112 lines.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
pggPL added 2 commits May 15, 2026 16:09
- test_common.cu: move NVFP4_1D_SCALING to the DELAYED/BLOCK_SCALING switch
  arm that allocates the columnwise data buffer as transpose(passed_shape).
  Required by TE's NVFP4 convention (column-wise data is transposed-then-
  quantized; see Tensor::shape() in common/common.h). Was incorrectly grouped
  with MXFP8 (whose columnwise data shape matches the logical shape).
  Without this, allocating columnwise-only NVFP4 in a test wires up a
  scale_inv shape that disagrees with what TE's CheckScaleTensorShape
  expects.

- libtransformer_engine.version: export
  transformer_engine::cuda::cublas_version so the test can read the run-time
  cuBLAS version (analogous to the already-exported sm_arch, sm_count,
  current_device).

- test_grouped_gemm.cu: simplify make_*_operand signatures to take
  use_rowwise directly, with the (is_A, transposed) -> use_rowwise mapping
  centralized in make_grouped_gemm_ref. NVFP4 now uses single-direction
  allocation (no nvte_transpose, no two-step assembly) — same pattern as
  MXFP8 / FP8 block scaling, working thanks to the test_common.cu fix.

- test_grouped_gemm.cu: tighten skip logic via grouped_gemm_skip_reason
  (now takes TestParams):
    * Skip NVFP4 + FP16 output  (cuBLAS hard-errors in cublaslt_gemm.cu)
    * Skip NVFP4 2D quantization in non-TN layouts (fallback quantize path
      doesn't support 2D; production weight quantization always allocates
      both directions and hits the optimized path).
  Drop the BF16 + FP16 output test case (cuBLAS grouped GEMM has no
  algorithm for that combination, even in TN).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop multi-line rationale comments that were conversation residue (linking
back to specific other files/lines, restating GEMM TN convention, listing
PR numbers) and keep just the short ones that name what the code does.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
pggPL and others added 4 commits May 15, 2026 18:36
Two related bugs in the grouped GEMM path were causing FP8 block scaling
NT and NN cases to fail on Hopper for non-Mul128 dims (Mul32 tests).

1. build_grouped_gemm_multi_inputA_args (discrete-A path) used
   t->shape() (LOGICAL) for cuBLAS rows/cols, while the symmetric
   grouped path (build_grouped_tensor / select_grouped_operand) uses
   tensors[i]->{rowwise,columnwise}_shape() (PHYSICAL). For FP8 block
   columnwise the two differ (physical is transposed of logical), so
   args.rows became N instead of K. Switch the discrete-A path to use
   data.shape directly — this aligns both paths and makes swap_dims
   redundant for this function.

2. padded_block_{1d,2d}_scale_inv_floats had a columnwise branch that
   swapped first/last in the formula, but the quantizer
   (test_common.cu:get_scales) always uses logical dims regardless of
   direction. Combined with grouped meta passing transposed dims for
   columnwise data, the formula produced wrong per-expert scale stride
   — only visible for dims not divisible by 128. The unified formula
   ceil(last/128) * roundup(first, 4) is correct for both directions
   because the meta swap and quantizer swap cancel out.

Also remove now-redundant check_fp4_output_compat duplication and trim
stale comments.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
pggPL and others added 3 commits May 15, 2026 21:02
Handle per-operand FP8 block scaling modes when computing grouped scale offsets, and tighten NVFP4 validation to reject unsupported mixed or non-per-group-alpha paths.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Avoid test cases that require direct columnwise 2D block quantization, which is not a supported test setup path.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GroupedGemm: NVFP4 via cuBLAS

3 participants