Featured image of post HCCL通信库创新大赛记录——Ring AllReduce

HCCL通信库创新大赛记录——Ring AllReduce

HCCL集合通信库(Huawei Collective Communication Library,简称HCCL)是基于昇腾AI处理器的高性能通信库,聚焦于超大计算集群中的流量调度难题,为大集群提供高效可靠的通信服务,是华为AI软件生态CANN的核心组件之一,严重故障场景下可使用Ring AllReduce算法

赛题回顾

环境:单机8卡昇腾910B服务器一台

赛题内容:在8PMesh中,断掉一根Full Mesh直连的路径(Rank0和Rank1之间的一根链路),完成AllReduce的集合通信语义

评分指标:使用算法分析工具验证,语义实现正确;在物理环境上运行HCCL Test工具,考察1KB、1MB、1GB这三种数据量下的算法带宽的均值,根据三种数据量下的算法带宽作为评判依据

算法思路

Ring AllReduce是一种将数据划分为多个块,并在一个逻辑环上逐步传输与归约数据的算法,该算法的原理图如下所述:

通过扭转改表逻辑环上的节点顺序可以避开故障链路。这适用于极端情况下(如大量故障链路)应用,可通过哈密尔顿路径保障集群正常运作,且一定程度上保障性能。

算法设计

第一步(ReduceScatter)顺时针方向每轮迭代,每轮的流程如下:

  1. 通过TxAck通知右节点:“本节点已经将数据准备好了,随时可以从CCL_OUT拉取数据/向CCL_OUT推送数据”;
  2. 通过RxAck阻塞流,直至左节点也准备就绪(即左节点向本节点发送TxAck通知);
  3. 左节点准备就绪后,即可从左节点的CCL_OUT拉取数据,并聚合到本地的CCL_OUT;
  4. 完成数据的拉取、聚合后,通过TxDataSignal通知左节点:“本节点已经完成操作,你可以释放资源或进行其它工作了”;
  5. 通过RxDataSignal阻塞流,直至右节点也完成数据的拉取、聚合(即右节点向本节点发送TxDataSignal通知)。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
1. 通知右节点就绪
    links[rightRank]->TxAck() 
2. 等待左节点就绪
    links[leftRank]-> RxACK()
3. 从左节点拉取数据
    reduce(leftCclOut, localCclOut)
4. 通知左节点完成
    links[leftRank]->TxDataSignal()
5. 等待右节点完成
    links[rightRank]-> RxDataSignal()

第二步(AllGather)顺时针方向每轮迭代,每轮的流程如下:

  1. 通过TxAck通知右节点:“本节点已经将数据准备好了,随时可以从CCL_OUT拉取数据/向CCL_OUT推送数据”;
  2. 通过RxAck阻塞流,直至左节点也准备就绪(即左节点向本节点发送TxAck通知);
  3. 左节点准备就绪后,即可从左节点的CCL_OUT拉取数据,并拷贝到本地的CCL_OUT;
  4. 完成数据的拉取、聚合后,通过TxDataSignal通知左节点:“本节点已经完成操作,你可以释放资源或进行其它工作了”;
  5. 通过RxDataSignal阻塞流,直至右节点也完成数据的拉取、聚合(即右节点向本节点发送TxDataSignal通知)。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
1. 通知右节点就绪:
    links[rightRank]->TxAck() 
2. 等待左节点就绪:
    links[leftRank]-> RxACK()
3. 从左节点拉取数据:
    memcpy(leftCclOut, localCclOut)
4. 通知左节点完成:
    links[leftRank]->TxDataSignal()
5. 等待右节点完成:
    links[rightRank]-> RxDataSignal()

逆时针方向同理,即从右节点拉取、聚合数据。因此在各节点上,一条主流负责顺时针方向,一条从流负责逆时针方向。

编程实现

cc

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
#include "coll_custom_huge_all_reduce_mesh_executor.h"

