From 1fe927bd247ed143f2b661474d62e0b4dbf73452 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 23 Jun 2026 23:08:47 -0700 Subject: [PATCH 1/3] implement base model output caching in model-level tests --- tests/models/testing_utils/common.py | 88 ++++++++++++------- .../test_models_transformer_hunyuan_dit.py | 4 +- .../test_models_transformer_hunyuan_video.py | 8 +- .../test_models_transformer_wan_animate.py | 4 +- 4 files changed, 64 insertions(+), 40 deletions(-) diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 626f1eb7f1bf..ec66bd30f0aa 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -24,9 +24,13 @@ from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging -from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator -from ...testing_utils import assert_tensors_close, torch_device +from ...testing_utils import ( + assert_tensors_close, + require_accelerator, + require_torch_multi_accelerator, + torch_device, +) def named_persistent_module_tensors( @@ -278,8 +282,30 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin): pass """ + @pytest.fixture(scope="class") + def base_model_output(self): + """Class-scoped reference forward output, built once and reused across the class. + + Building the model and running its forward pass is fully deterministic (``torch.manual_seed(0)`` + plus the deterministic ``get_dummy_inputs`` contract), so the reference ("base") output is + identical for every test in the class. The save/load and parallelism tests compare a reloaded + model against this output; computing it a single time here — instead of rebuilding the model and + re-running the forward in each test — removes that redundant work and speeds up the suite. + + The hardware-gated tests that consume this fixture use ``pytest.mark.skipif`` (via the + ``require_*`` decorators), which pytest evaluates before fixture setup, so skipping on a machine + without the required accelerators never triggers this forward. + + Tests that still need a live model (e.g. to save it) build their own with the same seed, so the + reloaded model's weights match this cached output. + """ + torch.manual_seed(0) + model = self.model_class(**self.get_init_dict()).eval().to(torch_device) + with torch.no_grad(): + return model(**self.get_dummy_inputs(), return_dict=False)[0] + @torch.no_grad() - def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): + def test_from_save_pretrained(self, base_model_output, tmp_path, atol=5e-5, rtol=5e-5): torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) @@ -296,13 +322,15 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" ) - image = model(**self.get_dummy_inputs(), return_dict=False)[0] new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] - assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") + assert_tensors_close( + base_model_output, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes." + ) @torch.no_grad() - def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): + def test_from_save_pretrained_variant(self, base_model_output, tmp_path, atol=5e-5, rtol=0): + torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -317,10 +345,11 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): new_model.to(torch_device) - image = model(**self.get_dummy_inputs(), return_dict=False)[0] new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] - assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") + assert_tensors_close( + base_model_output, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes." + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) def test_from_save_pretrained_dtype(self, tmp_path, dtype): @@ -360,13 +389,8 @@ def test_determinism(self, atol=1e-5, rtol=0): ) @torch.no_grad() - def test_output(self, expected_output_shape=None): - model = self.model_class(**self.get_init_dict()) - model.to(torch_device) - model.eval() - - inputs_dict = self.get_dummy_inputs() - output = model(**inputs_dict, return_dict=False)[0] + def test_output(self, base_model_output, expected_output_shape=None): + output = base_model_output assert output is not None, "Model output is None" assert output[0].shape == expected_output_shape or self.output_shape, ( @@ -509,14 +533,12 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, @require_accelerator @torch.no_grad() - def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0): + def test_sharded_checkpoints(self, base_model_output, tmp_path, atol=1e-5, rtol=0): torch.manual_seed(0) config = self.get_init_dict() model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**self.get_dummy_inputs(), return_dict=False)[0] - model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -537,19 +559,17 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close( - base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load" + base_model_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load" ) @require_accelerator @torch.no_grad() - def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0): + def test_sharded_checkpoints_with_variant(self, base_model_output, tmp_path, atol=1e-5, rtol=0): torch.manual_seed(0) config = self.get_init_dict() model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**self.get_dummy_inputs(), return_dict=False)[0] - model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small variant = "fp16" @@ -575,11 +595,15 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close( - base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load" + base_model_output, + new_output, + atol=atol, + rtol=rtol, + msg="Output should match after variant sharded save/load", ) @torch.no_grad() - def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0): + def test_sharded_checkpoints_with_parallel_loading(self, base_model_output, tmp_path, atol=1e-5, rtol=0): from diffusers.utils import constants torch.manual_seed(0) @@ -587,8 +611,6 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**self.get_dummy_inputs(), return_dict=False)[0] - model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -624,7 +646,11 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt output_parallel = model_parallel(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close( - base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading" + base_model_output, + output_parallel, + atol=atol, + rtol=rtol, + msg="Output should match with parallel loading", ) finally: @@ -635,19 +661,17 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt @require_torch_multi_accelerator @torch.no_grad() - def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0): + def test_model_parallelism(self, base_model_output, tmp_path, atol=1e-5, rtol=0): if self.model_class._no_split_modules is None: pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") + torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**config).eval() model = model.to(torch_device) - torch.manual_seed(0) - base_output = model(**inputs_dict, return_dict=False)[0] - model_size = compute_module_sizes(model)[""] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] @@ -665,5 +689,5 @@ def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close( - base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism" + base_model_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism" ) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py index 1c08244b620c..370033ef319f 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -120,9 +120,9 @@ def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]: class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin): - def test_output(self): + def test_output(self, base_model_output): batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0] - super().test_output(expected_output_shape=(batch_size,) + self.output_shape) + super().test_output(base_model_output, expected_output_shape=(batch_size,) + self.output_shape) class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin): diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py index 90c716a336a5..cc934be125aa 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -223,8 +223,8 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: class TestHunyuanVideoI2VTransformer(HunyuanVideoI2VTransformerTesterConfig, ModelTesterMixin): - def test_output(self): - super().test_output(expected_output_shape=(1, *self.output_shape)) + def test_output(self, base_model_output): + super().test_output(base_model_output, expected_output_shape=(1, *self.output_shape)) # ======================== HunyuanVideo Token Replace Image-to-Video ======================== @@ -299,5 +299,5 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: class TestHunyuanVideoTokenReplaceTransformer(HunyuanVideoTokenReplaceTransformerTesterConfig, ModelTesterMixin): - def test_output(self): - super().test_output(expected_output_shape=(1, *self.output_shape)) + def test_output(self, base_model_output): + super().test_output(base_model_output, expected_output_shape=(1, *self.output_shape)) diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index 30f78ca1c3de..bd751974637b 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -146,11 +146,11 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Wan Animate Transformer 3D.""" - def test_output(self): + def test_output(self, base_model_output): # Override test_output because the transformer output is expected to have less channels # than the main transformer input. expected_output_shape = (1, 4, 21, 16, 16) - super().test_output(expected_output_shape=expected_output_shape) + super().test_output(base_model_output, expected_output_shape=expected_output_shape) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): From ef206c246cb966de388323106b4315d7abfd67db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 23 Jun 2026 23:17:28 -0700 Subject: [PATCH 2/3] single quotes --- tests/models/testing_utils/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index ec66bd30f0aa..129c8197887d 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -286,14 +286,14 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin): def base_model_output(self): """Class-scoped reference forward output, built once and reused across the class. - Building the model and running its forward pass is fully deterministic (``torch.manual_seed(0)`` - plus the deterministic ``get_dummy_inputs`` contract), so the reference ("base") output is + Building the model and running its forward pass is fully deterministic (`torch.manual_seed(0)` + plus the deterministic `get_dummy_inputs` contract), so the reference ("base") output is identical for every test in the class. The save/load and parallelism tests compare a reloaded model against this output; computing it a single time here — instead of rebuilding the model and re-running the forward in each test — removes that redundant work and speeds up the suite. - The hardware-gated tests that consume this fixture use ``pytest.mark.skipif`` (via the - ``require_*`` decorators), which pytest evaluates before fixture setup, so skipping on a machine + The hardware-gated tests that consume this fixture use `pytest.mark.skipif` (via the + `require_*` decorators), which pytest evaluates before fixture setup, so skipping on a machine without the required accelerators never triggers this forward. Tests that still need a live model (e.g. to save it) build their own with the same seed, so the From d1fb7e2768152f238e5651f8ab5ab12fc056b194 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Jun 2026 13:09:41 -0700 Subject: [PATCH 3/3] memory --- tests/models/testing_utils/common.py | 56 ++++++++++++++++------------ tests/models/testing_utils/memory.py | 38 +++++++------------ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 129c8197887d..eb120567f3d1 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -262,7 +262,39 @@ def get_dummy_inputs(self) -> Dict[str, Any]: raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.") -class ModelTesterMixin: +class BaseModelOutputMixin: + """Provides the class-scoped `base_model_output` fixture shared across tester mixins. + + Kept separate from `BaseModelTesterConfig` — which only declares the testing contract and performs no + computation — so any mixin that needs the cached reference output (`ModelTesterMixin`, the memory + offload mixins, ...) can inherit it without duplicating the build-and-forward. + """ + + @pytest.fixture(scope="class") + def base_model_output(self): + """Class-scoped reference forward output, built once and reused across the class. + + Building the model and running its forward pass is fully deterministic (`torch.manual_seed(0)` + plus the deterministic `get_dummy_inputs` contract), so the reference ("base") output is + identical for every test in the class. The save/load, parallelism, and memory-offload tests + compare a reloaded/offloaded model against this output; computing it a single time here — instead + of rebuilding the model and re-running the forward in each test — removes that redundant work and + speeds up the suite. + + The hardware-gated tests that consume this fixture use `pytest.mark.skipif` (via the `require_*` + decorators), which pytest evaluates before fixture setup, so skipping on a machine without the + required accelerators never triggers this forward. + + Tests that still need a live model (e.g. to save or offload it) build their own with the same + seed, so the reloaded model's weights match this cached output. + """ + torch.manual_seed(0) + model = self.model_class(**self.get_init_dict()).eval().to(torch_device) + with torch.no_grad(): + return model(**self.get_dummy_inputs(), return_dict=False)[0] + + +class ModelTesterMixin(BaseModelOutputMixin): """ Base mixin class for model testing with common test methods. @@ -282,28 +314,6 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin): pass """ - @pytest.fixture(scope="class") - def base_model_output(self): - """Class-scoped reference forward output, built once and reused across the class. - - Building the model and running its forward pass is fully deterministic (`torch.manual_seed(0)` - plus the deterministic `get_dummy_inputs` contract), so the reference ("base") output is - identical for every test in the class. The save/load and parallelism tests compare a reloaded - model against this output; computing it a single time here — instead of rebuilding the model and - re-running the forward in each test — removes that redundant work and speeds up the suite. - - The hardware-gated tests that consume this fixture use `pytest.mark.skipif` (via the - `require_*` decorators), which pytest evaluates before fixture setup, so skipping on a machine - without the required accelerators never triggers this forward. - - Tests that still need a live model (e.g. to save it) build their own with the same seed, so the - reloaded model's weights match this cached output. - """ - torch.manual_seed(0) - model = self.model_class(**self.get_init_dict()).eval().to(torch_device) - with torch.no_grad(): - return model(**self.get_dummy_inputs(), return_dict=False)[0] - @torch.no_grad() def test_from_save_pretrained(self, base_model_output, tmp_path, atol=5e-5, rtol=5e-5): torch.manual_seed(0) diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index 8731c644854a..84c3e23133a1 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -37,7 +37,7 @@ require_accelerator, torch_device, ) -from .common import cast_inputs_to_dtype, check_device_map_is_respected +from .common import BaseModelOutputMixin, cast_inputs_to_dtype, check_device_map_is_respected def require_offload_support(func): @@ -69,7 +69,7 @@ def wrapper(self, *args, **kwargs): @is_cpu_offload -class CPUOffloadTesterMixin: +class CPUOffloadTesterMixin(BaseModelOutputMixin): """ Mixin class for testing CPU offloading functionality. @@ -94,16 +94,14 @@ def model_split_percents(self) -> list[float]: @require_offload_support @torch.no_grad() - def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0): + def test_cpu_offload(self, base_model_output, tmp_path, atol=1e-5, rtol=0): + torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**config).eval() model = model.to(torch_device) - torch.manual_seed(0) - base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] # We test several splits of sizes to make sure it works max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] @@ -120,21 +118,19 @@ def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**inputs_dict) assert_tensors_close( - base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading" + base_model_output, new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading" ) @require_offload_support @torch.no_grad() - def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0): + def test_disk_offload_without_safetensors(self, base_model_output, tmp_path, atol=1e-5, rtol=0): + torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**config).eval() model = model.to(torch_device) - torch.manual_seed(0) - base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] max_size = int(self.model_split_percents[0] * model_size) # Force disk offload by setting very small CPU memory @@ -154,21 +150,19 @@ def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**inputs_dict) assert_tensors_close( - base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading" + base_model_output, new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading" ) @require_offload_support @torch.no_grad() - def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0): + def test_disk_offload_with_safetensors(self, base_model_output, tmp_path, atol=1e-5, rtol=0): + torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**config).eval() model = model.to(torch_device) - torch.manual_seed(0) - base_output = model(**inputs_dict) - model_size = compute_module_sizes(model)[""] model.cpu().save_pretrained(str(tmp_path)) @@ -183,7 +177,7 @@ def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**inputs_dict) assert_tensors_close( - base_output[0], + base_model_output, new_output[0], atol=atol, rtol=rtol, @@ -192,7 +186,7 @@ def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0): @is_group_offload -class GroupOffloadTesterMixin: +class GroupOffloadTesterMixin(BaseModelOutputMixin): """ Mixin class for testing group offloading functionality. @@ -209,10 +203,9 @@ class GroupOffloadTesterMixin: @require_group_offload_support @pytest.mark.parametrize("record_stream", [False, True]) - def test_group_offloading(self, record_stream, atol=1e-5, rtol=0): + def test_group_offloading(self, base_model_output, record_stream, atol=1e-5, rtol=0): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() - torch.manual_seed(0) @torch.no_grad() def run_forward(model): @@ -224,10 +217,7 @@ def run_forward(model): model.eval() return model(**inputs_dict)[0] - model = self.model_class(**init_dict) - - model.to(torch_device) - output_without_group_offloading = run_forward(model) + output_without_group_offloading = base_model_output torch.manual_seed(0) model = self.model_class(**init_dict)