diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 2d25242825..678062df91 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -42,6 +42,11 @@ shutil.rmtree(build_tools_copy) shutil.copytree(build_tools_dir, build_tools_copy) +license_src = current_file_path.parent.parent / "LICENSE" +license_dst = current_file_path / "LICENSE" +if license_src.is_file(): + shutil.copyfile(license_src, license_dst) + from build_tools.build_ext import get_build_ext from build_tools.utils import copy_common_headers, min_python_version_str @@ -131,7 +136,10 @@ def get_cuda_major_version() -> int: python_requires=f">={min_python_version_str()}", install_requires=install_requires, tests_require=test_requirements(), + license_files=("LICENSE",), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) shutil.rmtree("build_tools") + if license_dst.is_file(): + license_dst.unlink() diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 99f6a99efa..593a3169d9 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -43,6 +43,11 @@ shutil.rmtree(build_tools_copy) shutil.copytree(build_tools_dir, build_tools_copy) +license_src = current_file_path.parent.parent / "LICENSE" +license_dst = current_file_path / "LICENSE" +if license_src.is_file(): + shutil.copyfile(license_src, license_dst) + from build_tools.build_ext import get_build_ext from build_tools.utils import copy_common_headers, min_python_version_str @@ -177,7 +182,10 @@ def run(self): python_requires=f">={min_python_version_str()}", install_requires=install_requires, tests_require=test_requirements(), + license_files=("LICENSE",), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) shutil.rmtree("build_tools") + if license_dst.is_file(): + license_dst.unlink()