Featured image of post HCCL通信库创新大赛记录——MESH

HCCL通信库创新大赛记录——MESH

HCCL集合通信库(Huawei Collective Communication Library,简称HCCL)是基于昇腾AI处理器的高性能通信库,聚焦于超大计算集群中的流量调度难题,为大集群提供高效可靠的通信服务,是华为AI软件生态CANN的核心组件之一,本文针对中数据量采用MESH算法

赛题回顾

环境:单机8卡昇腾910B服务器一台

赛题内容:在8PMesh中,断掉一根Full Mesh直连的路径(Rank0和Rank1之间的一根链路),完成AllReduce的集合通信语义

评分指标:使用算法分析工具验证,语义实现正确;在物理环境上运行HCCL Test工具,考察1KB、1MB、1GB这三种数据量下的算法带宽的均值,根据三种数据量下的算法带宽作为评判依据

算法思路

在中数据量的场景下,需要在计算开销和通信开销之间权衡,算法需要尽可能少的迭代次数,也需要尽可能的利用链路资源。因此将全部节点划分为健康节点与不健康节点,健康节点为不受故障链路影响的节点,不健康节点为受故障链路影响的节点(即Rank0与Rank1),算法共需2次完成:

  1. 健康节点聚合:健康节点直接从其它全部节点拉取数据并聚合,本步骤后,全部健康节点持有最终的聚合结果;
  2. 辅助不健康节点:各不健康节点选定一个健康节点,并直接从该节点获取最终的聚合结果。

算法设计

健康节点聚合

宏观流程如下:

  1. 全部Rank(包括不健康节点)将User Input中的输入数据全部拷贝至User Output(从远端把数据拉取来,与本地数据聚合)与CCL_OUT(准备好数据,供远端把数据拉走)中;
  2. 全部不健康节点从其它全部Rank的CCL_OUT中拉取数据到本地User Output进行聚合规约。

每个健康节点需要从其它7个节点拉取数据,因此需要7条从流,每条从流负责从1个对等节点拉取数据,该从流的具体流程如下:

  1. 将位于User Input中的聚合数据拷贝到User Output与CCL_OUT中;
  2. 若对等节点为健康节点,通过TxAck通知对等节点:“本节点已经将数据准备好了,随时可以从CCL_OUT拉取数据/向CCL_OUT推送数据”;
  3. 通过RxAck阻塞流,直至对等节点也准备就绪(即对等节点向本节点发送TxAck通知);
  4. 对等节点准备就绪后,即可从对等节点的CCL_OUT拉取数据,并聚合到本地的User Output;
  5. 完成数据的拉取、聚合后,通过TxDataSignal通知对等节点:“本节点已经完成操作,你可以释放资源或进行其它工作了”;
  6. 若对等节点为健康节点,通过RxDataSignal阻塞流,直至对等节点也完成数据的拉取、聚合(即对等节点向本节点发送TxDataSignal通知)。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
1. 搬运数据到CCL_OUT和User Output
    copy(userIn, cclOut)
    copy(userIn, userOut)
2. 通知远端节点就绪,并等待远端节点就绪
    links[dstRank]->TxAck() # 仅远端节点为健康节点时进行
    links[dstRank]-> RxACK()
3. 从远端CCL_OUT拉取数据,并聚合到User Output
    reduce(remoteCclOut, localUserOut)
4. 通知远等节点拉取完成,并等待远等节点完成
    links[dstRank]->TxDataSignal()
    links[dstRank]-> RxDataSignal() # 仅远端节点为健康节点时进行

不健康节点无需从其它节点拉取数据,因此所需执行的操作较为简单:

  1. 将位于User Input中的聚合数据拷贝到CCL_OUT中;
  2. 通过TxAck通知对等节点:“本节点已经将数据准备好了,随时可以从CCL_OUT拉取数据/向CCL_OUT推送数据”;
  3. 通过RxDataSignal阻塞流,直至对等节点也完成数据的拉取、聚合(即对等节点向本节点发送TxDataSignal通知)。
1
2
3
4
5
6
1. 搬运数据到CCL_OUT
    copy(userIn, cclOut)
2. 通知远端节点就绪
    links[dstRank]->TxAck() 
3. 等待远等节点完成
    links[dstRank]-> RxDataSignal()

辅助不健康节点

