diff --git a/test/infinicore/framework/base.py b/test/infinicore/framework/base.py index e57e8ec11..477eab2e0 100644 --- a/test/infinicore/framework/base.py +++ b/test/infinicore/framework/base.py @@ -3,21 +3,22 @@ Contains TestConfig, TestRunner, and BaseOperatorTest classes. """ -import torch -import infinicore import traceback from abc import ABC, abstractmethod -from .results import CaseResult -from .datatypes import to_torch_dtype, to_infinicore_dtype +import torch + +import infinicore + +from .benchmark import BenchmarkUtils from .devices import InfiniDeviceNames, torch_device_map -from .tensor import TensorSpec, TensorInitializer +from .results import CaseResult +from .tensor import TensorSpec +from .utils.compare_utils import create_test_comparator from .utils.tensor_utils import ( clone_torch_tensor, infinicore_tensor_from_torch, ) -from .utils.compare_utils import create_test_comparator -from .benchmark import BenchmarkUtils class TestConfig: @@ -96,7 +97,7 @@ def run_tests(self, devices, test_func, test_type="Test"): self.passed_tests.append( f"{test_case} - {InfiniDeviceNames[device]}" ) - print(f"\033[92m✓\033[0m Passed") + print("\033[92m✓\033[0m Passed") elif test_result.return_code == -1: # Test failed - use the actual error message from test_result fail_msg = f"{test_case} - {InfiniDeviceNames[device]} - {test_result.error_message}" @@ -153,7 +154,6 @@ def print_summary(self): """ total_tests = len(self.test_cases) passed_count = len(self.passed_tests) - skipped_count = len(self.skipped_tests) partial_count = len(self.partial_tests) failed_count = len(self.failed_tests) @@ -179,10 +179,10 @@ def print_summary(self): # If there are skipped or partial tests, show appropriate message if self.skipped_tests or self.partial_tests: print( - f"\n\033[93mTests completed with some implementations missing\033[0m" + "\n\033[93mTests completed with some implementations missing\033[0m" ) else: - print(f"\n\033[92mAll tests passed!\033[0m") + print("\n\033[92mAll tests passed!\033[0m") # Print benchmark summary if benchmarking was enabled if self.config.bench and ( @@ -448,6 +448,7 @@ def run_test(self, device, test_case, config): try: infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs) + infinicore.sync_device() if infini_result is None: infini_implemented = False except NotImplementedError as e: @@ -621,7 +622,7 @@ def run_test(self, device, test_case, config): is_valid = compare_fn(infini_comparison, torch_comparison) if not is_valid: - raise AssertionError(f"Result comparison failed.") + raise AssertionError("Result comparison failed.") # ========================================================================== # UNIFIED BENCHMARKING LOGIC diff --git a/test/infinicore/framework/utils/tensor_utils.py b/test/infinicore/framework/utils/tensor_utils.py index 19467d238..e16896a21 100644 --- a/test/infinicore/framework/utils/tensor_utils.py +++ b/test/infinicore/framework/utils/tensor_utils.py @@ -1,5 +1,7 @@ import torch + import infinicore + from ..datatypes import to_infinicore_dtype, to_torch_dtype # ================================================================= @@ -32,6 +34,7 @@ def clone_torch_tensor(torch_tensor): def infinicore_tensor_from_torch(torch_tensor): + synchronize_device(torch_tensor.device.type) infini_device = infinicore.device(torch_tensor.device.type, 0) if torch_tensor.is_contiguous(): return infinicore.from_blob( @@ -71,6 +74,8 @@ def convert_infinicore_to_torch(infini_result): ) temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini) temp_tensor.copy_(infini_result) + infinicore.sync_device() + synchronize_device(torch_result_from_infini.device.type) return torch_result_from_infini