Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 18 additions & 30 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,31 +613,39 @@ def get_activation_recompute_contexts():
return forward_ctx, recompute_ctx


def has_te_modules(network):
@lru_cache
def get_te_classes():
"""
Check if there are any Transformer Engine modules in the network.
Return all Transformer Engine modules.
"""
from .module import LayerNorm, RMSNorm
Comment on lines +616 to 621
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 get_te_classes() is now a shared, cached helper that both has_te_modules() and _is_te_module() depend on, but it has no docstring. Every other public helper in this file documents what it returns and why the lazy imports are deferred; adding one here keeps the module consistent.

Suggested change
@lru_cache
def get_te_classes():
from .module import LayerNorm, RMSNorm
@lru_cache
def get_te_classes():
"""Return a tuple of all Transformer Engine module classes.
Imports are deferred to avoid circular dependencies at module load time.
The result is cached so the tuple is built only once per process.
"""
from .module import LayerNorm, RMSNorm

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

from .module.base import TransformerEngineBaseModule
from .attention.dot_product_attention.dot_product_attention import (
DotProductAttention,
)
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer

te_classes_list = [
return (
LayerNorm,
RMSNorm,
TransformerEngineBaseModule,
UnfusedDotProductAttention,
DotProductAttention,
MultiheadAttention,
TransformerLayer,
]
)


def has_te_modules(network):
"""
Check if there are any Transformer Engine modules in the network.
"""
te_classes = get_te_classes()
if isinstance(network, torch.nn.Module):
for module in network.modules():
if any(isinstance(module, te_class) for te_class in te_classes_list):
return True
if any(isinstance(module, te_classes) for module in network.modules()):
return True
return False

# Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
Expand Down Expand Up @@ -2040,28 +2048,8 @@ def _is_te_module(module):
Check if given module is a Transformer Engine module that requires the TE checkpoint
implementation for activation recompute.
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer

te_classes_list = [
LayerNorm,
RMSNorm,
TransformerEngineBaseModule,
UnfusedDotProductAttention,
DotProductAttention,
MultiheadAttention,
TransformerLayer,
]
is_te_module = False
for te_class in te_classes_list:
if isinstance(module, te_class):
is_te_module = True
break
return is_te_module
te_classes = get_te_classes()
return isinstance(module, te_classes)


def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Expand Down