-
Notifications
You must be signed in to change notification settings - Fork 447
[WIP] Allow cub::DeviceRadixSort and cub::DeviceSegmentedRadixSort to use iterator as input #374
base: main
Are you sure you want to change the base?
Conversation
KeyInIterT d_keys_in_ = d_keys_out; | ||
ValueInIterT d_values_in_ = d_values_out; | ||
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>> | ||
(d_lookback, d_ctrs part * num_passes pass, | ||
part < num_parts - 1 ? | ||
d_bins ((part 1) * num_passes pass) * RADIX_DIGITS : NULL, | ||
d_bins (part * num_passes pass) * RADIX_DIGITS, | ||
d_keys_out_, | ||
d_keys_in_ part * PART_SIZE, | ||
d_values_out_, | ||
d_values_in_ part * PART_SIZE, | ||
part_num_items, current_bit, num_bits); | ||
break; | ||
} | ||
} | ||
} else { | ||
using KeyOutIterT = KeyIteratorT; | ||
using ValueOutIterT = ValueIteratorT; | ||
KeyOutIterT d_keys_out_ = d_keys_out; | ||
ValueOutIterT d_values_out_ = d_values_out; | ||
switch (input_mode) { | ||
case INPUT: { | ||
using KeyInIterT = KeyInputIteratorT; | ||
using ValueInIterT = ValueInputIteratorT; | ||
auto onesweep_kernel = DeviceRadixSortOnesweepKernel< | ||
MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, | ||
ValueInIterT, ValueOutIterT, OffsetT>; | ||
KeyInIterT d_keys_in_ = d_keys_in; | ||
ValueInIterT d_values_in_ = d_values_in; | ||
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>> | ||
(d_lookback, d_ctrs part * num_passes pass, | ||
part < num_parts - 1 ? | ||
d_bins ((part 1) * num_passes pass) * RADIX_DIGITS : NULL, | ||
d_bins (part * num_passes pass) * RADIX_DIGITS, | ||
d_keys_out_, | ||
d_keys_in_ part * PART_SIZE, | ||
d_values_out_, | ||
d_values_in_ part * PART_SIZE, | ||
part_num_items, current_bit, num_bits); | ||
break; | ||
} | ||
case TMP_STORAGE: { | ||
using KeyInIterT = KeyT *; | ||
using ValueInIterT = ValueT *; | ||
auto onesweep_kernel = DeviceRadixSortOnesweepKernel< | ||
MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, | ||
ValueInIterT, ValueOutIterT, OffsetT>; | ||
KeyInIterT d_keys_in_ = d_keys_tmp; | ||
ValueInIterT d_values_in_ = d_values_tmp; | ||
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>> | ||
(d_lookback, d_ctrs part * num_passes pass, | ||
part < num_parts - 1 ? | ||
d_bins ((part 1) * num_passes pass) * RADIX_DIGITS : NULL, | ||
d_bins (part * num_passes pass) * RADIX_DIGITS, | ||
d_keys_out_, | ||
d_keys_in_ part * PART_SIZE, | ||
d_values_out_, | ||
d_values_in_ part * PART_SIZE, | ||
part_num_items, current_bit, num_bits); | ||
break; | ||
} | ||
case OUTPUT: { | ||
using KeyInIterT = KeyIteratorT; | ||
using ValueInIterT = ValueIteratorT; | ||
auto onesweep_kernel = DeviceRadixSortOnesweepKernel< | ||
MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, | ||
ValueInIterT, ValueOutIterT, OffsetT>; | ||
KeyInIterT d_keys_in_ = d_keys_out; | ||
ValueInIterT d_values_in_ = d_values_out; | ||
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>> | ||
(d_lookback, d_ctrs part * num_passes pass, | ||
part < num_parts - 1 ? | ||
d_bins ((part 1) * num_passes pass) * RADIX_DIGITS : NULL, | ||
d_bins (part * num_passes pass) * RADIX_DIGITS, | ||
d_keys_out_, | ||
d_keys_in_ part * PART_SIZE, | ||
d_values_out_, | ||
d_values_in_ part * PART_SIZE, | ||
part_num_items, current_bit, num_bits); | ||
break; | ||
} | ||
} | ||
} | ||
if (CubDebug(error = cudaPeekAtLastError())) break; | ||
} | ||
|
||
// use the temporary buffers if no overwrite is allowed | ||
if (!is_overwrite_okay && pass == 0) | ||
{ | ||
d_keys = num_passes % 2 == 0 ? | ||
DoubleBuffer<KeyT>(d_keys_tmp, d_keys_tmp2) : | ||
DoubleBuffer<KeyT>(d_keys_tmp2, d_keys_tmp); | ||
d_values = num_passes % 2 == 0 ? | ||
DoubleBuffer<ValueT>(d_values_tmp, d_values_tmp2) : | ||
DoubleBuffer<ValueT>(d_values_tmp2, d_values_tmp); | ||
} | ||
d_keys.selector ^= 1; | ||
d_values.selector ^= 1; | ||
input_mode = output_is_tmp ? TMP_STORAGE : OUTPUT; | ||
output_is_tmp = !output_is_tmp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@allisonvacanti Before I continue my work, I'd like to hear your feedback about this part. In order to support iterators, I have to add this verbose logic because the type of input iter, type of tmp storage, and type of output iter can be different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if you are OK with this change, but this is very verbose, and I can not think of a better solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to restore the is_overwrite_okay
optimization first and then we can take a closer look at this part.
num_items, | ||
begin_bit, | ||
end_bit, | ||
is_overwrite_okay, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_overwrite_okay
is removed, with iterator support, it never overwrites.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch! I'd like to get this functionality into CUB, but the proposed implementation will be too disruptive, since there is no way for a user to avoid the large temporary storage allocations needed to hold the intermediate keys/values. When users call the DoubleBuffer
APIs, they are explicitly providing the scratch space needed to hold these intermediate results -- we need to use this memory instead of reallocating it.
I definitely want to merge this, but we'll need to preserve the is_overwrite_okay
optimization and only allocate the extra scratch memory when it is absolutely necessary.
@@ -1281,9 1305,9 @@ struct DispatchRadixSort : | |||
// lookback | |||
max_num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT), | |||
// extra key buffer | |||
is_overwrite_okay || num_passes <= 1 ? 0 : num_items * sizeof(KeyT), | |||
num_passes <= 1 ? 0 : num_items * sizeof(KeyT), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will drastically increase the amount of temporary storage needed for some invocations of radix sort. We need to keep this optimization in place for the DoubleBuffer
overloads in the Device*RadixSort
APIs, since folks specifically use those to reduce the temporary storage allocations.
INPUT, | ||
TMP_STORAGE, | ||
OUTPUT | ||
} input_mode = INPUT; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style: This should be split into separate declarations:
enum InputMode { ... };
InputMode input_mode = INPUT;
@@ -1822,8 1966,8 @@ struct DispatchSegmentedRadixSort : | |||
void* allocations[2] = {}; | |||
size_t allocation_sizes[2] = | |||
{ | |||
(is_overwrite_okay) ? 0 : num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer | |||
(is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer | |||
num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, we need to keep this optimization in place.
Also -- @senior-zero is adding a new segmented sort implementation that uses iterator instead of pointers, see #357. It preserves the double buffer optimizations, so it may be a useful reference. |
@allisonvacanti Looks like in #357, iterators are only supported for offsets, for input and output keys and values, they still need to be pointers? |
@zasdfgbnm My mistake -- you are correct. We don't have an example that does the switching. I checked the new merge sort implementation and it also copies the keys/values unconditionally. But that was a new algorithm, so it's less concerning. I'd still like to update this PR to only allocate the extra temporary storage when the iterators aren't pointers to avoid changing the requirements of this algorithm -- it should be possible to have special logic for the first pass that reads from the iterators instead of the temp storage buffers. Let me know if you'd like to discuss this more. |
I'll be starting the 1.15 RC next week, and it looks like this will take a bit more work to be ready. Bumping to 1.16 milestone -- let me know if you plan to finish this before Monday and we can keep it at 1.15. |
1.16 is fine to me |
Hi @zasdfgbnm , are there any updates on this? |
@cliffburdick No, I am not working on this any more. |
Fixes NVIDIA/cccl#868