-
Notifications
You must be signed in to change notification settings - Fork 53
/
mbarrier.cu
102 lines (91 loc) · 3.03 KB
/
mbarrier.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
// Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-barrier
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier
// https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/copy_sm90_desc.hpp
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
namespace mbarrier {
__device__ inline void init(
uint32_t smem_barrier_ptr,
uint32_t thread_count = 1) {
asm volatile(
"mbarrier.init.shared.b64 [%0], %1;\n" ::"r"(smem_barrier_ptr),
"r"(thread_count));
}
__device__ inline void inval(uint32_t smem_barrier_ptr) {
asm volatile("mbarrier.inval.shared.b64 [%0];\n" ::"r"(smem_barrier_ptr));
}
__device__ inline uint64_t arrive(uint32_t smem_barrier_ptr) {
volatile uint64_t state;
asm volatile("mbarrier.arrive.shared.b64 %0, [%1];\n"
: "=l"(state)
: "r"(smem_barrier_ptr));
return state;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__device__ inline uint64_t arriveExpectTX(
uint32_t smem_barrier_ptr,
uint32_t tx_count) {
volatile uint64_t state;
asm volatile("mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;\n"
: "=l"(state)
: "r"(smem_barrier_ptr), "r"(tx_count));
return state;
}
#endif
__device__ inline void wait(uint32_t smem_barrier_ptr, uint64_t state) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile(
"{\n"
".reg .pred complete;\n"
"waitLoop:\n"
"mbarrier.try_wait.shared.b64 complete, [%0], %1;\n"
"@!complete bra waitLoop;\n"
"}\n" ::"r"(smem_barrier_ptr),
"l"(state));
#else
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.test_wait.shared.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"nanosleep.u32 20;\n"
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(smem_barrier_ptr),
"l"(state));
#endif
}
__device__ inline void waitParity(uint32_t smem_barrier_ptr, uint32_t parity) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile(
"{\n"
".reg .pred complete;\n"
"waitLoop:\n"
"mbarrier.try_wait.parity.shared.b64 complete, [%0], %1;\n"
"@!complete bra waitLoop;\n"
"}\n" ::"r"(smem_barrier_ptr),
"r"(parity));
#else
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.test_wait.parity.shared.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"nanosleep.u32 20;\n"
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(smem_barrier_ptr),
"r"(parity));
#endif
}
} // namespace mbarrier
#endif // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))