From 0b4de878a2119bca433f60ed9b8150edd48e44f0 Mon Sep 17 00:00:00 2001 From: Andres Guzman-Ballen Date: Sun, 24 May 2026 23:41:28 -0500 Subject: [PATCH] TST: Add remaining tests for ObjectCode.from_ --- cuda_core/tests/test_module.py | 87 ++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 3a438f825a..a34c824732 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -3,6 +3,8 @@ import ctypes import pickle +import shutil +import subprocess import warnings import pytest @@ -33,6 +35,10 @@ """ +def _nvcc_path(): + return shutil.which("nvcc") + + def _is_nvfatbin_available(): """Check if nvfatbin bindings are available.""" try: @@ -44,7 +50,24 @@ def _is_nvfatbin_available(): return False +def _is_nvcc_available(): + return _nvcc_path() is not None + + +def _compile_with_nvcc(*, tmp_path, kernel, flags, arch, suffix): + """Compile kernel with nvcc, return content of output path.""" + src = tmp_path / "kernel.cu" + src.write_text(kernel) + out = tmp_path / f"kernel{suffix}" + compile_cmd = [_nvcc_path(), f"-arch={arch}", *flags, "-o", str(out), str(src)] + result = subprocess.run(compile_cmd, capture_output=True) # noqa: S603 + if result.returncode != 0: + pytest.fail(f"nvcc failed: {result.stderr}") + return out.read_bytes() + + nvfatbin_available = pytest.mark.skipif(not _is_nvfatbin_available(), reason="nvfatbin bindings not available") +nvcc_available = pytest.mark.skipif(not _is_nvcc_available(), reason="nvcc not in PATH") @pytest.fixture(scope="module") @@ -172,6 +195,20 @@ def get_saxpy_fatbin(init_cuda): return bytes(fatbin), sym_map +@pytest.fixture +def get_saxpy_object(init_cuda, tmp_path): + dev = Device() + arch = dev.arch + return _compile_with_nvcc(tmp_path=tmp_path, kernel=SAXPY_KERNEL, flags=["-dc"], arch=f"sm_{arch}", suffix=".o") + + +@pytest.fixture +def get_saxpy_library(init_cuda, tmp_path): + dev = Device() + arch = dev.arch + return _compile_with_nvcc(tmp_path=tmp_path, kernel=SAXPY_KERNEL, flags=["-lib"], arch=f"sm_{arch}", suffix=".a") + + def test_get_kernel(init_cuda): kernel = """extern "C" __global__ void ABC() { }""" @@ -330,6 +367,56 @@ def test_object_code_load_fatbin_from_file(get_saxpy_fatbin, tmp_path, convert_p mod_obj.get_kernel("saxpy") # force loading +@nvcc_available +def test_object_code_load_object(get_saxpy_object): + objct = get_saxpy_object + assert isinstance(objct, bytes) + mod_obj = ObjectCode.from_object(objct) + assert mod_obj.code == objct + assert mod_obj.code_type == "object" + # object doesn't support kernel retrieval directly as it's used for linking + # Test that get_kernel fails for unsupported code type + with pytest.raises(RuntimeError, match=r'Unsupported code type "object"'): + mod_obj.get_kernel("saxpy") + + +@nvcc_available +def test_object_code_load_object_from_file(get_saxpy_object, tmp_path, convert_path): + objct = get_saxpy_object + assert isinstance(objct, bytes) + object_file = tmp_path / "test.o" + object_file.write_bytes(objct) + arg = convert_path(object_file) + mod_obj = ObjectCode.from_object(arg) + assert mod_obj.code == str(arg) + assert mod_obj.code_type == "object" + + +@nvcc_available +def test_object_code_load_library(get_saxpy_library): + library = get_saxpy_library + assert isinstance(library, bytes) + mod_obj = ObjectCode.from_library(library) + assert mod_obj.code == library + assert mod_obj.code_type == "library" + # library doesn't support kernel retrieval directly as it's used for linking + # Test that get_kernel fails for unsupported code type + with pytest.raises(RuntimeError, match=r'Unsupported code type "library"'): + mod_obj.get_kernel("saxpy") + + +@nvcc_available +def test_object_code_load_library_from_file(get_saxpy_library, tmp_path, convert_path): + library = get_saxpy_library + assert isinstance(library, bytes) + library_file = tmp_path / "test.a" + library_file.write_bytes(library) + arg = convert_path(library_file) + mod_obj = ObjectCode.from_library(arg) + assert mod_obj.code == str(arg) + assert mod_obj.code_type == "library" + + def test_saxpy_arguments(get_saxpy_kernel_cubin, cuda12_4_prerequisite_check): krn, _ = get_saxpy_kernel_cubin