Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions test/infinicore/framework/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
Contains TestConfig, TestRunner, and BaseOperatorTest classes.
"""

import torch
import infinicore
import traceback
from abc import ABC, abstractmethod

from .results import CaseResult
from .datatypes import to_torch_dtype, to_infinicore_dtype
import torch

import infinicore

from .benchmark import BenchmarkUtils
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
from .results import CaseResult
from .tensor import TensorSpec
from .utils.compare_utils import create_test_comparator
from .utils.tensor_utils import (
clone_torch_tensor,
infinicore_tensor_from_torch,
)
from .utils.compare_utils import create_test_comparator
from .benchmark import BenchmarkUtils


class TestConfig:
Expand Down Expand Up @@ -96,7 +97,7 @@ def run_tests(self, devices, test_func, test_type="Test"):
self.passed_tests.append(
f"{test_case} - {InfiniDeviceNames[device]}"
)
print(f"\033[92m✓\033[0m Passed")
print("\033[92m✓\033[0m Passed")
elif test_result.return_code == -1:
# Test failed - use the actual error message from test_result
fail_msg = f"{test_case} - {InfiniDeviceNames[device]} - {test_result.error_message}"
Expand Down Expand Up @@ -153,7 +154,6 @@ def print_summary(self):
"""
total_tests = len(self.test_cases)
passed_count = len(self.passed_tests)
skipped_count = len(self.skipped_tests)
partial_count = len(self.partial_tests)
failed_count = len(self.failed_tests)

Expand All @@ -179,10 +179,10 @@ def print_summary(self):
# If there are skipped or partial tests, show appropriate message
if self.skipped_tests or self.partial_tests:
print(
f"\n\033[93mTests completed with some implementations missing\033[0m"
"\n\033[93mTests completed with some implementations missing\033[0m"
)
else:
print(f"\n\033[92mAll tests passed!\033[0m")
print("\n\033[92mAll tests passed!\033[0m")

# Print benchmark summary if benchmarking was enabled
if self.config.bench and (
Expand Down Expand Up @@ -448,6 +448,7 @@ def run_test(self, device, test_case, config):

try:
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
infinicore.sync_device()
if infini_result is None:
infini_implemented = False
except NotImplementedError as e:
Expand Down Expand Up @@ -621,7 +622,7 @@ def run_test(self, device, test_case, config):

is_valid = compare_fn(infini_comparison, torch_comparison)
if not is_valid:
raise AssertionError(f"Result comparison failed.")
raise AssertionError("Result comparison failed.")

# ==========================================================================
# UNIFIED BENCHMARKING LOGIC
Expand Down
5 changes: 5 additions & 0 deletions test/infinicore/framework/utils/tensor_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

import infinicore

from ..datatypes import to_infinicore_dtype, to_torch_dtype

# =================================================================
Expand Down Expand Up @@ -32,6 +34,7 @@ def clone_torch_tensor(torch_tensor):


def infinicore_tensor_from_torch(torch_tensor):
synchronize_device(torch_tensor.device.type)
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
Expand Down Expand Up @@ -71,6 +74,8 @@ def convert_infinicore_to_torch(infini_result):
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
infinicore.sync_device()
synchronize_device(torch_result_from_infini.device.type)
return torch_result_from_infini


Expand Down
Loading