diff --git a/python/infinicore/tensor.py b/python/infinicore/tensor.py index 4348887d6..76c8547e4 100644 --- a/python/infinicore/tensor.py +++ b/python/infinicore/tensor.py @@ -186,17 +186,29 @@ def strided_from_blob(data_ptr, size, strides, *, dtype=None, device=None): def from_torch(torch_tensor) -> Tensor: + if torch_tensor.device.type == "npu": + import torch + + torch.npu.synchronize() + infini_type = to_infinicore_dtype(torch_tensor.dtype) infini_device = infinicore.device(torch_tensor.device.type, 0) - return Tensor( - _infinicore.from_blob( + if torch_tensor.is_contiguous(): + underlying = _infinicore.from_blob( torch_tensor.data_ptr(), list(torch_tensor.shape), dtype=infini_type._underlying, device=infini_device._underlying, - ), - _torch_ref=torch_tensor, - ) + ) + else: + underlying = _infinicore.strided_from_blob( + torch_tensor.data_ptr(), + list(torch_tensor.shape), + list(torch_tensor.stride()), + dtype=infini_type._underlying, + device=infini_device._underlying, + ) + return Tensor(underlying, _torch_ref=torch_tensor) def from_numpy(