PD 分离中的 GDR

Posted by w@hidva.com on April 17, 2025

这篇文章 我提到过我们基于 vllm 实现了 PD 分离, 而且设计思路恰好与 Nvidia Dynamo 撞车了. 大致查询链路也与 Dynamo 相似, 简单来说是:

  1. 请求 R 从 D 节点接入,
  2. D 节点决策模块确定请求 R 应在某个 P 实例上完成, 此时 D 为 R 分配 kvcache block R.d_kv_blocks, 并会调用 P 实例 do_prefill(R, R.d_kv_blocks) rpc.
  3. P 实例负责完成 prefill, 并将 prefill 生成的 kvcache layer-by-layer 地通过 RDMA Write + GDR(GPU-Direct RDMA) 直接写入到 R.d_kv_blocks 中.
  4. P do_prefill rpc 会在 first token 生成, 以及 kvcache 传输完成之后返回 P prefill 生成的 first token. D 在 do_prefill rpc 响应之后在本地进行 decode 过程.

B.T.W, 我们这套方案对 vllm 引入性也非常的低, 基本上复用了 vllm 目前现有设施, 对 vllm 改动几乎为 0.

disaggpd-gdr.1.jpg

在本篇文章, 我来探究下 GDR 相关的一些细节.

内存序

mscclpp DeviceSyncer 真的能 sync 么? 所示, 我个人对 memory order 是比较敏感的, 所以接触到 GDR 之后第一反应, 为了实现必需的 Synchronization, 我是否需要做些什么?

参考 Nvidia GDR: Synchronization and Memory Ordering 文档, 目前我们 PD 分离实现中 P 端在发起 RDMA Write 以及 D 端在发起 forward 计算都遵循了这里提到的约定

性能影响

关于 GDR 我们一直存在的担忧点是其对 decode latency 的影响, 毕竟 decode 算是对 HBM 带宽比较敏感的, 而 P GDR RDMA Write 也会占用 D HBM 一部分带宽. 为了探究这个问题我们设计了一组实验.

  1. 首先以 PD 分离模式启动 P, D 实例. 这里使用 Qwen2 72B, tp=4. D worker 需要在 forward 前后进行打点:
+        zy_start_ts = time.perf_counter()
+        prev_rdma_rx_bytes = 0
+        prev_rdma_rx_packets = 0
+        if get_tp_group().is_first_rank and self.vllm_config.disagg_config and self.vllm_config.disagg_config.is_decode():
+            prev_rdma_rx_bytes, prev_rdma_rx_packets = get_rdma_rx_bytes_packets()
         # Run the decoder.
         # Use persistent buffers for CUDA graphs.
         with set_forward_context(attn_metadata, self.vllm_config):
@@ -1076,6 +1092,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
                 intermediate_tensors=intermediate_tensors,
                 inputs_embeds=inputs_embeds,
             )
+        torch.cuda.synchronize()
+        now = time.perf_counter()
+        foward_s = now - zy_start_ts
+        rdma_rx_bytes = 0
+        rdma_rx_packets = 0
+        if get_tp_group().is_first_rank and self.vllm_config.disagg_config and self.vllm_config.disagg_config.is_decode():
+            rdma_rx_bytes, rdma_rx_packets = get_rdma_rx_bytes_packets()
+        rdma_rx_bps = (rdma_rx_bytes - prev_rdma_rx_bytes) / foward_s
+        rdma_rx_pps = (rdma_rx_packets - prev_rdma_rx_packets) / foward_s
+        foward_ms = foward_s * 1000
+        num_reqs = self.input_batch.num_reqs
+        context_lens = self.input_batch.num_tokens[:num_reqs]
+        logger.info(f">>>ZYDBG. {foward_ms=} {len(input_ids)=} {num_reqs=} context_lens={np.min(context_lens)}|{np.max(context_lens)}|{np.sum(context_lens)} {rdma_rx_bps=} {rdma_rx_pps=}")
+

get_rdma_rx_bytes_packets 基于 pyroute2, linux rdma netlink 实现, 其源码见本文最后. 如上 worker log 输出日志示例:

INFO 04-17 16:58:46 [gpu_model_runner.py:1107] >>>ZYDBG. foward_ms=32.284699846059084 len(input_ids)=64 num_reqs=64 context_lens=2|2|128 rdma_rx_bps=0.0 rdma_rx_pps=0.0
  1. D 实例在启动 warmup 之后, 下发指定 batch size 个 input len = 1, ignore_eos=True, max_tokens=2000 请求. 我们会使用 context_len = 1500 之后的 500 个 step 的 decode latency 作为参考指标 DecodeLatency. 由于这些请求 input len = 1, 所以 D 实例决策模块会确定这些请求在 D 节点进行 prefill. 也即这 batch size 个请求的处理不会涉及到 PD 分离的逻辑.
if self.scheduler.disagg_config and self.scheduler.disagg_config.is_decode():
    for i in range(batch_size):
        self.add_req_threadsafe(EngineCoreRequest(
            request_id = f'zhanyi-debug-disaggpd-gdr-{i}',
            prompt=None,
            mm_inputs=None,
            mm_hashes=None,
            eos_token_id=None,
            lora_request=None,
            mm_placeholders=None,
            prompt_token_ids=[111308],  # 您好
            sampling_params=SamplingParams(ignore_eos=True, max_tokens=2000),
            arrival_time=time.time(),
        ))
    self.scheduler.max_num_scheduled_tokens = batch_size

这里同时会设置 scheduler.max_num_scheduled_tokens = batch_size 确保 scheduler, worker 只会处理这 batch size 个请求.

  1. 之后我会在 context_len = 500 左右启动 vllm benchmark_serving, 以指定 qps 压测 input len = 2000, output len = 7 的请求. 由于我们已经变更了 scheduler.max_num_scheduled_tokens, 对于这些请求, D 节点只会调用 P do_prefill rpc, 并在收到响应之后将请求加入到 D scheduler 上进行排队. 也即 最关键的, 除了 P 发起的 GDR RDMA Write 流量之外, 这些请求的存在对 D worker 进行 decode 没有任何影响! 在我们早期基于内部的推理框架实现的 PD 分离中, D worker 每次 step 都需要对这些请求做一些 PD 分离特有逻辑的处理, 导致不太好完成本节提到的实现.

如下数据为指定 qps 下, 指标 DecodeLatency 相对 qps0 时相对值. 可以看到 GDR 对 decode latency 基本没有影响.

DecodeLatency qps1 qps2 qps3
min 0.0016539 0.0005427 0.0009722
max -0.0032505 -0.0020395 -0.0029623
avg -0.0008911 -0.0007972 -0.0010694

P.S. 测试使用了阿里云 ERDMA, 具体显卡类型以及指标绝对值我不太好把握, 所以有意不展示出来了, 相对值就已经可以说明情况了.

指定 qps 下, Prefill 发起的 GDR RDMA Write 流量如图所示:

disaggpd-gdr.2.jpg

可以看到 GDR RDMA Write 流量与压测 qps 关系不大, 这是因为就算是 qps1 已经足以跑满 prefill 每次 step, GDR RDMA Write 流量与 Prefill chunked prefill batched token 数目有关, 我这里使用的是我们线上值. 或许可以在不同 batched token 下跑下实验, 但我太懒了=.

rdma_stat

get_rdma_rx_bytes_packets 依赖的源码实现. 这应该是这篇文章最大的贡献了! 这也稍微抵消了我在上面不得不隐藏一些细节的内疚感…

from pyroute2.netlink import nlmsg, nla, NLA_F_NESTED
from pyroute2.netlink.nlsocket import AsyncNetlinkSocket, NetlinkRequest
from typing import Optional
import asyncio

# RDMA Netlink 协议常量, copy from iproute2
NLM_F_REQUEST = 0x1
NLM_F_ROOT = 0x100
NLM_F_MATCH = 0x200
NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
NETLINK_RDMA = 20
RDMA_NL_NLDEV = 5
RDMA_NLDEV_ATTR_DEV_INDEX = 1
RDMA_NLDEV_ATTR_DEV_NAME = 2
RDMA_NLDEV_ATTR_PORT_INDEX = 3
RDMA_NLDEV_ATTR_STAT_HWCOUNTERS = 80
RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY = 81
RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY_NAME = 82
RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY_VALUE = 83
RDMA_NLDEV_CMD_STAT_GET = 17
NLM_F_ACK = 0x4

# #define RDMA_NL_GET_TYPE(client, op) ((client << 10) + op)
def RDMA_NL_GET_TYPE(client, op):
  return (client << 10) + op

RDMA_STATE_GET_TYPE = RDMA_NL_GET_TYPE(RDMA_NL_NLDEV, RDMA_NLDEV_CMD_STAT_GET)

class RdmaStatGetMsg(nlmsg):
  nla_map = ((RDMA_NLDEV_ATTR_DEV_INDEX, 'DevIndex', 'uint32'),
             (RDMA_NLDEV_ATTR_PORT_INDEX, 'PortIndex', 'uint32'))

  @staticmethod
  def new(devidx, portidx):
    msg = RdmaStatGetMsg()
    msg['attrs'] = [
      # 对应 nla attr (RDMA_NLDEV_ATTR_DEV_INDEX, devidx)
      ['DevIndex', devidx],
      ['PortIndex', portidx]]
    msg['header']['type'] = RDMA_STATE_GET_TYPE
    msg['header']['flags'] = NLM_F_REQUEST | NLM_F_ACK
    return msg

# iproute2 rdma stat 有 bug, 其使用了 u32
# strace -e 捕捉到的内核返回:
# \x68\x77\x5f\x72\x78\x5f\x62\x79\x74\x65\x73\x5f\x63\x6e\x74\x00\x0c\x00\x53\x00\xa2\x10\xbc\x24\x82\x02\x00\x00
# \x68\x77\x5f\x72\x78\x5f\x62\x79\x74\x65\x73\x5f\x63\x6e\x74, 即 hw_rx_bytes_cnt
# 其对应值为 \xa2\x10\xbc\x24\x82\x02\x00\x00. 但 rdma stat 输出:
# hw_rx_bytes_cnt 616304802, 对应着 \xa2\x10\xbc\x24.
# 搞得我排查了半天是不是我逻辑问题.

class RdmaStatGetResp(nlmsg):
  # 在 decode 时会将不在 nla_map 的 nla 映射为 UNKNOWN NLA
  # 对应: ('UNKNOWN', {'header': {'length': 12, 'type': 2}})
  nla_map = ((RDMA_NLDEV_ATTR_DEV_INDEX, 'DevIndex', 'uint32'),
             (RDMA_NLDEV_ATTR_DEV_NAME, 'DevName', 'asciiz'),
             (RDMA_NLDEV_ATTR_STAT_HWCOUNTERS, 'HWC', 'HWC'),
             (RDMA_NLDEV_ATTR_PORT_INDEX, 'PortIndex', 'uint32'))

  class HWC(nla):
    nla_flags = NLA_F_NESTED
    # 这里缺少 ',' 会导致异常. 但 pyroute2 decode 并不会输出任何信息,
    # 只是输出结果缺少 HWC attr.
    nla_map = ((RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY, 'HWCE', 'HWCE'),)

    # RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY
    # print 输出 ('HWCE', {'attrs': [('Name', 'listen_create_cnt'), ('Value', 0)]}, 32768)
    # 见 nla_slot.__repr__, 32768 是 flag.
    class HWCE(nla):
      # nla_flags 这里只影响 encode. 对 decode 没有影响.
      nla_flags = NLA_F_NESTED
      nla_map = ((RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY_NAME, 'Name', 'asciiz'),
                 # 内核 rdma netlink 这里使用的是本机字节序
                 (RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY_VALUE, 'Value', 'uint64'))

      # val: {'attrs': [('Name', 'listen_create_cnt'), ('Value', 0)]}
      @staticmethod
      def from_parsed(input: dict) -> tuple[str, int]:
        name: Optional[str] = None
        val: Optional[int] = None
        for k, v in input['attrs']:
          if k == 'Name':
            name = v
            continue
          if k == 'Value':
            val = v
            continue
        assert name is not None
        assert val is not None
        return name, val


_g_sock: Optional[AsyncNetlinkSocket] = None

def _get_sock():
  # use thread local?
  global _g_sock
  if _g_sock is not None:
    return _g_sock
  _g_sock = AsyncNetlinkSocket(family=NETLINK_RDMA)
  _g_sock.register_policy(RDMA_STATE_GET_TYPE, RdmaStatGetResp)
  return _g_sock


async def get_rdma_stat(devidx, portidx) -> dict[str, int]:
  sock = _get_sock()

  msg = RdmaStatGetMsg.new(devidx, portidx)
  req = NetlinkRequest(sock, msg)
  await req.send()
  ret: dict[str, int] = dict()
  async for resp in req.response():
    for k, v in resp['attrs']:
      if k == 'DevIndex':
        assert v == devidx
        continue
      if k == 'PortIndex':
        assert v == portidx
        continue
      if k != 'HWC':
        continue
      hwc = v
      for hwce_n, hwce_v in hwc['attrs']:
        if hwce_n != 'HWCE':
          continue
        ck, cv = RdmaStatGetResp.HWC.HWCE.from_parsed(hwce_v)
        assert ck not in ret
        ret[ck] = cv
    #print(resp)
  return ret


if __name__ == '__main__':
  import timeit

  loop = asyncio.get_event_loop()
  res = loop.run_until_complete(get_rdma_stat(0, 1))
  print(res)

  timeit_env = globals()
  timeit_env.update(locals())
  timeit_res = timeit.timeit('loop.run_until_complete(get_rdma_stat(0, 1))', number=1000, globals=timeit_env)
  print(timeit_res)