merge sort
本篇介绍 cuda samples 中的 mergeSort.
大体上来讲, mergeSort 分为两个阶段.
- 对含有
SHARED_SIZE_LIMIT
(即 1024) 个元素的数组进行排序. - 合并多个有序数组.
其中第一个阶段调用一次函数 mergeSortShared
结束.
而第二个阶段需要循环调用三个函数: generateSampleRanks
, mergeRanksAndIndices
和 mergeElementaryIntervals
.
我将分别讲述这两个阶段.
第一阶段: 段内排序
函数 mergeSortShared
比较简单,
就是对完整的数据执行 mergeSortSharedKernel
函数.
其函数声明为
1 | template <uint sortDir> |
该 kernel 函数对一个长为 arrayLength
的数组排序,
在代码中 arrayLength
接收的参数为宏变量 SHARED_SIZE_LIMIT
(即 1024).
调用该函数的 grid 和 block 大小分别为 N / arrayLength
和 arrayLength / 2
.
每个 block 排序一段长为 arrayLength
的数组.
在函数 mergeSortSharedKernel
中排序通过一个循环进行.
该循环在长度为 stride
的数组已经排序基础上,
将 2 个长为 stride
的数组进行合并.
因此 stride
的大小从 1 开始每次增大 1 倍直到 arrayLength / 2
.
在循环开始之前,
长度 stride
的数组已经是排序了的状态.
循环中数量为 stride
的 thread 为一组,
排序相邻的两个长为 stride
的数组,
记前一个数组为 A
,
后一个数组为 B
.
因此需要计算每一组 thread 内部的 thread ID 以及这组 thread 对应数据的起始位置.
1 | // 计算组内 thread ID, [0, thread) |
相对应的,
每个 thread 在两个数组中都有一个对应的数据,
分别为 keyA
和 keyB
.
1 | uint keyA = baseKey[lPos + 0]; |
接下来找到 keyA
和 keyB
在两个数组合并后的位置.
因为数组 A
和数组 B
都是排序的,
因此在数组 A
中有 lPos
个数比 keyA
小,
只需要找到在数组 B
中有多少个数比 keyA
小,
两者相加就得到了 keyA
合并后的位置.
这个寻找的过程可以通过二分法查找.
keyB
同理也能找到合并后的位置.
但是这里有个问题就是某个 thread 对应的 keyA
和 keyB
相同,
那么这样计算得到的最后的位置也是相同的,
产生了冲突.
代码中解决这个可能的冲突的方法是 keyA
在数组 B
中寻找小于 keyA
的数,
keyB
在数组 A
中寻找小于等于 keyB
的数.
这样做的好处是保持了排序的稳定性.
数组 A
的元素总是排在数组 B
的相同元素之前.
1 | uint posA = binarySearchExclusive<sortDir>(keyA, baseKey + stride, stride, stride) + lPos; |
函数 binarySearchExclusive
和 binarySearchInclusive
有着相似的结构,
只有内部的微弱区别.
通过二分法查找 val
在数组中的位置,
即返回数组中小于 (binarySearchExclusive
) 或小于等于 (binarySearchInclusive) val
的元素的数量.
1 | // val 待查找的元素 |
merge sort 的第一个阶段就完成了.
该阶段将每 SHARED_SIZE_LIMIT
(即 1024) 个元素进行排序,
得到了 N / SHARED_SIZE_LIMIT
个有序数组.
第二阶段: 合并有序段
第二个阶段循环合并连续 2 个长为 stride
的有序数组.
因此 stride
的大小从 SHARED_SIZE_LIMIT
(即 1024) 开始每次增大 1 倍直到 N
.
这个阶段主要分成三个部分: 生成排序, 合并排序和合并.
也就是调用三个函数 generateSampleRanks
, mergeRanksAndIndices
和 mergeElementaryIntervals
.
思路解析
要合并 2 个连续的有序数组,
当然可以像第一阶段一样启动足够多的线程查找每个元素合并后的位置.
但是这样的方法在后期的时延会增加很多,
因为 stride
长度倍增.
这里采用的方法是先分组再查找.
要理解这部分的代码,
需要理解数组 d_RanksA
, d_RanksB
, d_LimitsA
和 d_LimitsB
.
这个会在后面详细讲解.
先把 stride
分为大小为 SAMPLE_STRIDE
的子数组.
d_RanksA
表示 2 个 stride
中所有的子数组的起始元素在 stride A
中的位置,
即 A
中小于 (或小于等于) 某个起始元素的元素数.
这一部分对应的就是函数 generateSampleRanks
的部分.
接着把 d_RanksA
排序,
将排序后的值填入对应的 d_LimitsA
中.
这一部分对应的就是函数 mergeRanksAndIndices
的部分.
经过前两步,
就把 A
和 B
每个都分为了 个范围.
之后合并对应的范围就得到 2 个 stride
合并后的结果.
这一部分对应的就是函数 mergeElementaryIntervals
.
代码解析
generateSamplesRanks
这一步生成 d_RanksA
和 d_RanksB
.
通过 kernel 函数 generateSampleRanksKernel
实现的.
无论 stride
大小,
该 kernel 函数的总线程数都不变.
就是 .
1 | // 计算总线程数量 |
调用函数 generateSampleRanksKernel
的 grid 大小为 threadCount / 256
,
block 大小为 256.
在该 kernel 函数中,
把所有的线程分组.
每 stride / SAMPLE_STRIDE
个线程组成一组.
因此,
每个线程的全局 ID 和组内 ID 可以通过如下方式计算.
1 | // 线程的全局 ID |
每个 thread 对应两个长为 SAMPLE_STRIDE
的数组以及 d_RanksA
和 d_RanksB
中的两个元素,
就可以计算每个线程组的数据偏移量.
1 | // 计算每个线程组的数据偏移量 |
每个线程组对应连续 2 个长为 stride
的数组,
前一个数组称为数组 A
,
后一个称为 B
.
每个线程在 A
和 B
中都有对应的一个长为 SAMPLE_STRIDE
的子数组.
就可以计算得到 A
和 B
的长度以及子数组的个数.
1 | // A 和 B 的长度 |
最后求在 A
和 B
两个数组中,
小于 (或者小于等于) 每个子数组起始元素的元素数.
这也是通过 binarySearchExclusive
和 binarySearchInclusive
两个 kernel 函数完成.
1 | if (i < segmentSamplesA) { |
mergeRanksAndIndices
这一步把 d_RanksA
和 d_RanksB
中的内容排序后赋值给 d_LimitsA
和 d_LimitsB
.
对 A
和 B
分别调用 kernel 函数 mergeRanksAndIndicesKernel
,
通过 d_RanksA
生成 d_LimitsA
.
该 kernel 函数有大小 <<<threadCount / 256, 256>>>
.
其中 threadCount = N / (2 * SAMPLE_STRIDE)
.
同样 stride / SAMPLE_STRIDE
个 thread 为一组.
计算组内 ID 和数组偏移同之前一样.
1 | // thread ID |
需要将 d_RanksA
排序.
对 A
中的每个 d_RanksA
查找其在 B
中的位置.
1 | if (i < segmentSamplesA) { |
同样,
对 B
中的每个 d_RanksA
查找其在 A
中的位置.
1 | if (i < segmentSamplesA) { |
经过排序,
实际上是把一个 stride
分为了 2 * stride / SAMPLE_STRIDE
个部分.
划分的标准是每个 SAMPLE_STRIDE
的起始元素在 A
或 B
中的位置.
mergeElementaryIntervals
这一步合并 2 个 stride
.
调用 kernel 函数 mergeElementaryIntervalsKernel
,
其大小为 <<<N / SAMPLE_STRIDE, SAMPLE_STRIDE>>>
.
之前已经把每个 stride
分为了 2 * SAMPLE_STRIDE
个部分.
每个 block 负责 2 个 stride
中对应的一部分,
将他们合并.
首先是把所有的 block 分组,
每组有 2 * stride / SAMPLE_STRIDE
个 block.
再计算组内 ID 和数据偏移.
1 | // 组内 ID |
还需要计算合并后的起始位置, 该 block 的合并长度等数据.
1 | __shared__ uint startSrcA, startSrcB, lenSrcA, lenSrcB, startDstA, startDstB; |
把 global 内存中的数据加载到 shared 内存.
1 | if (threadIdx.x < lenSrcA) { |
调用 merge 函数合并 s_key
中的数据.
merge
函数比较简单,
找到每个元素合并后的位置.
最后把数据存到 global 内存的对应位置.
1 | if (threadIdx.x < lenSrcA) { |