宏观流程如下:

  1. 各不健康节点从选定健康节点的User Output中拉取数据到本地的CCL_OUT中,或选定健康节点将User Output中的数据推送到所负责节点的CCL_OUT中;
  2. 各不健康节点将CCL_OUT中的数据拷贝到User Output。

所选择的健康节点只需将数据推送到所负责的不健康节点,因此在主流上即可完成,该主流的具体流程如下:

  1. 通过RxAck阻塞流,直至对等节点也准备就绪(即对等节点向本节点发送TxAck通知);
  2. 对等节点准备就绪后,即可将本地User Output中的数据拷贝到对等节点的CCL_OUT中;
  3. 完成数据的拉取、聚合后,通过TxDataSignal通知对等节点:“本节点已经完成操作,你可以释放资源或进行其它工作了”。
1
2
3
4
5
6
1. 等待远端节点就绪
    links[dstRank]->RxAck() 
2. 推送数据
    memcpy(localUserOut, remoteCclOut)
3. 通知远端节点完成
    links[dstRank]->TxDataSignal()

每个健康节点需要从其它7个节点拉取数据,因此需要7条从流,每条从流负责从1个对等节点拉取数据,该从流的具体流程如下:

  1. 通过TxAck通知对等节点:“本节点已经将数据准备好了,随时可以从CCL_OUT拉取数据/向CCL_OUT推送数据”;
  2. 通过RxDataSignal阻塞流,直至对等节点也完成数据的推送(即对等节点向本节点发送TxDataSignal通知);
  3. 将本地CCL_OUT中的数据拷贝到User Output中。
1
2
3
4
5
6
1. 通知远端节点就绪
    links[dstRank]->TxAck()
2. 等待远端节点完成
    links[dstRank]->RxDataSignal()
3. 拷贝数据到User Output
    memcpy(userOut, cclOut)

任务编排

编程实现

cc

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
#include "coll_custom_medium_all_reduce_mesh_executor.h"

namespace hccl
{
    CollCustomMediumAllReduceMeshExecutor::CollCustomMediumAllReduceMeshExecutor(const HcclDispatcher dispatcher,
                                                                                 std::unique_ptr<TopoMatcher> &topoMatcher)
        : CollAllReduceExecutor(dispatcher, topoMatcher)
    {
        CCLMemSlice_ = false;
        DMAReduceFlag_ = true;
    }

    // Calculate the amount of scratch memory to request
    HcclResult CollCustomMediumAllReduceMeshExecutor::CalcScratchMemSize(u64 &scratchMemSize)
    {
        // We don't need to use Scratch memory
        scratchMemSize = 0U;
        HCCL_WARNING("[HCCLContest][CollCustomMediumAllReduceMeshExecutor][CalcScratchMemSize] scratchMemSize: %u",
                     scratchMemSize);
        return HCCL_SUCCESS;
    }

    // Calculate the number of streams to be requested
    HcclResult CollCustomMediumAllReduceMeshExecutor::CalcStreamNum(u32 &streamNum)
    {
        // One stream is required for each remote rank
        u32 totalStreamNum = topoAttr_.deviceNumPerAggregation;
        streamNum = totalStreamNum - 1U;
        HCCL_WARNING("[HCCLContest][CollCustomMediumAllReduceMeshExecutor][CalcStreamNum] streamNum: %u", streamNum);
        return HCCL_SUCCESS;
    }

    // Calculate the number of Notify to be requested
    HcclResult CollCustomMediumAllReduceMeshExecutor::CalcNotifyNum(u32 streamNum, u32 &notifyNum)
    {
        notifyNum = 2U * streamNum;
        HCCL_WARNING("[HCCLContest][CollCustomMediumAllReduceMeshExecutor][CalcNotifyNum] notifyNum: %u", notifyNum);
        return HCCL_SUCCESS;
    }

    // Set up the level-0 mesh topology required for the AllReduce operation
    HcclResult CollCustomMediumAllReduceMeshExecutor::CalcCommInfo(std::vector<LevelNSubCommTransport> &opTransport)
    {
        HCCL_WARNING("[HCCLContest][CollCustomMediumAllReduceMeshExecutor][CalcNotifyNum]");

        // Define the source and destination memory types for communication
        TransportMemType inputType = TransportMemType::CCL_INPUT;
        TransportMemType outputType = TransportMemType::CCL_OUTPUT;
        // Construct a mesh topology for level 0
        CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH);
        commParaLevel0.meshSinglePlane = true;
        // Compute and populate the transport plan for level-0 communication domain
        CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType));
        return HCCL_SUCCESS;
    }

    // Calculate the number of iterations for loop processing
    u64 CollCustomMediumAllReduceMeshExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize)
    {
        u64 maxCountPerLoop = cclBuffSize / unitSize;
        HCCL_WARNING("[HCCLContest][CollCustomMediumAllReduceMeshExecutor][CalcLoopMaxCount] maxCountPerLoop: %u",
                     maxCountPerLoop);
        return maxCountPerLoop;
    }

    // Process data for a single iteration of the AllReduce algorithm execution
    HcclResult CollCustomMediumAllReduceMeshExecutor::KernelRun(const OpParam &param, ExecMem &execMem)
    {
        // Get sub-communication domain information for level 0
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);

        // Dertermine the information for the topology
        u32 rankSize = level0CommInfo.localRankSize;
        u32 rankId = level0CommInfo.localRank;

        // Perform AllReduce
        if (IsHealthyRank(rankId))
        {
            HealthyPerformAllReduce(param, execMem);
        }
        else
        {
            UnhealthyPerformAllReduce(param, execMem);
        }

        // Resolving unhealthy ranks
        if (IsProxyRank(rankId))
        {
            ProxyPerformSend(param, execMem);
        }
        if (!IsHealthyRank(rankId))
        {
            UnhealthyPerformReceive(param, execMem);
        }

        HCCL_WARNING("[HCCLContest][CollCustomMediumAllReduceMeshExecutor][KernelRun] localRank: %u, localRankSize: %u",
                     level0CommInfo.localRank, level0CommInfo.localRankSize);

        return HCCL_SUCCESS;
    }

    // The healthy rank pulls and reduces data from each rank
    HcclResult CollCustomMediumAllReduceMeshExecutor::HealthyPerformAllReduce(const OpParam &param, ExecMem &execMem)
    {
        // Retrieve the sub-communication domain information and the master stream
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);

        // Dertermine the information for the topology and the data
        u32 rankSize = level0CommInfo.localRankSize;
        u32 rankId = level0CommInfo.localRank;
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 chunkSize = execMem.count * unitSize;

        // Copy data from user input to CCL_Out
        DeviceMem src = DeviceMem::create(execMem.inputPtr, chunkSize);
        DeviceMem dst = DeviceMem::create(execMem.outputMem.ptr(), chunkSize);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, masterStream));

        // Copy data from user input to user output
        src = DeviceMem::create(execMem.inputPtr, chunkSize);
        dst = DeviceMem::create(execMem.outputPtr, chunkSize);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, masterStream));

        // Post notify signals from master stream to all slave streams
        CHK_RET(MainPostToSlaves(param, execMem));
        // Slave streams wait for master's notification
        CHK_RET(SlavesWaitForMain(param, execMem));

        for (u32 round = 1; round < rankSize; round++)
        {
            // Get the slave stream assigned for communication with dstRank
            u32 dstRank = (round + rankId) % rankSize;
            Stream &subStream = algResResp_->slaveStreams[round - 1];

            // Only receive data from unhealthy ranks, other ranks not only receive but also send
            if (!IsHealthyRank(dstRank))
            {
                // Wait for notification from the remote rank
                CHK_RET(level0CommInfo.links[dstRank]->RxAck(subStream));
            }
            else
            {
                // Notify remote rank data has been ready
                CHK_RET(level0CommInfo.links[dstRank]->TxAck(subStream));
                // Wait for notification from the remote rank
                CHK_RET(level0CommInfo.links[dstRank]->RxAck(subStream));
            }

            // Source address on the remote rank's CCL_Out (remote)
            void *srcRemoteMemPtr = nullptr;
            CHK_RET(level0CommInfo.links[dstRank]->GetRemoteMem(UserMemType::OUTPUT_MEM, &srcRemoteMemPtr));
            DeviceMem srcRemote = DeviceMem::create(static_cast<char *>(srcRemoteMemPtr), chunkSize);

            // Destination address on the local rank's user output (local)
            void *dstLocalMemPtr = static_cast<char *>(execMem.outputPtr);
            DeviceMem dstLocal = DeviceMem::create(static_cast<char *>(dstLocalMemPtr), chunkSize);

            // Perform HcclD2DMemcpyAsync: Copy and reduce data from remote CCL_Out to local user output
            CHK_RET(HcclReduceAsync(dispatcher_, static_cast<void *>(srcRemote.ptr()), chunkSize / unitSize,
                                    param.DataDes.dataType, param.reduceType, subStream,
                                    static_cast<void *>(dstLocal.ptr()),
                                    level0CommInfo.links[dstRank]->GetRemoteRank(),
                                    level0CommInfo.links[dstRank]->GetLinkType(),
                                    INLINE_REDUCE_BIT));

            // Only receive data from unhealthy ranks, other ranks not only receive but also send
            if (!IsHealthyRank(dstRank))
            {
                // Notify data transfer and operations are completed
                CHK_RET(level0CommInfo.links[dstRank]->TxDataSignal(subStream));
            }
            else
            {
                // Notify the remote rank that data transfer and operations have completed
                CHK_RET(level0CommInfo.links[dstRank]->TxDataSignal(subStream));
                // Wait for data transfer and operations has been completed
                CHK_RET(level0CommInfo.links[dstRank]->RxDataSignal(subStream));
            }
        }

        // Slave streams notify the master stream that their tasks are completed
        CHK_RET(SlavesPostToMain(param, execMem));
        // The master stream waits for all slave streams to complete their tasks
        CHK_RET(MainWaitForSlaves(param, execMem));

        return HCCL_SUCCESS;
    }

    // Unhealthy ranks only need to prepare data
    HcclResult CollCustomMediumAllReduceMeshExecutor::UnhealthyPerformAllReduce(const OpParam &param, ExecMem &execMem)
    {
        // Retrieve the sub-communication domain information and the master stream
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);

        // Dertermine the information for the topology and the data
        u32 rankSize = level0CommInfo.localRankSize;
        u32 rankId = level0CommInfo.localRank;
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 chunkSize = execMem.count * unitSize;

        // Copy data from user input to CCL_Out
        DeviceMem src = DeviceMem::create(execMem.inputPtr, chunkSize);
        DeviceMem dst = DeviceMem::create(execMem.outputMem.ptr(), chunkSize);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, masterStream));

        // Post notify signals from master stream to all slave streams
        CHK_RET(MainPostToSlaves(param, execMem));
        // Slave streams wait for master's notification
        CHK_RET(SlavesWaitForMain(param, execMem));

        for (u32 round = 1; round < rankSize; round++)
        {
            // Get the slave stream assigned for communication with dstRank
            u32 dstRank = (round + rankId) % rankSize;
            Stream &subStream = algResResp_->slaveStreams[round - 1];

            // Skipping unhealthy ranks
            if (!IsHealthyRank(dstRank))
            {
                // Slave stream do empty job
                CHK_RET(SlaveExecEmptyTask(param, execMem, round - 1));
                continue;
            }

            // Notify remote rank data has been ready
            CHK_RET(level0CommInfo.links[dstRank]->TxAck(subStream));
            // Ensure data transfer and operations are completed
            CHK_RET(level0CommInfo.links[dstRank]->RxDataSignal(subStream));
        }

        // Slave streams notify the master stream that their tasks are completed
        CHK_RET(SlavesPostToMain(param, execMem));
        // The master stream waits for all slave streams to complete their tasks
        CHK_RET(MainWaitForSlaves(param, execMem));

        return HCCL_SUCCESS;
    }

    // The proxy rank passes the result to the unhealthy rank
    HcclResult CollCustomMediumAllReduceMeshExecutor::ProxyPerformSend(const OpParam &param, ExecMem &execMem)
    {
        // Retrieve the sub-communication domain information and the master stream
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);

        // Dertermine the information for the topology and the data
        u32 rankSize = level0CommInfo.localRankSize;
        u32 rankId = level0CommInfo.localRank;
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 chunkSize = execMem.count * unitSize;

        // Each proxy rank is responsible for only one unhealthy rank
        u32 dstRank = 0;
        dstRank = (rankId == PROXY_RANK_X) ? BROKEN_RANK_X : dstRank;
        dstRank = (rankId == PROXY_RANK_Y) ? BROKEN_RANK_Y : dstRank;

        // Notify remote rank data has been ready
        CHK_RET(level0CommInfo.links[dstRank]->RxAck(masterStream));

        // Destination address on the remote rank's CCL_Out (remote)
        void *dstRemoteMemPtr = nullptr;
        CHK_RET(level0CommInfo.links[dstRank]->GetRemoteMem(UserMemType::OUTPUT_MEM, &dstRemoteMemPtr));
        DeviceMem dstRemote = DeviceMem::create(static_cast<char *>(dstRemoteMemPtr), chunkSize);

        // Source address on the local rank's user output (local)
        void *srcLocalMemPtr = static_cast<char *>(execMem.outputPtr);
        DeviceMem dstLocal = DeviceMem::create(static_cast<char *>(srcLocalMemPtr), chunkSize);

        // Perform HcclD2DMemcpyAsync: Copy data from local user output to remote CCL_Out
        HcclD2DMemcpyAsync(dispatcher_, dstRemote, dstLocal, masterStream);

        // Communication post-synchronization: Ensure data transfer and operations are completed
        CHK_RET(level0CommInfo.links[dstRank]->TxDataSignal(masterStream));

        return HCCL_SUCCESS;
    }

    // Unhealthy ranks receive the last results
    HcclResult CollCustomMediumAllReduceMeshExecutor::UnhealthyPerformReceive(const OpParam &param, ExecMem &execMem)
    {
        // Retrieve the sub-communication domain information and the master stream
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);

        // Dertermine the information for the topology and the data
        u32 rankSize = level0CommInfo.localRankSize;
        u32 rankId = level0CommInfo.localRank;
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 chunkSize = execMem.count * unitSize;

        // Each proxy rank is responsible for only one unhealthy rank
        u32 srcRank = 0;
        srcRank = (rankId == BROKEN_RANK_X) ? PROXY_RANK_X : srcRank;
        srcRank = (rankId == BROKEN_RANK_Y) ? PROXY_RANK_Y : srcRank;

        // Notify remote rank data has been ready
        CHK_RET(level0CommInfo.links[srcRank]->TxAck(masterStream));
        // Ensure data transfer and operations are completed
        CHK_RET(level0CommInfo.links[srcRank]->RxDataSignal(masterStream));

        // Copy data from CCL_Out to user input
        DeviceMem dst = DeviceMem::create(execMem.outputPtr, chunkSize);
        DeviceMem src = DeviceMem::create(execMem.outputMem.ptr(), chunkSize);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, masterStream));

        return HCCL_SUCCESS;
    }

    // Determine if a rank is healthy
    bool CollCustomMediumAllReduceMeshExecutor::IsHealthyRank(u32 rankId)
    {
        return rankId != BROKEN_RANK_X && rankId != BROKEN_RANK_Y;
    }

    // Determine if a rank is proxy
    bool CollCustomMediumAllReduceMeshExecutor::IsProxyRank(u32 rankId)
    {
        return rankId == PROXY_RANK_X || rankId == PROXY_RANK_Y;
    }

    // Post notify signals from master stream to all slave streams
    HcclResult CollCustomMediumAllReduceMeshExecutor::MainPostToSlaves(const OpParam &param, ExecMem &execMem)
    {
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);
        for (u32 signalIndex = 0; signalIndex < algResResp_->slaveStreams.size(); signalIndex++)
        {
            CHK_RET(LocalNotify::Post(masterStream, dispatcher_,
                                      algResResp_->notifiesAux[signalIndex], PROF_STAGE_1));
        }
        return HCCL_SUCCESS;
    }

    // Slave streams wait for master's notification
    HcclResult CollCustomMediumAllReduceMeshExecutor::SlavesWaitForMain(const OpParam &param, ExecMem &execMem)
    {
        for (u32 streamIndex = 0; streamIndex < algResResp_->slaveStreams.size(); streamIndex++)
        {
            CHK_RET(LocalNotify::Wait(algResResp_->slaveStreams[streamIndex], dispatcher_,
                                      algResResp_->notifiesAux[streamIndex], PROF_STAGE_1));
        }
        return HCCL_SUCCESS;
    }

    // Slave streams notify the master stream that their tasks are completed
    HcclResult CollCustomMediumAllReduceMeshExecutor::SlavesPostToMain(const OpParam &param, ExecMem &execMem)
    {
        for (u32 streamIndex = 0; streamIndex < algResResp_->slaveStreams.size(); streamIndex++)
        {
            CHK_RET(LocalNotify::Post(algResResp_->slaveStreams[streamIndex], dispatcher_,
                                      algResResp_->notifiesMain[streamIndex], PROF_STAGE_1));
        }
        return HCCL_SUCCESS;
    }

    // The master stream waits for all slave streams to complete their tasks
    HcclResult CollCustomMediumAllReduceMeshExecutor::MainWaitForSlaves(const OpParam &param, ExecMem &execMem)
    {
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);
        for (u32 signalIndex = 0; signalIndex < algResResp_->slaveStreams.size(); signalIndex++)
        {
            CHK_RET(LocalNotify::Wait(masterStream, dispatcher_,
                                      algResResp_->notifiesMain[signalIndex], PROF_STAGE_1));
        }
        return HCCL_SUCCESS;
    }

    // The master stream executes an empty task to ensure synchronization
    HcclResult CollCustomMediumAllReduceMeshExecutor::MainExecEmptyTask(const OpParam &param, ExecMem &execMem)
    {
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);
        DeviceMem srcTmp = DeviceMem::create(execMem.inputPtr, 0);
        DeviceMem dstTmp = DeviceMem::create(execMem.outputPtr, 0);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstTmp, srcTmp, masterStream));
        return HCCL_SUCCESS;
    }

    // The slave stream execute an empty task to ensure synchronization
    HcclResult CollCustomMediumAllReduceMeshExecutor::SlaveExecEmptyTask(const OpParam &param, ExecMem &execMem, u32 streamIndex)
    {
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);
        DeviceMem srcTmp = DeviceMem::create(execMem.inputPtr, 0);
        DeviceMem dstTmp = DeviceMem::create(execMem.outputPtr, 0);
        CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstTmp, srcTmp, algResResp_->slaveStreams[streamIndex]));
        return HCCL_SUCCESS;
    }

    REGISTER_EXEC("CustomMediumAllReduceMeshExecutor", CustomMediumAllReduceMesh, CollCustomMediumAllReduceMeshExecutor);
} // namespace hccl

h

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#ifndef COLL_CUSTOM_MEDIUM_ALLREDUCE_MESH_EXECUTOR_H
#define COLL_CUSTOM_MEDIUM_ALLREDUCE_MESH_EXECUTOR_H

#include "coll_all_reduce_executor.h"

// Define the ranks affected by the broken link
#define BROKEN_RANK_X 0
#define BROKEN_RANK_Y 1
// Define the rank responsible for proxying the unhealthy ranks
#define PROXY_RANK_X 2
#define PROXY_RANK_Y 3

namespace hccl
{
    class CollCustomMediumAllReduceMeshExecutor : public CollAllReduceExecutor
    {
    public:
        CollCustomMediumAllReduceMeshExecutor(const HcclDispatcher dispatcher, std::unique_ptr<TopoMatcher> &topoMatcher);
        ~CollCustomMediumAllReduceMeshExecutor() = default;

    private:
        /* *************** 资源计算 *************** */
        HcclResult CalcScratchMemSize(u64 &scratchMemSize) override;
        HcclResult CalcStreamNum(u32 &streamNum) override;
        HcclResult CalcNotifyNum(u32 streamNum, u32 &notifyNum) override;
        HcclResult CalcCommInfo(std::vector<LevelNSubCommTransport> &opTransport) override;
        u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize);

        /* *************** 算法编排 *************** */
        HcclResult KernelRun(const OpParam &param, ExecMem &execMem) override;
        HcclResult HealthyPerformAllReduce(const OpParam &param, ExecMem &execMem);
        HcclResult UnhealthyPerformAllReduce(const OpParam &param, ExecMem &execMem);
        HcclResult ProxyPerformSend(const OpParam &param, ExecMem &execMem);
        HcclResult UnhealthyPerformReceive(const OpParam &param, ExecMem &execMem);
        bool IsHealthyRank(u32 rankId);
        bool IsProxyRank(u32 rankId);

        /* *************** 流间同步 *************** */
        HcclResult MainPostToSlaves(const OpParam &param, ExecMem &execMem);
        HcclResult SlavesWaitForMain(const OpParam &param, ExecMem &execMem);
        HcclResult SlavesPostToMain(const OpParam &param, ExecMem &execMem);
        HcclResult MainWaitForSlaves(const OpParam &param, ExecMem &execMem);
        HcclResult MainExecEmptyTask(const OpParam &param, ExecMem &execMem);
        HcclResult SlaveExecEmptyTask(const OpParam &param, ExecMem &execMem, u32 streamIndex);
    };
} // namespace hccl

#endif

算法性能

Licensed under CC BY-NC-SA 4.0
皖ICP备2025083746号-1
公安备案 陕公网安备61019002003315号



使用 Hugo 构建
主题 StackJimmy 设计