在 这篇文章 我提到过我们基于 vllm 实现了 PD 分离, 而且设计思路恰好与 Nvidia Dynamo 撞车了. 大致查询链路也与 Dynamo 相似, 简单来说是:
- 请求 R 从 D 节点接入,
- D 节点决策模块确定请求 R 应在某个 P 实例上完成, 此时 D 为 R 分配 kvcache block R.d_kv_blocks, 并会调用 P 实例 do_prefill(R, R.d_kv_blocks) rpc.
- P 实例负责完成 prefill, 并将 prefill 生成的 kvcache layer-by-layer 地通过 RDMA Write + GDR(GPU-Direct RDMA) 直接写入到 R.d_kv_blocks 中.
- P do_prefill rpc 会在 first token 生成, 以及 kvcache 传输完成之后返回 P prefill 生成的 first token. D 在 do_prefill rpc 响应之后在本地进行 decode 过程.
B.T.W, 我们这套方案对 vllm 引入性也非常的低, 基本上复用了 vllm 目前现有设施, 对 vllm 改动几乎为 0.
在本篇文章, 我来探究下 GDR 相关的一些细节.
内存序
如 mscclpp DeviceSyncer 真的能 sync 么? 所示, 我个人对 memory order 是比较敏感的, 所以接触到 GDR 之后第一反应, 为了实现必需的 Synchronization, 我是否需要做些什么?
参考 Nvidia GDR: Synchronization and Memory Ordering 文档, 目前我们 PD 分离实现中 P 端在发起 RDMA Write 以及 D 端在发起 forward 计算都遵循了这里提到的约定
性能影响
关于 GDR 我们一直存在的担忧点是其对 decode latency 的影响, 毕竟 decode 算是对 HBM 带宽比较敏感的, 而 P GDR RDMA Write 也会占用 D HBM 一部分带宽. 为了探究这个问题我们设计了一组实验.
- 首先以 PD 分离模式启动 P, D 实例. 这里使用 Qwen2 72B, tp=4. D worker 需要在 forward 前后进行打点:
+ zy_start_ts = time.perf_counter()
+ prev_rdma_rx_bytes = 0
+ prev_rdma_rx_packets = 0
+ if get_tp_group().is_first_rank and self.vllm_config.disagg_config and self.vllm_config.disagg_config.is_decode():
+ prev_rdma_rx_bytes, prev_rdma_rx_packets = get_rdma_rx_bytes_packets()
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
@@ -1076,6 +1092,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
+ torch.cuda.synchronize()
+ now = time.perf_counter()
+ foward_s = now - zy_start_ts
+ rdma_rx_bytes = 0
+ rdma_rx_packets = 0
+ if get_tp_group().is_first_rank and self.vllm_config.disagg_config and self.vllm_config.disagg_config.is_decode():
+ rdma_rx_bytes, rdma_rx_packets = get_rdma_rx_bytes_packets()
+ rdma_rx_bps = (rdma_rx_bytes - prev_rdma_rx_bytes) / foward_s
+ rdma_rx_pps = (rdma_rx_packets - prev_rdma_rx_packets) / foward_s
+ foward_ms = foward_s * 1000
+ num_reqs = self.input_batch.num_reqs
+ context_lens = self.input_batch.num_tokens[:num_reqs]
+ logger.info(f">>>ZYDBG. {foward_ms=} {len(input_ids)=} {num_reqs=} context_lens={np.min(context_lens)}|{np.max(context_lens)}|{np.sum(context_lens)} {rdma_rx_bps=} {rdma_rx_pps=}")
+
get_rdma_rx_bytes_packets
基于 pyroute2, linux rdma netlink 实现, 其源码见本文最后. 如上 worker log 输出日志示例:
INFO 04-17 16:58:46 [gpu_model_runner.py:1107] >>>ZYDBG. foward_ms=32.284699846059084 len(input_ids)=64 num_reqs=64 context_lens=2|2|128 rdma_rx_bps=0.0 rdma_rx_pps=0.0
- D 实例在启动 warmup 之后, 下发指定 batch size 个 input len = 1, ignore_eos=True, max_tokens=2000 请求. 我们会使用 context_len = 1500 之后的 500 个 step 的 decode latency 作为参考指标 DecodeLatency. 由于这些请求 input len = 1, 所以 D 实例决策模块会确定这些请求在 D 节点进行 prefill. 也即这 batch size 个请求的处理不会涉及到 PD 分离的逻辑.
if self.scheduler.disagg_config and self.scheduler.disagg_config.is_decode():
for i in range(batch_size):
self.add_req_threadsafe(EngineCoreRequest(
request_id = f'zhanyi-debug-disaggpd-gdr-{i}',
prompt=None,
mm_inputs=None,
mm_hashes=None,
eos_token_id=None,
lora_request=None,
mm_placeholders=None,
prompt_token_ids=[111308], # 您好
sampling_params=SamplingParams(ignore_eos=True, max_tokens=2000),
arrival_time=time.time(),
))
self.scheduler.max_num_scheduled_tokens = batch_size
这里同时会设置 scheduler.max_num_scheduled_tokens = batch_size
确保 scheduler, worker 只会处理这 batch size 个请求.
- 之后我会在 context_len = 500 左右启动 vllm benchmark_serving, 以指定 qps 压测 input len = 2000, output len = 7 的请求. 由于我们已经变更了 scheduler.max_num_scheduled_tokens, 对于这些请求, D 节点只会调用 P do_prefill rpc, 并在收到响应之后将请求加入到 D scheduler 上进行排队. 也即 最关键的, 除了 P 发起的 GDR RDMA Write 流量之外, 这些请求的存在对 D worker 进行 decode 没有任何影响! 在我们早期基于内部的推理框架实现的 PD 分离中, D worker 每次 step 都需要对这些请求做一些 PD 分离特有逻辑的处理, 导致不太好完成本节提到的实现.
如下数据为指定 qps 下, 指标 DecodeLatency 相对 qps0 时相对值. 可以看到 GDR 对 decode latency 基本没有影响.
DecodeLatency | qps1 | qps2 | qps3 |
---|---|---|---|
min | 0.0016539 | 0.0005427 | 0.0009722 |
max | -0.0032505 | -0.0020395 | -0.0029623 |
avg | -0.0008911 | -0.0007972 | -0.0010694 |
P.S. 测试使用了阿里云 ERDMA, 具体显卡类型以及指标绝对值我不太好把握, 所以有意不展示出来了, 相对值就已经可以说明情况了.
指定 qps 下, Prefill 发起的 GDR RDMA Write 流量如图所示:
可以看到 GDR RDMA Write 流量与压测 qps 关系不大, 这是因为就算是 qps1 已经足以跑满 prefill 每次 step, GDR RDMA Write 流量与 Prefill chunked prefill batched token 数目有关, 我这里使用的是我们线上值. 或许可以在不同 batched token 下跑下实验, 但我太懒了=.
rdma_stat
get_rdma_rx_bytes_packets 依赖的源码实现. 这应该是这篇文章最大的贡献了! 这也稍微抵消了我在上面不得不隐藏一些细节的内疚感…
from pyroute2.netlink import nlmsg, nla, NLA_F_NESTED
from pyroute2.netlink.nlsocket import AsyncNetlinkSocket, NetlinkRequest
from typing import Optional
import asyncio
# RDMA Netlink 协议常量, copy from iproute2
NLM_F_REQUEST = 0x1
NLM_F_ROOT = 0x100
NLM_F_MATCH = 0x200
NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
NETLINK_RDMA = 20
RDMA_NL_NLDEV = 5
RDMA_NLDEV_ATTR_DEV_INDEX = 1
RDMA_NLDEV_ATTR_DEV_NAME = 2
RDMA_NLDEV_ATTR_PORT_INDEX = 3
RDMA_NLDEV_ATTR_STAT_HWCOUNTERS = 80
RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY = 81
RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY_NAME = 82
RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY_VALUE = 83
RDMA_NLDEV_CMD_STAT_GET = 17
NLM_F_ACK = 0x4
# #define RDMA_NL_GET_TYPE(client, op) ((client << 10) + op)
def RDMA_NL_GET_TYPE(client, op):
return (client << 10) + op
RDMA_STATE_GET_TYPE = RDMA_NL_GET_TYPE(RDMA_NL_NLDEV, RDMA_NLDEV_CMD_STAT_GET)
class RdmaStatGetMsg(nlmsg):
nla_map = ((RDMA_NLDEV_ATTR_DEV_INDEX, 'DevIndex', 'uint32'),
(RDMA_NLDEV_ATTR_PORT_INDEX, 'PortIndex', 'uint32'))
@staticmethod
def new(devidx, portidx):
msg = RdmaStatGetMsg()
msg['attrs'] = [
# 对应 nla attr (RDMA_NLDEV_ATTR_DEV_INDEX, devidx)
['DevIndex', devidx],
['PortIndex', portidx]]
msg['header']['type'] = RDMA_STATE_GET_TYPE
msg['header']['flags'] = NLM_F_REQUEST | NLM_F_ACK
return msg
# iproute2 rdma stat 有 bug, 其使用了 u32
# strace -e 捕捉到的内核返回:
# \x68\x77\x5f\x72\x78\x5f\x62\x79\x74\x65\x73\x5f\x63\x6e\x74\x00\x0c\x00\x53\x00\xa2\x10\xbc\x24\x82\x02\x00\x00
# \x68\x77\x5f\x72\x78\x5f\x62\x79\x74\x65\x73\x5f\x63\x6e\x74, 即 hw_rx_bytes_cnt
# 其对应值为 \xa2\x10\xbc\x24\x82\x02\x00\x00. 但 rdma stat 输出:
# hw_rx_bytes_cnt 616304802, 对应着 \xa2\x10\xbc\x24.
# 搞得我排查了半天是不是我逻辑问题.
class RdmaStatGetResp(nlmsg):
# 在 decode 时会将不在 nla_map 的 nla 映射为 UNKNOWN NLA
# 对应: ('UNKNOWN', {'header': {'length': 12, 'type': 2}})
nla_map = ((RDMA_NLDEV_ATTR_DEV_INDEX, 'DevIndex', 'uint32'),
(RDMA_NLDEV_ATTR_DEV_NAME, 'DevName', 'asciiz'),
(RDMA_NLDEV_ATTR_STAT_HWCOUNTERS, 'HWC', 'HWC'),
(RDMA_NLDEV_ATTR_PORT_INDEX, 'PortIndex', 'uint32'))
class HWC(nla):
nla_flags = NLA_F_NESTED
# 这里缺少 ',' 会导致异常. 但 pyroute2 decode 并不会输出任何信息,
# 只是输出结果缺少 HWC attr.
nla_map = ((RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY, 'HWCE', 'HWCE'),)
# RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY
# print 输出 ('HWCE', {'attrs': [('Name', 'listen_create_cnt'), ('Value', 0)]}, 32768)
# 见 nla_slot.__repr__, 32768 是 flag.
class HWCE(nla):
# nla_flags 这里只影响 encode. 对 decode 没有影响.
nla_flags = NLA_F_NESTED
nla_map = ((RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY_NAME, 'Name', 'asciiz'),
# 内核 rdma netlink 这里使用的是本机字节序
(RDMA_NLDEV_ATTR_STAT_HWCOUNTER_ENTRY_VALUE, 'Value', 'uint64'))
# val: {'attrs': [('Name', 'listen_create_cnt'), ('Value', 0)]}
@staticmethod
def from_parsed(input: dict) -> tuple[str, int]:
name: Optional[str] = None
val: Optional[int] = None
for k, v in input['attrs']:
if k == 'Name':
name = v
continue
if k == 'Value':
val = v
continue
assert name is not None
assert val is not None
return name, val
_g_sock: Optional[AsyncNetlinkSocket] = None
def _get_sock():
# use thread local?
global _g_sock
if _g_sock is not None:
return _g_sock
_g_sock = AsyncNetlinkSocket(family=NETLINK_RDMA)
_g_sock.register_policy(RDMA_STATE_GET_TYPE, RdmaStatGetResp)
return _g_sock
async def get_rdma_stat(devidx, portidx) -> dict[str, int]:
sock = _get_sock()
msg = RdmaStatGetMsg.new(devidx, portidx)
req = NetlinkRequest(sock, msg)
await req.send()
ret: dict[str, int] = dict()
async for resp in req.response():
for k, v in resp['attrs']:
if k == 'DevIndex':
assert v == devidx
continue
if k == 'PortIndex':
assert v == portidx
continue
if k != 'HWC':
continue
hwc = v
for hwce_n, hwce_v in hwc['attrs']:
if hwce_n != 'HWCE':
continue
ck, cv = RdmaStatGetResp.HWC.HWCE.from_parsed(hwce_v)
assert ck not in ret
ret[ck] = cv
#print(resp)
return ret
if __name__ == '__main__':
import timeit
loop = asyncio.get_event_loop()
res = loop.run_until_complete(get_rdma_stat(0, 1))
print(res)
timeit_env = globals()
timeit_env.update(locals())
timeit_res = timeit.timeit('loop.run_until_complete(get_rdma_stat(0, 1))', number=1000, globals=timeit_env)
print(timeit_res)