Skip to content

[JAX] Move quantization to before AG in MoEBlock #2998

@tdophung

Description

@tdophung

Current MoEBlock initial impl follows the same A2Av flow in maxtext where for FSDP, weights are AG then quantized before groupedGEMM happens. This is wasteful as the quantization kernel has to work with fsdp times the amount of data needed, which makes it about fsdp times slower. Need to move it to happen before AG.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions