NCCL 源码解读(17): Primitives Simple

Posted by w@hidva.com on March 9, 2025

本文是 NCCL 源码解读系列之一, NCCL 源码解读系列使用 NCCL 版本: v2.25.1-1. 本文介绍 Primitives<ProtoSimple> 相关实现细节.

前景提要

在之前关于 p2p transport 的文章中, 我们已经介绍了在发送方会在 p2pSendSetup 时分配 sizeof(ncclSendMem) 大小显存, 接收方会分配 sizeof(ncclRecvMem) + buffSizes[LL] + buffSizes[LL128] + buffSizes[SIMPLE] 大小显存, 如下所示:

    sender                               receiver
+----------------------+              +------------+
|  head: u64           |              |  tail: u64 |
|  ptrExchange: void*  |              |ncclRecvMem |
|    ncclSendMem       |              |            |
+----------------------+              +------------+ <- ncclConnInfo.buffs[LL]
                                      |   LL       |
                                      | BufSiz     |
                                      +------------+ <- ncclConnInfo.buffs[LL128]
                                      | LL128      |
                                      | BufSiz     |
                                      +------------+ <- ncclConnInfo.buffs[SIMPLE]
                                      | Simple     |
                                      | BufSiz     |
                                      +------------+

之后发送方在 p2pSendConnect() 时会连接接收方, 此时 sender.ncclDevComm.channels[channelId].peers[peerRank].send[connIndex] 将保存相关信息. 同样接收方会在 p2pRecvConnect 时连接发送方, 此时相关信息保存在 receiver.ncclDevComm.channels[channelId].peers[peerRank].recv[connIndex] 中; 此时这俩结构中同名字段将具有相同的值:

struct ncclConnInfo {
  char *buffs[NCCL_NUM_PROTOCOLS]; // Local for recv, remote for send
  // tail 为 &ncclRecvMem.tail
  uint64_t *tail;     // Local for recv, remote for send
  // head 为 &ncclSendMem.head
  uint64_t *head;     // Local for send, remote for recv
  // ptrExchange 为 &ncclSendMem.ptrExchange
  void **ptrExchange; // Pointer exchange for direct communication
}

Primitives 构造

Ok! 在我们了解如上背景知识之后, 首先看下 Primitives 是如何使用的. Primitives 实现了 send,recv,reduce 这些基本的通信原语. 第一步很明显是构造一个 Primitives 对象:

auto prims = Primitives<T, RedOp, Fan=FanSymmetric<1>, Direct=1, Proto, P2p=0, isNetOffload=false>(
  tid, nthreads,
  recvPeers=&ring->prev, sendPeers=&ring->next,
  inputBuf=work->sendbuff, outputBuf=work->recvbuff,
  work->redOpArg,
  group=0,
  connIndexRecv=0, connIndexSend=0,
  collWork=work);

nthreads, 指定了当前 CTA 中参与通信的线程数, 其一定是 WARP_SIZE 的整数倍. tid 则是当前线程标识, 其范围 [0, nthreads), 其与 tidInBlock 值可以认为是一样的. group 与 NetDeviceUnpack 特性相关, 我们暂且忽略.

MaxRecv=Fan::MaxRecv=FanSymmetric<1>=1 指定了 recvPeers 指向内存块的容量. nrecv,nsend 为本次参与通信的 receiver, sender 个数, 即指定了当前 rank 应该从 nrecv 个发送方处接受数据, 并将数据发往 nsend 个接收方.

int nrecv=0, nsend=0;
while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++;
while (nsend < MaxSend && sendPeers[nsend] != -1) nsend++;
this->fan = Fan(nrecv, nsend);

connIndexRecv,connIndexSend 对应着 peers[peerRank].recv[connIndex] 中的 connIndex:

struct ncclDevChannelPeer {
  struct ncclConnInfo send[NCCL_MAX_CONNS];
  struct ncclConnInfo recv[NCCL_MAX_CONNS];
};

这里 NCCL_MAX_CONNS=2, connIndex=1, connIndex=0 分别对应着不同的使用场景, connIndex=1 一般用在 nvls, p2p, COLLNET_DIRECT 等这些场景中.

Primitives 会在构造时为参与到通信的线程分配对应的角色, 如下是 nrecv=2, nsend=2, nthreads=32 时各个线程角色划分:

  tid       0        1       2       3       4  ...  28       29       30       31
 flag   WaitRecv WaitRecv WaitSend WaitSend        PostRecv PostRecv PostSend PostSend
index       0        1       0       1                0        1        0        1
 peer   rPeer[0] rPeer[1] sPeer[0] sPeer[1]        rPeer[0] rPeer[1] sPeer[0] sPeer[1]

这里 index,peer 指定了 peer 端信息. 比如对于 WaitRecv 线程来说, index,peer 表明该线程负责与 peer 发送方交互.

loadRecvConn/loadSendConn, 根据各个线程的角色结合 ncclConnInfo 中信息初始化相应字段:

        WaitRecv                             WaitSend                 PostRecv                PostSend
connStepPtr = conn.tail               connStepPtr = conn.head  connStepPtr = conn.head  connStepPtr = conn.tail
connStepCache=*connStepPtr            connStepCache
connStepSize=conn.stepSize/sizeof(T)  connStepSize
connEltsFifo=conn.buffs[SIMPLE]       connEltsFifo                                      connEltsFifo
flag |= DirectWrite                   flag |= DirectWrite

recvProvider                          sendAcceptor
directBuff = recvbuff                 directBuff = *ptrExchange
*ptrExchange =
  recvbuffRmtAddrs[localPeer] +
  recvbuffOffset

如上 WaitSend 下 connStepCache 意味着这个线程对 connStepCache 字段进行了设置, 其值等同于最左端 WaitRecv 赋值逻辑, 即在 WaitSend/PostSend 线程中, connStepCache 字段也是等于 *connStepPtr.

Direct,DirectWrite 与 User Buffer Registration 特性有关, 简单来说 Direct=0 时, 数据会先写入 nccl 内部缓冲区中, 之后再拷贝到用户缓冲区. 而在 DirectWrite 情况下, 数据可以直接写往用户指定的缓冲区中. 比如在 p2p 情况下, 以 ncclAllReduce(const void* sendbuff, void* recvbuff) 为例, 若用户事先为 recvbuff 所载内存块调用过 ncclCommRegister() 进行了注册, 则 nccl 会将 recvbuff 注册到每一个 send peer 处. coll.recvbuffRmtAddrs[localPeer] 存放着 coll.recvbuff - coll.recvbuffOffset 在 localPeer 处的映射, 即 localPeer 可以通过 coll.recvbuffRmtAddrs[localPeer] + coll.recvbuffOffset 来访问到 recvbuff.

DirectRead 与 p2p read 有关. p2p read 目前看仅在 Ampere 系列显卡上启用. flag 不可能同时设置 DirectRead 以及 DirectWrite.

setDataPtrs(), 对于参与到通信的线程来说, recvProvider/sendAcceptor/sendProvider/recvAcceptor 至多只有一个为 true. 我们这里只关注 recvProvider/sendAcceptor, 简单来说 recvProvider 通过 ptrExchange 字段将 recvbuff 在 send 端的映射告知了 send.

step,slice,chunk

这里我们以 directSend(inpIx, outIx, eltN) 过程来引入一次发送过程的实现, directSend 其语义很直观, 将本端 inpIx 偏移处 eltN 个元素发送到接收方 outIx 偏移处; 这里是以元素为单位的偏移, 而不是以字节. 这里 eltN 可能 <=0 意味着 noop. 此时对应 genericOp 各个模板参数的值:

DirectRecv = 0, DirectSend = 1, Recv = 0, Send = 1, SrcBuf = Input, DstBuf = -1
Src = 1, Dst = 0,

在 DirectWrite 情况下, step, slice, chunk 这些概念不太具体, 更多的是抽象用于计数的. 所以我们以 DirectWrite=0 链路看下 step, slice, chunk 这几个概念, 此时数据会写入到 nccl 维护的缓冲区中, 即 connEltsFifo=buffs[SIMPLE] 中, nccl 会将 buffs[SIMPLE] 切分为 NCCL_STEPS 个块, 每个块可以容纳元素个数为 Primitives.stepSize, stepSize = buffSizes[SIMPLE] / NCCL_STEPS / sizeof(T). Primitives.connStepSize 是发送方/接收方协商的 stepSize, 而 Primitives.stepSize 是 ncclComm 粒度的 stepSize, 暂且认为这俩是一回事.

                                            tail
                                             |
+--------------+--------------+--------------+--------------+
|   step0      |  step1       |  step2       |  step3       |
+--------------+--------------+--------------+--------------+
               |
             head

StepPerSlice 指定了一个 slice 占用多少 step, genericOp.sliceSize 指定了一个 slice 可以容纳的元素个数. 对于 directSend -> genericOp 来说其任务就是将 [inpIx, inpIx + eltN) 处元素发往 [outIx, outIx + eltN) 处. genericOp.srcIx = inpIx, genericOp.dstIx = outIx. genericOp 一次只会传输一个 slice; offset 指定了待传输 slice 在 [inpIx, inpIx + eltN) 中的偏移, 即 userInput + srcIx + offset 为待传输 slice 的起始地址. slice 指定了当前已经传输了多少个 slice 了. 这里先不管那么多, 先走几遍 genericOp 对一个 slice 的传输过程:

      tid0                                   WaitSend                         PostSend
srcs[0] = userInput + srcIx + offset
                   wait conn.head + NCCL_STEPS >= step + StepPerSlice
                   dsts[0] = connEltsFifo + (step%NCCL_STEPS)*connStepSize
                   step += StepPerSlice, step=StepPerSlice
          ALL THREADS: reduceCopy(srcs[0], dst[0], sliceSize)
                                                                         step += StepPerSlice,step=StepPerSlice
                                                                         conn.tail = step