namespace hccl
{
    CollCustomHugeAllReduceMeshExecutor::CollCustomHugeAllReduceMeshExecutor(const HcclDispatcher dispatcher,
                                                                             std::unique_ptr<TopoMatcher> &topoMatcher)
        : CollAllReduceExecutor(dispatcher, topoMatcher)
    {
        CCLMemSlice_ = false;
        DMAReduceFlag_ = true;
    }

    // Calculate the amount of scratch memory to request
    HcclResult CollCustomHugeAllReduceMeshExecutor::CalcScratchMemSize(u64 &scratchMemSize)
    {
        // We don't need to use Scratch memory
        scratchMemSize = 0U;
        HCCL_WARNING("[HCCLContest][CollCustomHugeAllReduceMeshExecutor][CalcScratchMemSize] scratchMemSize: %u",
                     scratchMemSize);
        return HCCL_SUCCESS;
    }

    // Calculate the number of streams to be requested
    HcclResult CollCustomHugeAllReduceMeshExecutor::CalcStreamNum(u32 &streamNum)
    {
        // One for counter-clockwise Ring AllReduce
        streamNum = 1U;
        HCCL_WARNING("[HCCLContest][CollCustomHugeAllReduceMeshExecutor][CalcStreamNum] streamNum: %u", streamNum);
        return HCCL_SUCCESS;
    }

    // Calculate the number of Notify to be requested
    HcclResult CollCustomHugeAllReduceMeshExecutor::CalcNotifyNum(u32 streamNum, u32 &notifyNum)
    {
        notifyNum = 2U * streamNum;
        HCCL_WARNING("[HCCLContest][CollCustomHugeAllReduceMeshExecutor][CalcNotifyNum] notifyNum: %u", notifyNum);
        return HCCL_SUCCESS;
    }

    // Set up the level-0 mesh topology required for the AllReduce operation
    HcclResult CollCustomHugeAllReduceMeshExecutor::CalcCommInfo(std::vector<LevelNSubCommTransport> &opTransport)
    {
        HCCL_WARNING("[HCCLContest][CollCustomHugeAllReduceMeshExecutor][CalcNotifyNum]");

        // Define the source and destination memory types for communication
        TransportMemType inputType = TransportMemType::CCL_INPUT;
        TransportMemType outputType = TransportMemType::CCL_OUTPUT;
        // Construct a mesh topology for level 0
        CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH);
        commParaLevel0.meshSinglePlane = true;
        // Compute and populate the transport plan for level-0 communication domain
        CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType));
        return HCCL_SUCCESS;
    }

    // Calculate the number of iterations for loop processing
    u64 CollCustomHugeAllReduceMeshExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize)
    {
        u64 maxCountPerLoop = cclBuffSize / unitSize;
        HCCL_WARNING("[HCCLContest][CollCustomHugeAllReduceMeshExecutor][CalcLoopMaxCount] maxCountPerLoop: %u",
                     maxCountPerLoop);
        return maxCountPerLoop;
    }

    // Process data for a single iteration of the AllReduce algorithm execution
    HcclResult CollCustomHugeAllReduceMeshExecutor::KernelRun(const OpParam &param, ExecMem &execMem)
    {
        // Get sub-communication domain information for level 0
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);

        // Post notify signals from master stream to all slave streams
        CHK_RET(MainPostToSlaves(param, execMem));
        // Slave streams wait for master's notification
        CHK_RET(SlavesWaitForMain(param, execMem));

        // Perform ReduceScatter for clockwise direction
        CHK_RET(PerformRingReduceScatter(param, execMem));
        // Perform AllGather for clockwise direction
        CHK_RET(PerformRingAllGather(param, execMem));

        // Perform ReduceScatter for counter-clockwise direction
        CHK_RET(PerformCcwRingReduceScatter(param, execMem));
        // Perform AllGather for counter-clockwise direction
        CHK_RET(PerformCcwRingAllGather(param, execMem));

        // Slave streams notify the master stream that their tasks are completed
        CHK_RET(SlavesPostToMain(param, execMem));
        // The master stream waits for all slave streams to complete their tasks
        CHK_RET(MainWaitForSlaves(param, execMem));

        HCCL_WARNING("[HCCLContest][CollCustomHugeAllReduceMeshExecutor][KernelRun] localRank: %u, localRankSize: %u",
                     level0CommInfo.localRank, level0CommInfo.localRankSize);

        return HCCL_SUCCESS;
    }

    // Perform ReduceScatter for clockwise direction
    //      col0 col1 col2 col3   |      col0 col1 col2 col3
    // R0: [ 2,   2,   2,   2 ]   | R0: [ 2,   8,   6,   4 ]
    // R1: [ 2,   2,   2,   2 ]   | R1: [ 4,   2,   8,   6 ]
    // R2: [ 2,   2,   2,   2 ]   | R2: [ 6,   4,   2,   8 ]
    // R3: [ 2,   2,   2,   2 ]   | R3: [ 8,   6,   4,   2 ]
    HcclResult CollCustomHugeAllReduceMeshExecutor::PerformRingReduceScatter(const OpParam &param, ExecMem &execMem)
    {
        // Get sub-communication domain information for level 0
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);

        // Retrieve the master stream
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);

        // Get the information of the topology
        u64 rankSize = level0CommInfo.localRankSize;
        u32 leftRank = GetLeftNeighbor(level0CommInfo.localRank, rankSize);
        u32 rightRank = GetRightNeighbor(level0CommInfo.localRank, rankSize);

        // Determine unit size, chunk size, block size for this iteration (unit: bytes)
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 subChunkSize = execMem.count * unitSize / 2;
        u32 blockSize = subChunkSize / rankSize;

        // Copy data from user input to CCL_Out buffer
        DeviceMem src = DeviceMem::create(execMem.inputPtr, subChunkSize);
        DeviceMem dst = DeviceMem::create(execMem.outputMem.ptr(), subChunkSize);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, masterStream));

        // Perform ReduceScatter operation for each round
        for (u32 round = 0; round < rankSize - 1; round++)
        {
            // Rank i receives a block from the left rank, the index is (i - r + N - 1) % N
            u32 logicalRankIdx = PhysicalToLogical(level0CommInfo.localRank, rankSize);
            u32 blockIdx = (logicalRankIdx - round + rankSize - 1) % rankSize;

            // Notify the right rank I am ready to provide data
            CHK_RET(level0CommInfo.links[rightRank]->TxAck(masterStream));
            // Wait for the left rank to notify it is ready to provide data
            CHK_RET(level0CommInfo.links[leftRank]->RxAck(masterStream));

            // Source address on the left rank's CCL_Out (remote)
            void *srcRemoteMemPtr = nullptr;
            CHK_RET(level0CommInfo.links[leftRank]->GetRemoteMem(UserMemType::OUTPUT_MEM, &srcRemoteMemPtr));
            srcRemoteMemPtr = static_cast<char *>(srcRemoteMemPtr) + blockIdx * blockSize;
            DeviceMem srcRemote = DeviceMem::create(static_cast<char *>(srcRemoteMemPtr), blockSize);
            // Destination address on the local rank's CCL_Out (local)
            void *dstLocalMemPtr = static_cast<char *>(execMem.outputMem.ptr()) + blockIdx * blockSize;
            DeviceMem dstLocal = DeviceMem::create(static_cast<char *>(dstLocalMemPtr), blockSize);

            // Perform HcclD2DMemcpyAsync: Copy and reduce data from remote CCL_Out to local CCL_Out
            CHK_RET(HcclReduceAsync(dispatcher_, static_cast<void *>(srcRemote.ptr()), blockSize / unitSize,
                                    param.DataDes.dataType, param.reduceType, masterStream, static_cast<void *>(dstLocal.ptr()),
                                    level0CommInfo.links[leftRank]->GetRemoteRank(), level0CommInfo.links[leftRank]->GetLinkType(), INLINE_REDUCE_BIT));

            // Notify the left rank I am done copying data
            CHK_RET(level0CommInfo.links[leftRank]->TxDataSignal(masterStream));
            // Wait for the right rank to notify it is done copying data
            CHK_RET(level0CommInfo.links[rightRank]->RxDataSignal(masterStream));
        }

        return HCCL_SUCCESS;
    }

    // Perform AllGather for clockwise direction
    //      col0 col1 col2 col3   |      col0 col1 col2 col3
    // R0: [ 2,   8,   6,   4 ]   | R0: [ 8,   8,   8,   8 ]
    // R1: [ 4,   2,   8,   6 ]   | R1: [ 8,   8,   8,   8 ]
    // R2: [ 6,   4,   2,   8 ]   | R2: [ 8,   8,   8,   8 ]
    // R3: [ 8,   6,   4,   2 ]   | R3: [ 8,   8,   8,   8 ]
    HcclResult CollCustomHugeAllReduceMeshExecutor::PerformRingAllGather(const OpParam &param, ExecMem &execMem)
    {
        // Get sub-communication domain information for level 0
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);

        // Retrieve the master stream
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);

        // Get the information of the topology
        u64 rankSize = level0CommInfo.localRankSize;
        u32 leftRank = GetLeftNeighbor(level0CommInfo.localRank, rankSize);
        u32 rightRank = GetRightNeighbor(level0CommInfo.localRank, rankSize);

        // Determine unit size, chunk size, block size for this iteration (unit: bytes)
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 subChunkSize = execMem.count * unitSize / 2;
        u32 blockSize = subChunkSize / rankSize;

        // Perform AllGather operation for each round
        for (u32 round = 0; round < rankSize - 1; round++)
        {
            // Rank i receives a block from the left rank, the index is (i - r + N) % N
            u32 logicalRankIdx = PhysicalToLogical(level0CommInfo.localRank, rankSize);
            u32 blockIdx = (logicalRankIdx - round + rankSize) % rankSize;

            // Notify the right rank I am ready to provide data
            CHK_RET(level0CommInfo.links[rightRank]->TxAck(masterStream));
            // Wait for the left rank to notify it is ready to provide data
            CHK_RET(level0CommInfo.links[leftRank]->RxAck(masterStream));

            // Source address on the left rank's CCL_Out (remote)
            void *srcRemoteMemPtr = nullptr;
            CHK_RET(level0CommInfo.links[leftRank]->GetRemoteMem(UserMemType::OUTPUT_MEM, &srcRemoteMemPtr));
            srcRemoteMemPtr = static_cast<char *>(srcRemoteMemPtr) + blockIdx * blockSize;
            DeviceMem srcRemote = DeviceMem::create(static_cast<char *>(srcRemoteMemPtr), blockSize);
            // Destination address on the local rank's CCL_Out (local)
            void *dstLocalMemPtr = static_cast<char *>(execMem.outputMem.ptr()) + blockIdx * blockSize;
            DeviceMem dstLocal = DeviceMem::create(static_cast<char *>(dstLocalMemPtr), blockSize);

            // Perform HcclD2DMemcpyAsync
            CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstLocal, srcRemote, masterStream,
                                       level0CommInfo.links[leftRank]->GetRemoteRank(), level0CommInfo.links[leftRank]->GetLinkType()));

            // Notify the left rank I am done copying data
            CHK_RET(level0CommInfo.links[leftRank]->TxDataSignal(masterStream));
            // Wait for the right rank to notify it is done copying data
            CHK_RET(level0CommInfo.links[rightRank]->RxDataSignal(masterStream));
        }

        // Each rank copy chunk from CCL_Out buffer to user output
        DeviceMem src = DeviceMem::create(static_cast<char *>(execMem.outputMem.ptr()), subChunkSize);
        DeviceMem dst = DeviceMem::create(static_cast<char *>(execMem.outputPtr), subChunkSize);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, masterStream));

        return HCCL_SUCCESS;
    }

    // Perform ReduceScatter for counter-clockwise direction
    //      col0 col1 col2 col3   |      col0 col1 col2 col3
    // R0: [ 2,   2,   2,   2 ]   | R0: [ 2,   4,   6,   8 ]
    // R1: [ 2,   2,   2,   2 ]   | R1: [ 8,   2,   4,   6 ]
    // R2: [ 2,   2,   2,   2 ]   | R2: [ 6,   8,   2,   4 ]
    // R3: [ 2,   2,   2,   2 ]   | R3: [ 4,   6,   8,   2 ]
    HcclResult CollCustomHugeAllReduceMeshExecutor::PerformCcwRingReduceScatter(const OpParam &param, ExecMem &execMem)
    {
        // Get sub-communication domain information for level 0
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);

        // Retrieve the slave stream
        Stream &subStream = algResResp_->slaveStreams[C_CW_SLAVE_INDEX];

        // Get the information of the topology
        u64 rankSize = level0CommInfo.localRankSize;
        u32 leftRank = GetLeftNeighbor(level0CommInfo.localRank, rankSize);
        u32 rightRank = GetRightNeighbor(level0CommInfo.localRank, rankSize);

        // Determine unit size, chunk size, block size for this iteration (unit: bytes)
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 subChunkSize = execMem.count * unitSize / 2;
        u32 subChunkBaseOffset = subChunkSize;
        u32 blockSize = subChunkSize / rankSize;

        // Copy data from user input to CCL_Out buffer
        DeviceMem src = DeviceMem::create(static_cast<char *>(execMem.inputPtr) + subChunkBaseOffset, subChunkSize);
        DeviceMem dst = DeviceMem::create(static_cast<char *>(execMem.outputMem.ptr()) + subChunkBaseOffset, subChunkSize);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, subStream));

        // Perform ReduceScatter operation for each round
        for (u32 round = 0; round < rankSize - 1; round++)
        {
            // Rank i receives a block from the right rank, the index is (i + r + 1) % N
            u32 logicalRankIdx = PhysicalToLogical(level0CommInfo.localRank, rankSize);
            u32 blockIdx = (logicalRankIdx + round + 1) % rankSize;

            // Notify the left rank I am ready to provide data
            CHK_RET(level0CommInfo.links[leftRank]->TxAck(subStream));
            // Wait for the right rank to notify it is ready to provide data
            CHK_RET(level0CommInfo.links[rightRank]->RxAck(subStream));

            // Source address on the right rank's CCL_Out (remote)
            void *srcRemoteMemPtr = nullptr;
            CHK_RET(level0CommInfo.links[rightRank]->GetRemoteMem(UserMemType::OUTPUT_MEM, &srcRemoteMemPtr));
            srcRemoteMemPtr = static_cast<char *>(srcRemoteMemPtr) + subChunkBaseOffset + blockIdx * blockSize;
            DeviceMem srcRemote = DeviceMem::create(static_cast<char *>(srcRemoteMemPtr), blockSize);
            // Destination address on the local rank's CCL_Out (local)
            void *dstLocalMemPtr = static_cast<char *>(execMem.outputMem.ptr()) + subChunkBaseOffset + blockIdx * blockSize;
            DeviceMem dstLocal = DeviceMem::create(static_cast<char *>(dstLocalMemPtr), blockSize);

            // Perform HcclD2DMemcpyAsync: Copy and reduce data from remote CCL_Out to local CCL_Out
            CHK_RET(HcclReduceAsync(dispatcher_, static_cast<void *>(srcRemote.ptr()), blockSize / unitSize,
                                    param.DataDes.dataType, param.reduceType, subStream, static_cast<void *>(dstLocal.ptr()),
                                    level0CommInfo.links[rightRank]->GetRemoteRank(), level0CommInfo.links[rightRank]->GetLinkType(), INLINE_REDUCE_BIT));

            // Notify the right rank I am done copying data
            CHK_RET(level0CommInfo.links[rightRank]->TxDataSignal(subStream));
            // Wait for the left rank to notify it is done copying data
            CHK_RET(level0CommInfo.links[leftRank]->RxDataSignal(subStream));
        }

        return HCCL_SUCCESS;
    }

    // Perform AllGather for counter-clockwise direction
    //      col0 col1 col2 col3   |      col0 col1 col2 col3
    // R0: [ 2,   4,   6,   8 ]   | R0: [ 8,   8,   8,   8 ]
    // R1: [ 8,   2,   4,   6 ]   | R1: [ 8,   8,   8,   8 ]
    // R2: [ 6,   8,   2,   4 ]   | R2: [ 8,   8,   8,   8 ]
    // R3: [ 4,   6,   8,   2 ]   | R3: [ 8,   8,   8,   8 ]
    HcclResult CollCustomHugeAllReduceMeshExecutor::PerformCcwRingAllGather(const OpParam &param, ExecMem &execMem)
    {
        // Get sub-communication domain information for level 0
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);

        // Retrieve the slave stream
        Stream &subStream = algResResp_->slaveStreams[C_CW_SLAVE_INDEX];

        // Get the information of the topology
        u64 rankSize = level0CommInfo.localRankSize;
        u32 leftRank = GetLeftNeighbor(level0CommInfo.localRank, rankSize);
        u32 rightRank = GetRightNeighbor(level0CommInfo.localRank, rankSize);

        // Determine unit size, chunk size, block size for this iteration (unit: bytes)
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 subChunkSize = execMem.count * unitSize / 2;
        u32 subChunkBaseOffset = subChunkSize;
        u32 blockSize = subChunkSize / rankSize;

        // Perform AllGather operation for each round
        for (u32 round = 0; round < rankSize - 1; round++)
        {
            // Rank i receives a block from the left rank, the index is (i + r) % N
            u32 logicalRankIdx = PhysicalToLogical(level0CommInfo.localRank, rankSize);
            u32 blockIdx = (logicalRankIdx + round) % rankSize;

            // Notify the left rank I am ready to provide data
            CHK_RET(level0CommInfo.links[leftRank]->TxAck(subStream));
            // Wait for the right rank to notify it is ready to provide data
            CHK_RET(level0CommInfo.links[rightRank]->RxAck(subStream));

            // Source address on the left rank's CCL_Out (remote)
            void *srcRemoteMemPtr = nullptr;
            CHK_RET(level0CommInfo.links[rightRank]->GetRemoteMem(UserMemType::OUTPUT_MEM, &srcRemoteMemPtr));
            srcRemoteMemPtr = static_cast<char *>(srcRemoteMemPtr) + subChunkBaseOffset + blockIdx * blockSize;
            DeviceMem srcRemote = DeviceMem::create(static_cast<char *>(srcRemoteMemPtr), blockSize);
            // Destination address on the local rank's CCL_Out (local)
            void *dstLocalMemPtr = static_cast<char *>(execMem.outputMem.ptr()) + subChunkBaseOffset + blockIdx * blockSize;
            DeviceMem dstLocal = DeviceMem::create(static_cast<char *>(dstLocalMemPtr), blockSize);

            // Perform HcclD2DMemcpyAsync
            CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstLocal, srcRemote, subStream,
                                       level0CommInfo.links[rightRank]->GetRemoteRank(), level0CommInfo.links[rightRank]->GetLinkType()));

            // Notify the right rank I am done copying data
            CHK_RET(level0CommInfo.links[rightRank]->TxDataSignal(subStream));
            // Wait for the left rank to notify it is done copying data
            CHK_RET(level0CommInfo.links[leftRank]->RxDataSignal(subStream));
        }

        // Each rank copy chunk from CCL_Out buffer to user output
        DeviceMem src = DeviceMem::create(static_cast<char *>(execMem.outputMem.ptr()) + subChunkBaseOffset, subChunkSize);
        DeviceMem dst = DeviceMem::create(static_cast<char *>(execMem.outputPtr) + subChunkBaseOffset, subChunkSize);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, subStream));

        return HCCL_SUCCESS;
    }

    // Post notify signals from master stream to all slave streams
    HcclResult CollCustomHugeAllReduceMeshExecutor::MainPostToSlaves(const OpParam &param, ExecMem &execMem)
    {
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);
        for (u32 signalIndex = 0; signalIndex < algResResp_->slaveStreams.size(); signalIndex++)
        {
            CHK_RET(LocalNotify::Post(masterStream, dispatcher_,
                                      algResResp_->notifiesAux[signalIndex], PROF_STAGE_1));
        }
        return HCCL_SUCCESS;
    }

    // Slave streams wait for master's notification
    HcclResult CollCustomHugeAllReduceMeshExecutor::SlavesWaitForMain(const OpParam &param, ExecMem &execMem)
    {
        for (u32 streamIndex = 0; streamIndex < algResResp_->slaveStreams.size(); streamIndex++)
        {
            CHK_RET(LocalNotify::Wait(algResResp_->slaveStreams[streamIndex], dispatcher_,
                                      algResResp_->notifiesAux[streamIndex], PROF_STAGE_1));
        }
        return HCCL_SUCCESS;
    }

    // Slave streams notify the master stream that their tasks are completed
    HcclResult CollCustomHugeAllReduceMeshExecutor::SlavesPostToMain(const OpParam &param, ExecMem &execMem)
    {
        for (u32 streamIndex = 0; streamIndex < algResResp_->slaveStreams.size(); streamIndex++)
        {
            CHK_RET(LocalNotify::Post(algResResp_->slaveStreams[streamIndex], dispatcher_,
                                      algResResp_->notifiesMain[streamIndex], PROF_STAGE_1));
        }
        return HCCL_SUCCESS;
    }

    // The master stream waits for all slave streams to complete their tasks
    HcclResult CollCustomHugeAllReduceMeshExecutor::MainWaitForSlaves(const OpParam &param, ExecMem &execMem)
    {
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);
        for (u32 signalIndex = 0; signalIndex < algResResp_->slaveStreams.size(); signalIndex++)
        {
            CHK_RET(LocalNotify::Wait(masterStream, dispatcher_,
                                      algResResp_->notifiesMain[signalIndex], PROF_STAGE_1));
        }
        return HCCL_SUCCESS;
    }

    // The master stream executes an empty task to ensure synchronization
    HcclResult CollCustomHugeAllReduceMeshExecutor::MainExecEmptyTask(const OpParam &param, ExecMem &execMem)
    {
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);
        DeviceMem srcTmp = DeviceMem::create(execMem.inputPtr, 0);
        DeviceMem dstTmp = DeviceMem::create(execMem.outputPtr, 0);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstTmp, srcTmp, masterStream));
        return HCCL_SUCCESS;
    }

    // The slave stream execute an empty task to ensure synchronization
    HcclResult CollCustomHugeAllReduceMeshExecutor::SlaveExecEmptyTask(const OpParam &param, ExecMem &execMem, u32 streamIndex)
    {
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);
        DeviceMem srcTmp = DeviceMem::create(execMem.inputPtr, 0);
        DeviceMem dstTmp = DeviceMem::create(execMem.outputPtr, 0);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstTmp, srcTmp, algResResp_->slaveStreams[streamIndex]));
        return HCCL_SUCCESS;
    }

    // Convert logical rank addresses to actual rank addresses
    u32 CollCustomHugeAllReduceMeshExecutor::LogicalToPhysical(u32 logicalRank, u32 rankSize)
    {
        return twistedRing[logicalRank];
    }

    // Convert actual rank addresses to logical rank addresses
    u32 CollCustomHugeAllReduceMeshExecutor::PhysicalToLogical(u32 physicalRank, u32 rankSize)
    {
        u32 logicalRankIdx = 0;
        if (0 < physicalRank && physicalRank < rankSize - 1)
        {
            logicalRankIdx = physicalRank - 1;
        }
        else if (physicalRank == 0)
        {
            logicalRankIdx = rankSize - 2;
        }
        else
        {
            logicalRankIdx = rankSize - 1;
        }
        return logicalRankIdx;
    }

    // Get the left neighbor rank of the specified rank
    u32 CollCustomHugeAllReduceMeshExecutor::GetLeftNeighbor(u32 rank, u32 rankSize)
    {
        // Get the index of the specified rank
        u32 rankIdx = PhysicalToLogical(rank, rankSize);

        // Get the left neighbor rank
        u32 leftRankIdx = (rankSize + rankIdx - 1) % rankSize;
        return twistedRing[leftRankIdx];
    }

    // Get the right neighbor rank of the specified rank
    u32 CollCustomHugeAllReduceMeshExecutor::GetRightNeighbor(u32 rank, u32 rankSize)
    {
        // Get the index of the specified rank
        u32 rankIdx = PhysicalToLogical(rank, rankSize);

        // Get the right neighbor rank
        u32 rightRankIdx = (rankIdx + 1) % rankSize;
        return twistedRing[rightRankIdx];
    }

    REGISTER_EXEC("CustomHugeAllReduceMeshExecutor", CustomHugeAllReduceMesh, CollCustomHugeAllReduceMeshExecutor);
} // namespace hccl

