From deb633b94898cbb6ed8d8cad578c6afcde32660a Mon Sep 17 00:00:00 2001 From: Muu Date: Thu, 14 May 2026 19:29:23 +0800 Subject: [PATCH] refactor(distributed): deduplicate TE module class lookups with caching - Extract common get_te_classes() with @lru_cache for reuse - Refactor has_te_modules() and _is_te_module() to use tuple isinstance check - Remove duplicated import lists across multiple functions Signed-off-by: Muu --- transformer_engine/pytorch/distributed.py | 48 +++++++++-------------- 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index a0d4ac3530..670eecaa5e 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -613,18 +613,21 @@ def get_activation_recompute_contexts(): return forward_ctx, recompute_ctx -def has_te_modules(network): +@lru_cache +def get_te_classes(): """ - Check if there are any Transformer Engine modules in the network. + Return all Transformer Engine modules. """ from .module import LayerNorm, RMSNorm from .module.base import TransformerEngineBaseModule + from .attention.dot_product_attention.dot_product_attention import ( + DotProductAttention, + ) from .attention.dot_product_attention.backends import UnfusedDotProductAttention - from .attention.dot_product_attention.dot_product_attention import DotProductAttention from .attention.multi_head_attention import MultiheadAttention from .transformer import TransformerLayer - te_classes_list = [ + return ( LayerNorm, RMSNorm, TransformerEngineBaseModule, @@ -632,12 +635,17 @@ def has_te_modules(network): DotProductAttention, MultiheadAttention, TransformerLayer, - ] + ) + +def has_te_modules(network): + """ + Check if there are any Transformer Engine modules in the network. + """ + te_classes = get_te_classes() if isinstance(network, torch.nn.Module): - for module in network.modules(): - if any(isinstance(module, te_class) for te_class in te_classes_list): - return True + if any(isinstance(module, te_classes) for module in network.modules()): + return True return False # Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module, @@ -2040,28 +2048,8 @@ def _is_te_module(module): Check if given module is a Transformer Engine module that requires the TE checkpoint implementation for activation recompute. """ - from .module import LayerNorm, RMSNorm - from .module.base import TransformerEngineBaseModule - from .attention.dot_product_attention.dot_product_attention import DotProductAttention - from .attention.dot_product_attention.backends import UnfusedDotProductAttention - from .attention.multi_head_attention import MultiheadAttention - from .transformer import TransformerLayer - - te_classes_list = [ - LayerNorm, - RMSNorm, - TransformerEngineBaseModule, - UnfusedDotProductAttention, - DotProductAttention, - MultiheadAttention, - TransformerLayer, - ] - is_te_module = False - for te_class in te_classes_list: - if isinstance(module, te_class): - is_te_module = True - break - return is_te_module + te_classes = get_te_classes() + return isinstance(module, te_classes) def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: