-
-
Notifications
You must be signed in to change notification settings - Fork 863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Soft link NVRTC for cupy_backends.cuda.libs.nvrtc
#7621
Changes from 16 commits
d215dcc
e21ad39
5526d8a
2ec7e33
897b553
3adf560
a731cdd
fe17225
ffe14c5
3cf7831
c1890dc
0613c47
be52b85
62d1a38
80d13fa
a665de2
9b8857d
947fc74
97250fe
6688747
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 1,4 @@ | ||
[flake8] | ||
filename = *.pyx, *.pxd, *.pxi | ||
exclude = .git, .eggs, *.py | ||
ignore = W503,W504,E225,E226,E227,E275,E402,E999 | ||
ignore = W503,W504,E211,E225,E226,E227,E275,E402,E999 | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 36,7 @@ repos: | |
rev: v0.15.0 | ||
hooks: | ||
- id: cython-lint | ||
args: ["--max-line-length", "79"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same limitation as current coding standard (.flake8.cython) |
||
|
||
- repo: https://github.com/pre-commit/mirrors-mypy | ||
rev: v1.4.1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,24 6,32 @@ cimport cython | |
|
||
|
||
cdef class SoftLink: | ||
def __init__(self, object libname, str prefix): | ||
def __init__(self, object libname, str prefix, *, bint mandatory=False): | ||
self.error = None | ||
self.prefix = prefix | ||
self._cdll = None | ||
if libname is not None: | ||
if libname is None: | ||
# Stub build or CUDA/HIP only library. | ||
self.error = RuntimeError( | ||
'The library is unavailable in the current platform.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice!! |
||
else: | ||
try: | ||
self._cdll = ctypes.CDLL(libname) | ||
except Exception as e: | ||
warnings.warn( | ||
f'Warning: CuPy failed to load "{libname}": ' | ||
f'({type(e).__name__}: {e})') | ||
self._prefix = prefix | ||
self.error = e | ||
msg = ( | ||
f'CuPy failed to load {libname}: {type(e).__name__}: {e}') | ||
if mandatory: | ||
raise RuntimeError(msg) from e | ||
warnings.warn(msg) | ||
|
||
cdef func_ptr get(self, str name): | ||
""" | ||
Returns a function pointer for the API. | ||
""" | ||
if self._cdll is None: | ||
return <func_ptr>_fail_unsupported | ||
cdef str funcname = f'{self._prefix}{name}' | ||
cdef str funcname = f'{self.prefix}{name}' | ||
cdef object func = getattr(self._cdll, funcname, None) | ||
if func is None: | ||
return <func_ptr>_fail_not_found | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,139 @@ | ||
import sys as _sys | ||
|
||
from cupy_backends.cuda.api cimport runtime | ||
from cupy_backends.cuda._softlink cimport SoftLink | ||
|
||
|
||
ctypedef int nvrtcResult | ||
ctypedef void* nvrtcProgram | ||
# TODO(kmaehashi): Remove this alias. | ||
ctypedef nvrtcProgram Program | ||
|
||
ctypedef const char* (*F_nvrtcGetErrorString)(nvrtcResult result) nogil | ||
cdef F_nvrtcGetErrorString nvrtcGetErrorString | ||
|
||
ctypedef nvrtcResult (*F_nvrtcVersion)(int *major, int *minor) nogil | ||
cdef F_nvrtcVersion nvrtcVersion | ||
|
||
ctypedef nvrtcResult (*F_nvrtcCreateProgram)( | ||
nvrtcProgram* prog, const char* src, const char* name, int numHeaders, | ||
const char** headers, const char** includeNames) nogil | ||
cdef F_nvrtcCreateProgram nvrtcCreateProgram | ||
|
||
ctypedef nvrtcResult (*F_nvrtcDestroyProgram)(nvrtcProgram *prog) nogil | ||
cdef F_nvrtcDestroyProgram nvrtcDestroyProgram | ||
|
||
ctypedef nvrtcResult (*F_nvrtcCompileProgram)( | ||
nvrtcProgram prog, int numOptions, const char** options) nogil | ||
Comment on lines
26
to
27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
int numOptions, const char * const *options); |
||
cdef F_nvrtcCompileProgram nvrtcCompileProgram | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetPTXSize)(nvrtcProgram prog, size_t *ptxSizeRet) nogil # NOQA | ||
cdef F_nvrtcGetPTXSize nvrtcGetPTXSize | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetPTX)(nvrtcProgram prog, char *ptx) nogil | ||
cdef F_nvrtcGetPTX nvrtcGetPTX | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetCUBINSize)(nvrtcProgram prog, size_t *cubinSizeRet) nogil # NOQA | ||
cdef F_nvrtcGetCUBINSize nvrtcGetCUBINSize | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetCUBIN)(nvrtcProgram prog, char *cubin) nogil | ||
cdef F_nvrtcGetCUBIN nvrtcGetCUBIN | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetProgramLogSize)(nvrtcProgram prog, size_t* logSizeRet) nogil # NOQA | ||
cdef F_nvrtcGetProgramLogSize nvrtcGetProgramLogSize | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetProgramLog)(nvrtcProgram prog, char* log) nogil # NOQA | ||
cdef F_nvrtcGetProgramLog nvrtcGetProgramLog | ||
|
||
ctypedef nvrtcResult (*F_nvrtcAddNameExpression)(nvrtcProgram, const char*) nogil # NOQA | ||
cdef F_nvrtcAddNameExpression nvrtcAddNameExpression | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetLoweredName)(nvrtcProgram, const char*, const char**) nogil # NOQA | ||
cdef F_nvrtcGetLoweredName nvrtcGetLoweredName | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetNumSupportedArchs)(int* numArchs) nogil | ||
cdef F_nvrtcGetNumSupportedArchs nvrtcGetNumSupportedArchs | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetSupportedArchs)(int* supportedArchs) nogil | ||
cdef F_nvrtcGetSupportedArchs nvrtcGetSupportedArchs | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetNVVMSize)(nvrtcProgram prog, size_t *nvvmSizeRet) nogil # NOQA | ||
cdef F_nvrtcGetNVVMSize nvrtcGetNVVMSize | ||
|
||
ctypedef nvrtcResult (*F_nvrtcGetNVVM)(nvrtcProgram prog, char *nvvm) nogil | ||
cdef F_nvrtcGetNVVM nvrtcGetNVVM | ||
|
||
|
||
cdef SoftLink _L = None | ||
cdef void initialize(): | ||
kmaehashi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
global _L | ||
if _L is not None: | ||
return | ||
_L = _get_softlink() | ||
|
||
global nvrtcGetErrorString | ||
nvrtcGetErrorString = <F_nvrtcGetErrorString>_L.get('GetErrorString') | ||
global nvrtcVersion | ||
nvrtcVersion = <F_nvrtcVersion>_L.get('Version') | ||
global nvrtcCreateProgram | ||
nvrtcCreateProgram = <F_nvrtcCreateProgram>_L.get('CreateProgram') | ||
global nvrtcDestroyProgram | ||
nvrtcDestroyProgram = <F_nvrtcDestroyProgram>_L.get('DestroyProgram') | ||
global nvrtcCompileProgram | ||
nvrtcCompileProgram = <F_nvrtcCompileProgram>_L.get('CompileProgram') | ||
global nvrtcGetPTXSize | ||
nvrtcGetPTXSize = <F_nvrtcGetPTXSize>_L.get('GetPTXSize' if _L.prefix == 'nvrtc' else 'GetCodeSize') # NOQA | ||
global nvrtcGetPTX | ||
nvrtcGetPTX = <F_nvrtcGetPTX>_L.get('GetPTX' if _L.prefix == 'nvrtc' else 'GetCode') # NOQA | ||
global nvrtcGetCUBINSize | ||
nvrtcGetCUBINSize = <F_nvrtcGetCUBINSize>_L.get('GetCUBINSize') | ||
global nvrtcGetCUBIN | ||
nvrtcGetCUBIN = <F_nvrtcGetCUBIN>_L.get('GetCUBIN') | ||
global nvrtcGetProgramLogSize | ||
nvrtcGetProgramLogSize = <F_nvrtcGetProgramLogSize>_L.get('GetProgramLogSize') # NOQA | ||
global nvrtcGetProgramLog | ||
nvrtcGetProgramLog = <F_nvrtcGetProgramLog>_L.get('GetProgramLog') | ||
global nvrtcAddNameExpression | ||
nvrtcAddNameExpression = <F_nvrtcAddNameExpression>_L.get('AddNameExpression') # NOQA | ||
global nvrtcGetLoweredName | ||
nvrtcGetLoweredName = <F_nvrtcGetLoweredName>_L.get('GetLoweredName') | ||
global nvrtcGetNumSupportedArchs | ||
nvrtcGetNumSupportedArchs = <F_nvrtcGetNumSupportedArchs>_L.get('GetNumSupportedArchs') # NOQA | ||
global nvrtcGetSupportedArchs | ||
nvrtcGetSupportedArchs = <F_nvrtcGetSupportedArchs>_L.get('GetSupportedArchs') # NOQA | ||
global nvrtcGetNVVMSize | ||
nvrtcGetNVVMSize = <F_nvrtcGetNVVMSize>_L.get('GetNVVMSize') | ||
global nvrtcGetNVVM | ||
nvrtcGetNVVM = <F_nvrtcGetNVVM>_L.get('GetNVVM') | ||
|
||
|
||
cdef SoftLink _get_softlink(): | ||
cdef int runtime_version | ||
cdef str prefix = 'nvrtc' | ||
cdef object libname = None | ||
|
||
if CUPY_CUDA_VERSION != 0: | ||
runtime_version = runtime.runtimeGetVersion() | ||
if 11020 <= runtime_version < 12000: | ||
# CUDA 11.x (11.2 ) | ||
if _sys.platform == 'linux': | ||
libname = 'libnvrtc.so.11.2' | ||
else: | ||
libname = 'nvrtc64_112_0.dll' | ||
elif 12000 <= runtime_version < 13000: | ||
# CUDA 12.x | ||
if _sys.platform == 'linux': | ||
libname = 'libnvrtc.so.12' | ||
else: | ||
libname = 'nvrtc64_120_0.dll' | ||
elif CUPY_HIP_VERSION != 0: | ||
runtime_version = runtime.runtimeGetVersion() | ||
kmaehashi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prefix = 'hiprtc' | ||
if runtime_version < 5_00_00000: | ||
# ROCm 4.x | ||
libname = 'libamdhip64.so.4' | ||
elif runtime_version < 6_00_00000: | ||
# ROCm 5.x | ||
libname = 'libamdhip64.so.5' | ||
|
||
return SoftLink(libname, prefix, mandatory=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E211 added to avoid
Whitespace before '('
error: