Featured image of post HCCL通信库创新大赛记录——Butterfly

HCCL通信库创新大赛记录——Butterfly

HCCL集合通信库(Huawei Collective Communication Library,简称HCCL)是基于昇腾AI处理器的高性能通信库,聚焦于超大计算集群中的流量调度难题,为大集群提供高效可靠的通信服务,是华为AI软件生态CANN的核心组件之一,本文针对小数据量采用Butterfly算法

赛题回顾

环境:单机8卡昇腾910B服务器一台

赛题内容:在8PMesh中,断掉一根Full Mesh直连的路径(Rank0和Rank1之间的一根链路),完成AllReduce的集合通信语义

评分指标:使用算法分析工具验证,语义实现正确;在物理环境上运行HCCL Test工具,考察1KB、1MB、1GB这三种数据量下的算法带宽的均值,根据三种数据量下的算法带宽作为评判依据

算法思路

如前文所述,在小数据量的场景下,需要格外注重减少计算开销,即需要降低算法的复杂度、减少迭代次数。比较直观的思路是:从Rank2~Rank7中任选一个节点作为根节点,该根节点从其它全部节点拉取数据并聚合(由于根节点不是Rank0或Rank1,因此不受故障链路的影响),根节点在聚合完成后将数据推送到其它全部节点,共需2步即可完成操作。

在实际编程完成并性能测试时,意外发现上述设计的性能非常差,初步分析定位为多条流导致的同步开销。在HCCL中,流(stream)类似于CUDA里的stream,每个stream上的操作是顺序执行的,不同stream可以并行执行,提高通信带宽利用率。例如在根节点从其它全部节点拉取数据的过程中,从流0从Rank0拉取数据、从流1从Rank1拉取数据、……、从流7从Rank7拉取数据,而主流则需要等待所有从流完成数据的拉取,只有全部从流都完成任务后,缓存区里存放的才是最终的聚合结果,根节点才能将数据推送到其它全部节点(该过程同理需要多条从流执行,主流同样需要等待从流完成)。然而在小数据量场景下,使用过多的从流并不能提高性能,反而因同步开销导致性能衰减。

因此提升性能的核心在于算法仅需一条流即可完成AllReduce操作,即在主流上进行全部操作,不使用从流。Butterfly算法恰好满足这一需求,仅需3次通信轮次,每轮与固定对等节点交换数据,各节点仅需一条主流,适合数据量小、延迟敏感场景。

非常巧合的是,Rank1与Rank2在Butterfly算法中天然不发生通信,不过赛题中是Rank0与Rank1之间不能发生通信,那就重新映射下编号。物理节点Rank7、Rank0、Rank1、Rank2等物理节点的逻辑编号依次为Rank0、Rank1、Rank2、Rank3,其它节点以此类推。Butterfly算法中,逻辑节点Rank1与Rank2之间不通信,也就意味着物理节点Rank0与Rank1之间不通信。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
// Convert logical rank addresses to actual rank addresses
u32 CollCustomSmallAllReduceMeshExecutor::LogicalToPhysical(u32 logicalRank, u32 rankSize)
{
    return (logicalRank + 1) % rankSize;
}

// Convert actual rank addresses to logical rank addresses
u32 CollCustomSmallAllReduceMeshExecutor::PhysicalToLogical(u32 physicalRank, u32 rankSize)
{
    return (physicalRank + rankSize - 1) % rankSize;
}

算法设计

宏观流程如下:

  1. 在初始化的过程中:全部Rank将User Input中的输入数据全部拷贝至User Output(从远端把数据拉取来,与本地数据聚合)与CCL_OUT(准备好数据,供远端把数据拉走)中;
  2. 在每一轮迭代中:每块Rank从对等节点的CCL_OUT中拉取数据到本地User Output进行聚合规约;
  3. 若是最后一轮迭代,本地User Output中存放的即是最终的聚合结果;若不是最后一轮迭代,还需要将User Output中的数据拷贝到CCL_OUT中,以供下一轮迭代。

在第step轮迭代中的具体流程如下:

  1. 将本节点的物理编号通过PhysicalToLogical函数映射为逻辑编号;
  2. 根据逻辑编号计算在本轮迭代中的对等节点,即与谁交换数据;
  3. 将对等节点的逻辑编号通过LogicalToPhysical函数映射回物理编号;
  4. 将上一轮中位于User Output中的聚合数据拷贝到CCL_OUT中;
  5. 通过TxAck通知对等节点:“本节点已经将数据准备好了,随时可以从CCL_OUT拉取数据/向CCL_OUT推送数据”;
  6. 通过RxAck阻塞流,直至对等节点也准备就绪(即对等节点向本节点发送TxAck通知);
  7. 对等节点准备就绪后,即可从对等节点的CCL_OUT拉取数据,并聚合到本地的User Output;
  8. 完成数据的拉取、聚合后,通过TxDataSignal通知对等节点:“本节点已经完成操作,你可以释放资源或进行其它工作了”;
  9. 通过RxDataSignal阻塞流,直至对等节点也完成数据的拉取、聚合(即对等节点向本节点发送TxDataSignal通知)。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
1. 确定本轮通信的远端节点
    logicalRankIdx = PhysicalToLogical(rankIdx)
    logicalDstRankIdx = logicalRankIdx ^ (1 << step)
    dstRank = LogicalToPhysical(logicalDstRankIdx)
2. 搬运上一轮聚合数据到CCL_OUT
    copy(userOut, cclOut)
3. 通知远端节点就绪,并等待远端节点就绪
    links[dstRank]->TxAck()
    links[dstRank]->RxAck()
4. 从远端CCL_OUT拉取数据,并聚合到User Output
    reduce(remoteCclOut, localUserOut)
5. 通知远端节点拉取完成,并等待远端节点完成
    links[dstRank]->TxDataSignal()
    links[dstRank]-RxDataSignal()

编程实现

cc

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#include "coll_custom_small_all_reduce_mesh_executor.h"

namespace hccl
{
    CollCustomSmallAllReduceMeshExecutor::CollCustomSmallAllReduceMeshExecutor(const HcclDispatcher dispatcher, std::unique_ptr<TopoMatcher> &topoMatcher)
        : CollAllReduceExecutor(dispatcher, topoMatcher)
    {
        CCLMemSlice_ = false;
        DMAReduceFlag_ = true;
    }

    // Calculate the amount of scratch memory to request
    HcclResult CollCustomSmallAllReduceMeshExecutor::CalcScratchMemSize(u64 &scratchMemSize)
    {
        scratchMemSize = 0U;
        HCCL_WARNING("[HCCLContest][CollCustomSmallAllReduceMeshExecutor][CalcScratchMemSize] scratchMemSize: %u",
                     scratchMemSize);
        return HCCL_SUCCESS;
    }

    // Calculate the number of streams to be requested
    HcclResult CollCustomSmallAllReduceMeshExecutor::CalcStreamNum(u32 &streamNum)
    {
        streamNum = 0U;
        HCCL_WARNING("[HCCLContest][CollCustomSmallAllReduceMeshExecutor][CalcStreamNum] streamNum: %u", streamNum);
        return HCCL_SUCCESS;
    }

    // Calculate the number of Notify to be requested
    HcclResult CollCustomSmallAllReduceMeshExecutor::CalcNotifyNum(u32 streamNum, u32 &notifyNum)
    {
        notifyNum = 0U * streamNum;
        HCCL_WARNING("[HCCLContest][CollCustomSmallAllReduceMeshExecutor][CalcNotifyNum] notifyNum: %u", notifyNum);
        return HCCL_SUCCESS;
    }

    // Set up the level-0 mesh topology required for the AllReduce operation
    HcclResult CollCustomSmallAllReduceMeshExecutor::CalcCommInfo(std::vector<LevelNSubCommTransport> &opTransport)
    {
        HCCL_WARNING("[HCCLContest][CollCustomSmallAllReduceMeshExecutor][CalcNotifyNum]");

        // Define the source and destination memory types for communication
        TransportMemType inputType = TransportMemType::CCL_INPUT;
        TransportMemType outputType = TransportMemType::CCL_OUTPUT;
        // Construct a mesh topology for level 0
        CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH);
        commParaLevel0.meshSinglePlane = true;
        // Compute and populate the transport plan for level-0 communication domain
        CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType));
        return HCCL_SUCCESS;
    }

    // Calculate the number of iterations for loop processing
    u64 CollCustomSmallAllReduceMeshExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize)
    {
        u64 maxCountPerLoop = cclBuffSize / unitSize;
        HCCL_WARNING("[HCCLContest][CollCustomSmallAllReduceMeshExecutor][CalcLoopMaxCount] maxCountPerLoop: %u",
                     maxCountPerLoop);
        return maxCountPerLoop;
    }

    // Process data for a single iteration of the AllReduce algorithm execution
    HcclResult CollCustomSmallAllReduceMeshExecutor::KernelRun(const OpParam &param, ExecMem &execMem)
    {
        // Get sub-communication domain information for level 0
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);

        // Perform AllReduce
        PerformButterfly(param, execMem);

        HCCL_WARNING("[HCCLContest][CollCustomSmallAllReduceMeshExecutor][KernelRun] localRank: %u, localRankSize: %u",
                     level0CommInfo.localRank, level0CommInfo.localRankSize);

        return HCCL_SUCCESS;
    }

    // Perform AllReduce by Butterfly
    HcclResult CollCustomSmallAllReduceMeshExecutor::PerformButterfly(const OpParam &param, ExecMem &execMem)
    {
        // Get sub-communication domain information for level 0
        CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
        SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);

        // Dertermine the information for the topology
        u32 rankSize = level0CommInfo.localRankSize;
        u32 rankIdx = level0CommInfo.localRank;

        // Determine unit size, chunk size for this iteration (unit: bytes)
        u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
        u64 chunkSize = execMem.count * unitSize;

        // Retrieve the master stream
        hccl::Stream &masterStream = const_cast<hccl::Stream &>(param.stream);

        // Prepare for the memory device
        DeviceMem userInMem = DeviceMem::create(execMem.inputPtr, chunkSize);
        DeviceMem cclOutMem = execMem.outputMem.range(0, chunkSize);
        DeviceMem userOutMem = DeviceMem::create(execMem.outputPtr, chunkSize);

        // Perform Butterfly
        u32 nSteps = static_cast<u32>(log2(rankSize));
        for (u32 step = 0; step < nSteps; step++)
        {
            // Calculate the remote rank for this step
            u32 logicalRankIdx = PhysicalToLogical(rankIdx, rankSize);
            u32 logicalDstRankIdx = logicalRankIdx ^ (1 << step);
            u32 dstRank = LogicalToPhysical(logicalDstRankIdx, rankSize);

            // Prepare data for this step
            if (step == 0)
            {
                // Copy data from user input to user output and CCL_Out
                CHK_RET(HcclD2DMemcpyAsync(dispatcher_, userOutMem, userInMem, masterStream));
                CHK_RET(HcclD2DMemcpyAsync(dispatcher_, cclOutMem, userInMem, masterStream));
            }
            else
            {
                // Copy data from user output to CCL_Out
                CHK_RET(HcclD2DMemcpyAsync(dispatcher_, cclOutMem, userOutMem, masterStream));
            }

            // Notify the remote rank that it is ready
            CHK_RET(level0CommInfo.links[dstRank]->TxAck(masterStream));
            // Wait for the remote rank to be ready
            CHK_RET(level0CommInfo.links[dstRank]->RxAck(masterStream));

            // Reduce data from the remote rank's CCL_Out to local rank's user output
            void *srcRemotePtr = nullptr;
            CHK_RET(level0CommInfo.links[dstRank]->GetRemoteMem(UserMemType::OUTPUT_MEM, &srcRemotePtr));
            DeviceMem srcRemote = DeviceMem::create(static_cast<u8 *>(srcRemotePtr), chunkSize);
            CHK_RET(HcclReduceAsync(dispatcher_, srcRemote.ptr(), chunkSize / unitSize,
                                    param.DataDes.dataType, param.reduceType,
                                    masterStream, userOutMem.ptr(),
                                    level0CommInfo.links[dstRank]->GetRemoteRank(),
                                    level0CommInfo.links[dstRank]->GetLinkType(),
                                    INLINE_REDUCE_BIT));

            // Notify the remote rank that the transmission is finished
            CHK_RET(level0CommInfo.links[dstRank]->TxDataSignal(masterStream));
            // Wait for the remote rank to complete transmission
            CHK_RET(level0CommInfo.links[dstRank]->RxDataSignal(masterStream));
        }

        return HCCL_SUCCESS;
    }

    // Convert logical rank addresses to actual rank addresses
    u32 CollCustomSmallAllReduceMeshExecutor::LogicalToPhysical(u32 logicalRank, u32 rankSize)
    {
        return (logicalRank + 1) % rankSize;
    }

    // Convert actual rank addresses to logical rank addresses
    u32 CollCustomSmallAllReduceMeshExecutor::PhysicalToLogical(u32 physicalRank, u32 rankSize)
    {
        return (physicalRank + rankSize - 1) % rankSize;
    }

    REGISTER_EXEC("CustomSmallAllReduceMeshExecutor", CustomSmallAllReduceMesh, CollCustomSmallAllReduceMeshExecutor);
} // namespace hccl

h

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#ifndef COLL_CUSTOM_SMALL_ALLREDUCE_MESH_EXECUTOR_H
#define COLL_CUSTOM_SMALL_ALLREDUCE_MESH_EXECUTOR_H

#include "coll_all_reduce_executor.h"

namespace hccl
{
    class CollCustomSmallAllReduceMeshExecutor : public CollAllReduceExecutor
    {
    public:
        CollCustomSmallAllReduceMeshExecutor(const HcclDispatcher dispatcher, std::unique_ptr<TopoMatcher> &topoMatcher);
        ~CollCustomSmallAllReduceMeshExecutor() = default;

    private:
        /* *************** 资源计算 *************** */
        HcclResult CalcScratchMemSize(u64 &scratchMemSize) override;
        HcclResult CalcStreamNum(u32 &streamNum) override;
        HcclResult CalcNotifyNum(u32 streamNum, u32 &notifyNum) override;
        HcclResult CalcCommInfo(std::vector<LevelNSubCommTransport> &opTransport) override;
        u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize);

        /* *************** 算法编排 *************** */
        HcclResult KernelRun(const OpParam &param, ExecMem &execMem) override;
        HcclResult PerformButterfly(const OpParam &param, ExecMem &execMem);
        u32 LogicalToPhysical(u32 logicalRank, u32 rankSize);
        u32 PhysicalToLogical(u32 physicalRank, u32 rankSize);
    };
} // namespace hccl

#endif

算法性能

Licensed under CC BY-NC-SA 4.0
皖ICP备2025083746号-1
公安备案 陕公网安备61019002003315号



使用 Hugo 构建
主题 StackJimmy 设计