offset += sliceSize; slice += 1;

srcs[0] = userInput + srcIx + offset
                   wait conn.head + NCCL_STEPS >= step + StepPerSlice
                   dsts[0] = connEltsFifo + (step%NCCL_STEPS)*connStepSize
                   step += StepPerSlice,step=StepPerSlice * 2
          ALL THREADS: reduceCopy(srcs[0], dst[0], sliceSize)
                                                                         step += StepPerSlice,step=StepPerSlice * 2
                                                                         conn.tail = step

走一遍传输过程之后, 各个模块的语义, 各个字段啥意思就很清晰了是吧! 现在我们再来看传输过程涉及到的各个函数/字段:

waitPeer(), 负责为本次 slice 分配目标空间. 在 DirectWrite = 1 时, dst[0] = directBuff + dstIx + offset, 这里 directBuff 对应着接收方 recvbuff, 即本 slice 会直接写入到接收方 recvbuff 指定偏移处. 在 DirectWrite = 0 时, slice 需要先写往 connEltsFifo 缓冲区中. connEltsFifo,conn.head,conn.tail 就是一个传统的 SPSC ring buffer, < conn.tail 之前的 step 已经被发送方填充完成, < conn.head 之前的 step 已经被接收方消费完毕. WaitSend 线程 step 值 WaitSend.step 记录着最近一次已经完成填充的 step buffer 下标, 其等同于 conn.tail. conn.head + NCCL_STEPS < conn.tail + StepPerSlice 意味着当前 connEltsFifo 中剩余空间不足以存放一个 slice, 所以这里会一直循环等待接收方消费 connEltsFifo 中 step buffer, 在接收方完成消费之后会前进 conn.head, 从而可以让 connEltsFifo 有足够的空间容纳一个 slice.

P.S. 哎, 这里 step 有时候表示一块 buffer, 有时候表示 buffer 下标.

waitPeer 在等到 connEltsFifo 有足够空间容纳一个 slice 之后, 会将空间地址保存在 dsts[0] = connEltsFifo + (step%NCCL_STEPS)*connStepSize 中. 并同时前进 WaitSend.step, 此时 connEltsFifo 中下标位于 [origin WaitSend.step, current WaitSend.step) 范围内的 buffer 尚未填充, 等待被接下来的 reduceCopy 填充.

reduceCopy() 完成具体的发送工作, 其内容不在本文展开.

postPeer(), 会前进 PostSend 线程 step 值, PostSend.step 并更新 conn.tail 告知消费端指定 step buffer 已经填充完成可以消费了.

chunk, 从 genericOp 实现来看, genericOp 一次调用最多只会传输 SlicePerChunk 个 slice. 那么问题来了, 在调用 genericOp 时我们要求其要传输 nelem 个元素. 有没有可能在 genericOp 已经完成 SlicePerChunk 个 slice 传输之后, 传输量仍然小于 nelem? 即 SlicePerChunk * sliceSize >= nelem 是否一定成立?! Yes!

sliceSize = max(divUp(nelem, 16*SlicePerChunk)*16, sliceSize/32)

V = divUp(nelem, 16*SlicePerChunk), 则 16 * SlicePerChunk * V >= nelem, 而 sliceSize >= 16 * V, 所以 SlicePerChunk * sliceSize >= nelem. 但 sliceSize 也不能太大, 如上 waitPeer 所示, 其要求一个 slice 大小不能超过 StepPerSlice * connStepSize, 即要求用户在调用 genericOp() 时 nelem 不能太大. 所幸 genericOp 都是 nccl 内部调用, 其会确保这一要求. 不过我还是觉得在 genericOp() 中加个断言更合适一点.

nworkers,nthreads

见如上 genericOp 对一个 slice 的传输过程, 可知每一个线程大致工作流:

  1. waitPeer()
  2. barrier()
  3. reduceCopy()
  4. barrier()
  5. postPeer()

即每一个线程在执行实际的传输之前都需要 wait postPeer, 注意如上过程是一个循环, 其在循环 iter=2 中的 barrier 需要等待 PostSend 线程在 iter=1 的 postPeer. 为了避免等待, nccl 会在 nthreads 足够大时, 专门划分一个 WARP 负责 postPeer, 这样 [0, nworkers) 范围内的线程只需要无脑传输, [nworkers, nthreads) 范围内的线程会执行 postPeer. 这样第 2 步的 barrier() 就可以换成 subBarrier() 了. 这里 barrier() 语义上等同于 __syncthreads(), 与 __syncthreads() 要求 CTA 中所有线程都要到达同步点不同, barrier() 只需要要求 CTA 中前 nthreads 个线程到达同步点. 而 subBarrier 更轻量, 只需要要求前 nworkers 个线程到达同步点.

话说有没有思考过为啥要以 NCCL_SIMPLE_EXTRA_GROUP_IF_NTHREADS_GE, 即 3 * WARP_SIZE 作为是否划分 nworkers 的阈值? 这是因为: