前言
很久之前, 我们基于 vllm v1 connector 实现了一套异步的 kvcache load/save 组件, 主要是觉得社区 vllm v1 kv connector 设计中:
-
把 kvcache 未完成 load/save 的请求, 也放入 waiting 队列中, 依赖于一些 “空” 的 step 来更新 kvcache 的状态, 混淆了 kvcache load/save 和计算任务, 即带来复杂度又影响性能. 如这里 BUG 所示.
-
同步接口会阻塞 step, 如 get_num_new_matched_tokens 这种可能需要调用外部服务 api, 比如目前 lmcache get_num_new_matched_tokens 实现.
因此我们期望分离 “计算” 和 “kvcache load/save”, kvcache load/save 未 ready 的请求, 对于 scheduler 来说是 zero-overhead 的. 进一步地我们将这套组件分为 connector, backend 两个模块, 这里 backend 只需要负责 kvcache 的传输, 加载, 存储; 编写者只需要了解 kvcache layout 以及对应后端存储相关知识即可, 对 vllm scheduler 相关细节完全不需要感知. connector 则负责为 backend 提供运行环境; 以及动态扩缩, 容错, 请求生命周期管理等. 今天的问题就出现在这个 backend 中.
追凶
我有一个同事最近新增了一个 backend, 负责将请求的 kvcache 保存到 global kvstore, 以及从 global kvstore 中为请求加载 kvcache. 然后他在调试的时候就会偶发遇到一个诡异的问题, 在开启这个 backend 之后, forward 计算会出现 nan. 经过他的不懈努力整理出了一个最小复现 case, 这个复现 case 为后续的问题排查起到了非常重要的作用, case 简单来说就是:
主线程 connector 线程
model forward
model forward with torch.cuda.stream(CONNECTOR_STREAM):
model forward NEW_EVENT.record()
model forward
而且有个现象是 CONNECTOR_STREAM 如果在主线程创建, 则会出现 nan. 如果在 connector 线程创建, 则会报错: enqueue.cc: Cuda failure: 400 invalid resource handle
.
//CUDACHECK(cudaLaunchKernelExC(&launchConfig, fnAddr, args));
CUCHECK(cuLaunchKernelEx(&launchConfig, fn, nullptr, extra)); // line: 1500
return ncclSuccess;
而 nccl 我可太熟了, 所以从这个入手, 祭出我的 debug container, 里面存放着一份 python, torch, nccl debug build 开始跑起来, 加个断点 info locals 然后看到 launchStream=0x7f1f20710500! 这个值有点诡异, 一般情况下集合通信所用 stream 都是 0(default stream) 或者比较简短的值. 等等这个值该不会是 CONNECTOR_STREAM 吧???! 继续加日志输出 {CONNECTOR_STREAM!r}
发现还真是!!!!
这也就意味着 connector 线程 torch.cuda.stream
影响到了主线程的 current stream!!!
# Code1
class StreamContext(AbstractContextManager):
def __enter__(self):
cur_stream = self.stream
if cur_stream is None:
return
# _current_stream 就
global _current_stream
self.prev_stream = _current_stream
_current_stream = cur_stream
风起
让 Qwen 帮忙写了个 demo 想单独验证下.
import torch
import threading
import time
def thread_with_new_stream():
time.sleep(3)
s = torch.cuda.Stream()
print(f"[Thread-A] New CUDA Stream: {s!r}")
with torch.cuda.stream(s):
print(f"[Thread-A] Enter stream context: {torch.cuda.current_stream()!r}")
time.sleep(1300)
def thread_print_stream():
for i in range(400):
# 注意:每个线程(即每个Python Thread)有自己的stream堆栈
print(f"[Thread-B] {i}: Current stream: {torch.cuda.current_stream()!r}")
time.sleep(1)
def main():
torch.cuda.init() # 确保 CUDA 初始化
t1 = threading.Thread(target=thread_with_new_stream)
t2 = threading.Thread(target=thread_print_stream)
t1.start()
t2.start()
t1.join()
t2.join()
if __name__ == "__main__":
main()
$ python dbg_stream.py
[Thread-B] 0: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 1: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 2: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-A] New CUDA Stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f525c003040>
[Thread-A] Enter stream context: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f525c003040>
[Thread-B] 3: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 4: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 5: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
可以看到 Thread-A torch.cuda.stream(s)
之后并不影响 Thread-B 啊!!! 定眼一看才发现 Code1 中的 StreamContext 是 torch/cpu/__init__.py
的, 而 torch/cuda/__init__.py
的 StreamContext 是没有问题的… 啊? 怎么办, 牛都吹出来了!!!
然后忽然意识到这种痛苦这种愚蠢! 我三年前, 曾经经历过…
回落
没办法, 继续咬牙看吧! 首先可以明确的一点是 gdb 确实清晰展示 nccl 使用了 CONNECTOR_STREAM!
#1 0x00007fb1044818a5 in ncclLaunchKernel (comm=0x11d591c0, plan=0x18e82398) at enqueue.cc:1505
1505 CUCHECK(cuLaunchKernel(fn, grid.x, grid.y, grid.z, block.x, block.y, block.z, smem, launchStream, nullptr, extra));
(gdb) info locals
errStr = 0x7fb0b0a576a0 "invalid resource handle"
err = CUDA_ERROR_INVALID_HANDLE
launchStream = 0x7fa1187104e0
$ grep -F 0x7fa1187104e0 nan1.log
(VllmWorker rank=2 pid=333471) loop thread: save! <torch.cuda.Stream device=cuda:0 cuda_stream=0x7fa1187104e0> <torch.cuda.Event 0x7fa1199edd50> os.getpid()=333471 <torch.cuda.Stream device=cuda:0 cuda_stream=0x7fa1187104e0>
(VllmWorker rank=2 pid=333471) loop thread: save! <torch.cuda.Stream device=cuda:0 cuda_stream=0x7fa1187104e0> <torch.cuda.Event 0x7fa1199edf40> os.getpid()=333471 <torch.cuda.Stream device=cuda:0 cuda_stream=0x7fa1187104e0>
大方向感觉是没有的! py-bt
看下 nccl all reduce 用的 stream 是哪来的:
(gdb) py-bt
Traceback (most recent call first):
File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 291, in ncclAllReduce
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 127, in all_reduce
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/cuda_communicator.py", line 92, in all_reduce
out = pynccl_comm.all_reduce(input_)
def all_reduce(self,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if stream is None:
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def current_stream() -> torch.cuda.Stream:
from vllm.platforms import current_platform
global _current_stream
return _current_stream
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
global _current_stream
_current_stream = stream
prev_set_stream(stream)
torch.cuda.set_stream = _patched_set_stream
哈哈哈, 原来小鬼是你啊 vllm!!!
import torch
import threading
import time
import vllm.utils
def thread_with_new_stream():
time.sleep(3)
s = torch.cuda.Stream()
print(f"[Thread-A] New CUDA Stream: {s!r}")
with torch.cuda.stream(s):
print(f"[Thread-A] Enter stream context: {torch.cuda.current_stream()!r}")
print(f"[Thread-A] Current stream: {vllm.utils.current_stream()!r}")
time.sleep(1300)
def thread_print_stream():
for i in range(400):
# 注意:每个线程(即每个Python Thread)有自己的stream堆栈
print(f"[Thread-B] {i}: Current stream: {torch.cuda.current_stream()!r}")
print(f"[Thread-B] {i}: Current stream: {vllm.utils.current_stream()!r}")
time.sleep(1)
def main():
torch.cuda.init() # 确保 CUDA 初始化
t1 = threading.Thread(target=thread_with_new_stream)
t2 = threading.Thread(target=thread_print_stream)
t1.start()
t2.start()
t1.join()
t2.join()
if __name__ == "__main__":
main()
$ python stream.py
[Thread-B] 2: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 2: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-A] New CUDA Stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
[Thread-A] Enter stream context: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
[Thread-A] Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
[Thread-B] 3: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 3: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
[Thread-B] 4: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 4: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
后记
写到这里, 忽然意识:
而且有个现象是 CONNECTOR_STREAM 如果在主线程创建, 则会出现 nan. 如果在 connector 线程创建, 则会报错: enqueue.cc: Cuda failure: 400 invalid resource handle.
为啥 CONNECTOR_STREAM 在主线程创建不会导致 Cuda failure: 400 invalid resource handle 的原因也清楚了, 就是: