diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index b93d81bdef..135fbe8c82 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -46,11 +46,20 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False. + sequence_length (int | None, optional): sequence length required for causal masking. Defaults to None. + with_cross_attention (bool, optional): whether to include cross-attention layers that attend to an + external context tensor. When False, cross_attn is set to nn.Identity() so that the attribute + always exists for typing and checkpoint compatibility. Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). include_fc: whether to include the final linear layer. Default to True. use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + Raises: + ValueError: if dropout_rate is not in [0, 1]. + ValueError: if hidden_size is not divisible by num_heads. + """ super().__init__() @@ -79,14 +88,18 @@ def __init__( self.with_cross_attention = with_cross_attention self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=False, - use_flash_attention=use_flash_attention, - ) + self.cross_attn: CrossAttentionBlock | nn.Identity + if with_cross_attention: + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, + ) + else: + self.cross_attn = nn.Identity() def forward( self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index 3a278c112a..6f2316e218 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -147,7 +147,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 for k in list(old_state_dict.keys()): - if "norm2" in k: + if "norm2" in k and k.replace("norm2", "norm_cross_attn") in new_state_dict: new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k) if "norm3" in k: new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k) diff --git a/tests/networks/blocks/test_transformerblock.py b/tests/networks/blocks/test_transformerblock.py index b977a38e73..827850f05a 100644 --- a/tests/networks/blocks/test_transformerblock.py +++ b/tests/networks/blocks/test_transformerblock.py @@ -16,9 +16,11 @@ import numpy as np import torch +import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode +from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.utils import optional_import from tests.test_utils import dict_product @@ -53,6 +55,36 @@ def test_ill_arg(self): with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + @skipUnless(has_einops, "Requires einops") + def test_cross_attention_is_identity_when_disabled(self): + block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False) + # attributes always exist for typing and checkpoint compatibility + self.assertTrue(hasattr(block, "cross_attn")) + self.assertTrue(hasattr(block, "norm_cross_attn")) + # cross_attn is nn.Identity (no parameters) when disabled + self.assertIsInstance(block.cross_attn, nn.Identity) + param_names = [name for name, _ in block.named_parameters()] + self.assertFalse(any(n.startswith("cross_attn") for n in param_names)) + + @skipUnless(has_einops, "Requires einops") + def test_cross_attention_params_registered_when_enabled(self): + block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True) + self.assertIsInstance(block.cross_attn, CrossAttentionBlock) + self.assertTrue(hasattr(block, "norm_cross_attn")) + param_names = [name for name, _ in block.named_parameters()] + self.assertTrue(any(n.startswith("cross_attn.") for n in param_names)) + self.assertTrue(any("norm_cross_attn" in n for n in param_names)) + + @skipUnless(has_einops, "Requires einops") + def test_cross_attention_forward_with_context(self): + hidden_size = 128 + block = TransformerBlock(hidden_size=hidden_size, mlp_dim=256, num_heads=4, with_cross_attention=True) + x = torch.randn(2, 16, hidden_size) + context = torch.randn(2, 8, hidden_size) + with eval_mode(block): + out = block(x, context=context) + self.assertEqual(out.shape, x.shape) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format