Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft link NVRTC for cupy_backends.cuda.libs.nvrtc #7621

Merged
merged 20 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8.cython
Original file line number Diff line number Diff line change
@@ -1,4 1,4 @@
[flake8]
filename = *.pyx, *.pxd, *.pxi
exclude = .git, .eggs, *.py
ignore = W503,W504,E225,E226,E227,E275,E402,E999
ignore = W503,W504,E211,E225,E226,E227,E275,E402,E999
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E211 added to avoid Whitespace before '(' error:

ctypedef nvrtcResult (*F_nvrtcVersion)(int *major, int *minor) nogil
                    ^

1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 36,7 @@ repos:
rev: v0.15.0
hooks:
- id: cython-lint
args: ["--max-line-length", "79"]
Copy link
Member Author

@kmaehashi kmaehashi Aug 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same limitation as current coding standard (.flake8.cython)


- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
Expand Down
3 changes: 2 additions & 1 deletion cupy_backends/cuda/_softlink.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 2,7 @@ ctypedef int (*func_ptr)(...) nogil # NOQA

cdef class SoftLink:
cdef:
object error
str prefix
object _cdll
str _prefix
func_ptr get(self, str name)
22 changes: 15 additions & 7 deletions cupy_backends/cuda/_softlink.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 6,32 @@ cimport cython


cdef class SoftLink:
def __init__(self, object libname, str prefix):
def __init__(self, object libname, str prefix, *, bint mandatory=False):
self.error = None
self.prefix = prefix
self._cdll = None
if libname is not None:
if libname is None:
# Stub build or CUDA/HIP only library.
self.error = RuntimeError(
'The library is unavailable in the current platform.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!

else:
try:
self._cdll = ctypes.CDLL(libname)
except Exception as e:
warnings.warn(
f'Warning: CuPy failed to load "{libname}": '
f'({type(e).__name__}: {e})')
self._prefix = prefix
self.error = e
msg = (
f'CuPy failed to load {libname}: {type(e).__name__}: {e}')
if mandatory:
raise RuntimeError(msg) from e
warnings.warn(msg)

cdef func_ptr get(self, str name):
"""
Returns a function pointer for the API.
"""
if self._cdll is None:
return <func_ptr>_fail_unsupported
cdef str funcname = f'{self._prefix}{name}'
cdef str funcname = f'{self.prefix}{name}'
cdef object func = getattr(self._cdll, funcname, None)
if func is None:
return <func_ptr>_fail_not_found
Expand Down
22 changes: 0 additions & 22 deletions cupy_backends/cuda/cupy_nvrtc.h

This file was deleted.

139 changes: 139 additions & 0 deletions cupy_backends/cuda/libs/_cnvrtc.pxi
Original file line number Diff line number Diff line change
@@ -0,0 1,139 @@
import sys as _sys

from cupy_backends.cuda.api cimport runtime
from cupy_backends.cuda._softlink cimport SoftLink


ctypedef int nvrtcResult
ctypedef void* nvrtcProgram
# TODO(kmaehashi): Remove this alias.
ctypedef nvrtcProgram Program

ctypedef const char* (*F_nvrtcGetErrorString)(nvrtcResult result) nogil
cdef F_nvrtcGetErrorString nvrtcGetErrorString

ctypedef nvrtcResult (*F_nvrtcVersion)(int *major, int *minor) nogil
cdef F_nvrtcVersion nvrtcVersion

ctypedef nvrtcResult (*F_nvrtcCreateProgram)(
nvrtcProgram* prog, const char* src, const char* name, int numHeaders,
const char** headers, const char** includeNames) nogil
cdef F_nvrtcCreateProgram nvrtcCreateProgram

ctypedef nvrtcResult (*F_nvrtcDestroyProgram)(nvrtcProgram *prog) nogil
cdef F_nvrtcDestroyProgram nvrtcDestroyProgram

ctypedef nvrtcResult (*F_nvrtcCompileProgram)(
nvrtcProgram prog, int numOptions, const char** options) nogil
Comment on lines 26 to 27
Copy link
Member Author

@kmaehashi kmaehashi Aug 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctypedefs can be generated by copy-pasting arguments from the header file.

nvrtc.h:

nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
                                int numOptions, const char * const *options);

cdef F_nvrtcCompileProgram nvrtcCompileProgram

ctypedef nvrtcResult (*F_nvrtcGetPTXSize)(nvrtcProgram prog, size_t *ptxSizeRet) nogil # NOQA
cdef F_nvrtcGetPTXSize nvrtcGetPTXSize

ctypedef nvrtcResult (*F_nvrtcGetPTX)(nvrtcProgram prog, char *ptx) nogil
cdef F_nvrtcGetPTX nvrtcGetPTX

ctypedef nvrtcResult (*F_nvrtcGetCUBINSize)(nvrtcProgram prog, size_t *cubinSizeRet) nogil # NOQA
cdef F_nvrtcGetCUBINSize nvrtcGetCUBINSize

ctypedef nvrtcResult (*F_nvrtcGetCUBIN)(nvrtcProgram prog, char *cubin) nogil
cdef F_nvrtcGetCUBIN nvrtcGetCUBIN

ctypedef nvrtcResult (*F_nvrtcGetProgramLogSize)(nvrtcProgram prog, size_t* logSizeRet) nogil # NOQA
cdef F_nvrtcGetProgramLogSize nvrtcGetProgramLogSize

ctypedef nvrtcResult (*F_nvrtcGetProgramLog)(nvrtcProgram prog, char* log) nogil # NOQA
cdef F_nvrtcGetProgramLog nvrtcGetProgramLog

ctypedef nvrtcResult (*F_nvrtcAddNameExpression)(nvrtcProgram, const char*) nogil # NOQA
cdef F_nvrtcAddNameExpression nvrtcAddNameExpression

ctypedef nvrtcResult (*F_nvrtcGetLoweredName)(nvrtcProgram, const char*, const char**) nogil # NOQA
cdef F_nvrtcGetLoweredName nvrtcGetLoweredName

ctypedef nvrtcResult (*F_nvrtcGetNumSupportedArchs)(int* numArchs) nogil
cdef F_nvrtcGetNumSupportedArchs nvrtcGetNumSupportedArchs

ctypedef nvrtcResult (*F_nvrtcGetSupportedArchs)(int* supportedArchs) nogil
cdef F_nvrtcGetSupportedArchs nvrtcGetSupportedArchs

ctypedef nvrtcResult (*F_nvrtcGetNVVMSize)(nvrtcProgram prog, size_t *nvvmSizeRet) nogil # NOQA
cdef F_nvrtcGetNVVMSize nvrtcGetNVVMSize

ctypedef nvrtcResult (*F_nvrtcGetNVVM)(nvrtcProgram prog, char *nvvm) nogil
cdef F_nvrtcGetNVVM nvrtcGetNVVM


cdef SoftLink _L = None
cdef void initialize():
kmaehashi marked this conversation as resolved.
Show resolved Hide resolved
global _L
if _L is not None:
return
_L = _get_softlink()

global nvrtcGetErrorString
nvrtcGetErrorString = <F_nvrtcGetErrorString>_L.get('GetErrorString')
global nvrtcVersion
nvrtcVersion = <F_nvrtcVersion>_L.get('Version')
global nvrtcCreateProgram
nvrtcCreateProgram = <F_nvrtcCreateProgram>_L.get('CreateProgram')
global nvrtcDestroyProgram
nvrtcDestroyProgram = <F_nvrtcDestroyProgram>_L.get('DestroyProgram')
global nvrtcCompileProgram
nvrtcCompileProgram = <F_nvrtcCompileProgram>_L.get('CompileProgram')
global nvrtcGetPTXSize
nvrtcGetPTXSize = <F_nvrtcGetPTXSize>_L.get('GetPTXSize' if _L.prefix == 'nvrtc' else 'GetCodeSize') # NOQA
global nvrtcGetPTX
nvrtcGetPTX = <F_nvrtcGetPTX>_L.get('GetPTX' if _L.prefix == 'nvrtc' else 'GetCode') # NOQA
global nvrtcGetCUBINSize
nvrtcGetCUBINSize = <F_nvrtcGetCUBINSize>_L.get('GetCUBINSize')
global nvrtcGetCUBIN
nvrtcGetCUBIN = <F_nvrtcGetCUBIN>_L.get('GetCUBIN')
global nvrtcGetProgramLogSize
nvrtcGetProgramLogSize = <F_nvrtcGetProgramLogSize>_L.get('GetProgramLogSize') # NOQA
global nvrtcGetProgramLog
nvrtcGetProgramLog = <F_nvrtcGetProgramLog>_L.get('GetProgramLog')
global nvrtcAddNameExpression
nvrtcAddNameExpression = <F_nvrtcAddNameExpression>_L.get('AddNameExpression') # NOQA
global nvrtcGetLoweredName
nvrtcGetLoweredName = <F_nvrtcGetLoweredName>_L.get('GetLoweredName')
global nvrtcGetNumSupportedArchs
nvrtcGetNumSupportedArchs = <F_nvrtcGetNumSupportedArchs>_L.get('GetNumSupportedArchs') # NOQA
global nvrtcGetSupportedArchs
nvrtcGetSupportedArchs = <F_nvrtcGetSupportedArchs>_L.get('GetSupportedArchs') # NOQA
global nvrtcGetNVVMSize
nvrtcGetNVVMSize = <F_nvrtcGetNVVMSize>_L.get('GetNVVMSize')
global nvrtcGetNVVM
nvrtcGetNVVM = <F_nvrtcGetNVVM>_L.get('GetNVVM')


cdef SoftLink _get_softlink():
cdef int runtime_version
cdef str prefix = 'nvrtc'
cdef object libname = None

if CUPY_CUDA_VERSION != 0:
runtime_version = runtime.runtimeGetVersion()
if 11020 <= runtime_version < 12000:
# CUDA 11.x (11.2 )
if _sys.platform == 'linux':
libname = 'libnvrtc.so.11.2'
else:
libname = 'nvrtc64_112_0.dll'
elif 12000 <= runtime_version < 13000:
# CUDA 12.x
if _sys.platform == 'linux':
libname = 'libnvrtc.so.12'
else:
libname = 'nvrtc64_120_0.dll'
elif CUPY_HIP_VERSION != 0:
runtime_version = runtime.runtimeGetVersion()
kmaehashi marked this conversation as resolved.
Show resolved Hide resolved
prefix = 'hiprtc'
if runtime_version < 5_00_00000:
# ROCm 4.x
libname = 'libamdhip64.so.4'
elif runtime_version < 6_00_00000:
# ROCm 5.x
libname = 'libamdhip64.so.5'

return SoftLink(libname, prefix, mandatory=True)
6 changes: 0 additions & 6 deletions cupy_backends/cuda/libs/nvrtc.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 9,7 @@ IF CUPY_USE_CUDA_PYTHON:
from cuda.cnvrtc cimport *
# Aliases for compatibillity with existing CuPy codebase.
# TODO(kmaehashi): Remove these aliases.
ctypedef nvrtcResult Result
ctypedef nvrtcProgram Program
ELSE:
cdef extern from *:
ctypedef int Result 'nvrtcResult'
ctypedef void* Program 'nvrtcProgram'


cpdef check_status(int status)

Expand Down
Loading