h

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#ifndef COLL_CUSTOM_HUGE_ALLREDUCE_MESH_EXECUTOR_H
#define COLL_CUSTOM_HUGE_ALLREDUCE_MESH_EXECUTOR_H

#include "coll_all_reduce_executor.h"

// Different Ring AllReduce directions use different slave streams
#define C_CW_SLAVE_INDEX 0 // Counter-clockwise Ring AllReduce slave stream index

namespace hccl
{
    class CollCustomHugeAllReduceMeshExecutor : public CollAllReduceExecutor
    {
    public:
        CollCustomHugeAllReduceMeshExecutor(const HcclDispatcher dispatcher, std::unique_ptr<TopoMatcher> &topoMatcher);
        ~CollCustomHugeAllReduceMeshExecutor() = default;

    private:
        /* *************** 关键参数 *************** */
        std::vector<u32> twistedRing = {1, 2, 3, 4, 5, 6, 0, 7};
        /* *************** 资源计算 *************** */
        HcclResult CalcScratchMemSize(u64 &scratchMemSize) override;
        HcclResult CalcStreamNum(u32 &streamNum) override;
        HcclResult CalcNotifyNum(u32 streamNum, u32 &notifyNum) override;
        HcclResult CalcCommInfo(std::vector<LevelNSubCommTransport> &opTransport) override;
        u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize);
        /* *************** 算法编排 *************** */
        HcclResult KernelRun(const OpParam &param, ExecMem &execMem) override;
        HcclResult PerformRingReduceScatter(const OpParam &param, ExecMem &execMem);
        HcclResult PerformRingAllGather(const OpParam &param, ExecMem &execMem);
        HcclResult PerformCcwRingReduceScatter(const OpParam &param, ExecMem &execMem);
        HcclResult PerformCcwRingAllGather(const OpParam &param, ExecMem &execMem);
        HcclResult MainPostToSlaves(const OpParam &param, ExecMem &execMem);
        HcclResult SlavesWaitForMain(const OpParam &param, ExecMem &execMem);
        HcclResult SlavesPostToMain(const OpParam &param, ExecMem &execMem);
        HcclResult MainWaitForSlaves(const OpParam &param, ExecMem &execMem);
        HcclResult MainExecEmptyTask(const OpParam &param, ExecMem &execMem);
        HcclResult SlaveExecEmptyTask(const OpParam &param, ExecMem &execMem, u32 streamIndex);
        u32 LogicalToPhysical(u32 logicalRank, u32 rankSize);
        u32 PhysicalToLogical(u32 physicalRank, u32 rankSize);
        u32 GetLeftNeighbor(u32 rank, u32 rankSize);
        u32 GetRightNeighbor(u32 rank, u32 rankSize);
    };
} // namespace hccl

#endif

算法性能

Licensed under CC BY-NC-SA 4.0
皖ICP备2025083746号-1
公安备案 陕公网安备61019002003315号



使用 Hugo 构建
主题 StackJimmy 设计