Skip to content

refactor(distributed): deduplicate TE module class lookups with caching#2992

Open
muutot wants to merge 2 commits into
NVIDIA:mainfrom
muutot:dev
Open

refactor(distributed): deduplicate TE module class lookups with caching#2992
muutot wants to merge 2 commits into
NVIDIA:mainfrom
muutot:dev

Conversation

@muutot
Copy link
Copy Markdown
Contributor

@muutot muutot commented May 14, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Extract common get_te_classes() with @lru_cache for reuse
  • Refactor has_te_modules() and _is_te_module() to use tuple isinstance check
  • Remove duplicated import lists across multiple functions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@muutot muutot requested a review from ksivaman as a code owner May 14, 2026 11:37
@muutot
Copy link
Copy Markdown
Contributor Author

muutot commented May 14, 2026

/te-ci pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 14, 2026

Greptile Summary

This PR deduplicates the repeated TE module class import lists that existed inside both has_te_modules() and _is_te_module() by extracting a single get_te_classes() helper decorated with @lru_cache. Both call-sites are then simplified to use Python's native isinstance(x, tuple_of_types) form.

  • New get_te_classes(): lazy-imports all seven TE classes (LayerNorm, RMSNorm, TransformerEngineBaseModule, UnfusedDotProductAttention, DotProductAttention, MultiheadAttention, TransformerLayer) once, caches the result, and returns it as a tuple ready for isinstance checks.
  • has_te_modules(): replaces the nested any(isinstance(m, c) for c in list) loop with the idiomatic single-call any(isinstance(m, te_classes) for m in network.modules()).
  • _is_te_module(): collapses a manual for-loop with a break into a single return isinstance(module, te_classes).

Confidence Score: 5/5

The change is a pure refactor with no behavioral changes to production code paths; all seven TE classes are still imported and checked identically.

Both has_te_modules and _is_te_module preserve their original logic exactly — only the structure changed. The @lru_cache is applied to a zero-argument function, so the cache key is always () and the cached result is always the same tuple of class objects, matching the previous behavior of repeated identical imports from Python's already-cached module system.

No files require special attention beyond the single changed file.

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Extracts a shared get_te_classes() helper with @lru_cache, removing duplicated import lists from has_te_modules() and _is_te_module(), and simplifies both to use a single isinstance(x, tuple_of_types) call. Logic is preserved exactly; the lru_cache annotation is a mild test-isolation trade-off.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["has_te_modules(network)"] --> C
    B["_is_te_module(module)"] --> C
    C["get_te_classes() [@lru_cache]"] -->|first call| D["lazy import 7 TE classes"]
    D --> E["cache & return tuple"]
    C -->|cached| E
    E --> F1["isinstance(module, te_classes)\nfor each module in network.modules()"]
    E --> F2["isinstance(module, te_classes)"]
    A --> F1
    B --> F2
Loading

Reviews (3): Last reviewed commit: "Merge branch 'main' into dev" | Re-trigger Greptile

Comment on lines +616 to 618
@lru_cache
def get_te_classes():
from .module import LayerNorm, RMSNorm
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 get_te_classes() is now a shared, cached helper that both has_te_modules() and _is_te_module() depend on, but it has no docstring. Every other public helper in this file documents what it returns and why the lazy imports are deferred; adding one here keeps the module consistent.

Suggested change
@lru_cache
def get_te_classes():
from .module import LayerNorm, RMSNorm
@lru_cache
def get_te_classes():
"""Return a tuple of all Transformer Engine module classes.
Imports are deferred to avoid circular dependencies at module load time.
The result is cached so the tuple is built only once per process.
"""
from .module import LayerNorm, RMSNorm

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@muutot
Copy link
Copy Markdown
Contributor Author

muutot commented May 14, 2026

/te-ci pytorch

- Extract common get_te_classes() with @lru_cache for reuse
- Refactor has_te_modules() and _is_te_module() to use tuple isinstance check
- Remove duplicated import lists across multiple functions

Signed-off-by: Muu <koimuu@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant