Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

[WIP] Allow cub::DeviceRadixSort and cub::DeviceSegmentedRadixSort to use iterator as input #374

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

zasdfgbnm
Copy link
Contributor

@zasdfgbnm zasdfgbnm commented Sep 13, 2021

Comment on lines 1377 to 1515
KeyInIterT d_keys_in_ = d_keys_out;
ValueInIterT d_values_in_ = d_values_out;
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>>
(d_lookback, d_ctrs part * num_passes pass,
part < num_parts - 1 ?
d_bins ((part 1) * num_passes pass) * RADIX_DIGITS : NULL,
d_bins (part * num_passes pass) * RADIX_DIGITS,
d_keys_out_,
d_keys_in_ part * PART_SIZE,
d_values_out_,
d_values_in_ part * PART_SIZE,
part_num_items, current_bit, num_bits);
break;
}
}
} else {
using KeyOutIterT = KeyIteratorT;
using ValueOutIterT = ValueIteratorT;
KeyOutIterT d_keys_out_ = d_keys_out;
ValueOutIterT d_values_out_ = d_values_out;
switch (input_mode) {
case INPUT: {
using KeyInIterT = KeyInputIteratorT;
using ValueInIterT = ValueInputIteratorT;
auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
ValueInIterT, ValueOutIterT, OffsetT>;
KeyInIterT d_keys_in_ = d_keys_in;
ValueInIterT d_values_in_ = d_values_in;
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>>
(d_lookback, d_ctrs part * num_passes pass,
part < num_parts - 1 ?
d_bins ((part 1) * num_passes pass) * RADIX_DIGITS : NULL,
d_bins (part * num_passes pass) * RADIX_DIGITS,
d_keys_out_,
d_keys_in_ part * PART_SIZE,
d_values_out_,
d_values_in_ part * PART_SIZE,
part_num_items, current_bit, num_bits);
break;
}
case TMP_STORAGE: {
using KeyInIterT = KeyT *;
using ValueInIterT = ValueT *;
auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
ValueInIterT, ValueOutIterT, OffsetT>;
KeyInIterT d_keys_in_ = d_keys_tmp;
ValueInIterT d_values_in_ = d_values_tmp;
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>>
(d_lookback, d_ctrs part * num_passes pass,
part < num_parts - 1 ?
d_bins ((part 1) * num_passes pass) * RADIX_DIGITS : NULL,
d_bins (part * num_passes pass) * RADIX_DIGITS,
d_keys_out_,
d_keys_in_ part * PART_SIZE,
d_values_out_,
d_values_in_ part * PART_SIZE,
part_num_items, current_bit, num_bits);
break;
}
case OUTPUT: {
using KeyInIterT = KeyIteratorT;
using ValueInIterT = ValueIteratorT;
auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
ValueInIterT, ValueOutIterT, OffsetT>;
KeyInIterT d_keys_in_ = d_keys_out;
ValueInIterT d_values_in_ = d_values_out;
onesweep_kernel<<<num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream>>>
(d_lookback, d_ctrs part * num_passes pass,
part < num_parts - 1 ?
d_bins ((part 1) * num_passes pass) * RADIX_DIGITS : NULL,
d_bins (part * num_passes pass) * RADIX_DIGITS,
d_keys_out_,
d_keys_in_ part * PART_SIZE,
d_values_out_,
d_values_in_ part * PART_SIZE,
part_num_items, current_bit, num_bits);
break;
}
}
}
if (CubDebug(error = cudaPeekAtLastError())) break;
}

// use the temporary buffers if no overwrite is allowed
if (!is_overwrite_okay && pass == 0)
{
d_keys = num_passes % 2 == 0 ?
DoubleBuffer<KeyT>(d_keys_tmp, d_keys_tmp2) :
DoubleBuffer<KeyT>(d_keys_tmp2, d_keys_tmp);
d_values = num_passes % 2 == 0 ?
DoubleBuffer<ValueT>(d_values_tmp, d_values_tmp2) :
DoubleBuffer<ValueT>(d_values_tmp2, d_values_tmp);
}
d_keys.selector ^= 1;
d_values.selector ^= 1;
input_mode = output_is_tmp ? TMP_STORAGE : OUTPUT;
output_is_tmp = !output_is_tmp;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@allisonvacanti Before I continue my work, I'd like to hear your feedback about this part. In order to support iterators, I have to add this verbose logic because the type of input iter, type of tmp storage, and type of output iter can be different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if you are OK with this change, but this is very verbose, and I can not think of a better solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to restore the is_overwrite_okay optimization first and then we can take a closer look at this part.

num_items,
begin_bit,
end_bit,
is_overwrite_okay,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_overwrite_okay is removed, with iterator support, it never overwrites.

@alliepiper alliepiper self-assigned this Sep 21, 2021
@alliepiper alliepiper added this to the 1.15.0 milestone Sep 21, 2021
Copy link
Collaborator

@alliepiper alliepiper left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch! I'd like to get this functionality into CUB, but the proposed implementation will be too disruptive, since there is no way for a user to avoid the large temporary storage allocations needed to hold the intermediate keys/values. When users call the DoubleBuffer APIs, they are explicitly providing the scratch space needed to hold these intermediate results -- we need to use this memory instead of reallocating it.

I definitely want to merge this, but we'll need to preserve the is_overwrite_okay optimization and only allocate the extra scratch memory when it is absolutely necessary.

@@ -1281,9 1305,9 @@ struct DispatchRadixSort :
// lookback
max_num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT),
// extra key buffer
is_overwrite_okay || num_passes <= 1 ? 0 : num_items * sizeof(KeyT),
num_passes <= 1 ? 0 : num_items * sizeof(KeyT),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will drastically increase the amount of temporary storage needed for some invocations of radix sort. We need to keep this optimization in place for the DoubleBuffer overloads in the Device*RadixSort APIs, since folks specifically use those to reduce the temporary storage allocations.

INPUT,
TMP_STORAGE,
OUTPUT
} input_mode = INPUT;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: This should be split into separate declarations:

enum InputMode { ... };
InputMode input_mode = INPUT;

@@ -1822,8 1966,8 @@ struct DispatchSegmentedRadixSort :
void* allocations[2] = {};
size_t allocation_sizes[2] =
{
(is_overwrite_okay) ? 0 : num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer
(is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer
num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, we need to keep this optimization in place.

@alliepiper alliepiper assigned zasdfgbnm and unassigned alliepiper Sep 30, 2021
@alliepiper
Copy link
Collaborator

Also -- @senior-zero is adding a new segmented sort implementation that uses iterator instead of pointers, see #357. It preserves the double buffer optimizations, so it may be a useful reference.

@zasdfgbnm
Copy link
Contributor Author

@allisonvacanti Looks like in #357, iterators are only supported for offsets, for input and output keys and values, they still need to be pointers?

@alliepiper
Copy link
Collaborator

@zasdfgbnm My mistake -- you are correct. We don't have an example that does the switching. I checked the new merge sort implementation and it also copies the keys/values unconditionally. But that was a new algorithm, so it's less concerning.

I'd still like to update this PR to only allocate the extra temporary storage when the iterators aren't pointers to avoid changing the requirements of this algorithm -- it should be possible to have special logic for the first pass that reads from the iterators instead of the temp storage buffers. Let me know if you'd like to discuss this more.

@alliepiper
Copy link
Collaborator

I'll be starting the 1.15 RC next week, and it looks like this will take a bit more work to be ready. Bumping to 1.16 milestone -- let me know if you plan to finish this before Monday and we can keep it at 1.15.

@alliepiper alliepiper modified the milestones: 1.15.0, 1.16.0 Oct 14, 2021
@alliepiper alliepiper added helps: pytorch Helps or needed by PyTorch. P1: should have Necessary, but not critical. labels Oct 14, 2021
@zasdfgbnm
Copy link
Contributor Author

1.16 is fine to me

@alliepiper alliepiper modified the milestones: 1.16.0, 1.17.0 Feb 7, 2022
@alliepiper alliepiper added P3: backlog Unprioritized and removed P1: should have Necessary, but not critical. labels Apr 6, 2022
@alliepiper alliepiper modified the milestones: 1.17.0, Backlog Apr 25, 2022
@cliffburdick
Copy link

Hi @zasdfgbnm , are there any updates on this?

@zasdfgbnm
Copy link
Contributor Author

@cliffburdick No, I am not working on this any more.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
helps: pytorch Helps or needed by PyTorch. P3: backlog Unprioritized
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow iterators in cub::DeviceRadixSort
3 participants