无中生有的 Nan

Posted by w@hidva.com on July 13, 2025

前言

很久之前, 我们基于 vllm v1 connector 实现了一套异步的 kvcache load/save 组件, 主要是觉得社区 vllm v1 kv connector 设计中:

  • 把 kvcache 未完成 load/save 的请求, 也放入 waiting 队列中, 依赖于一些 “空” 的 step 来更新 kvcache 的状态, 混淆了 kvcache load/save 和计算任务, 即带来复杂度又影响性能. 如这里 BUG 所示.

  • 同步接口会阻塞 step, 如 get_num_new_matched_tokens 这种可能需要调用外部服务 api, 比如目前 lmcache get_num_new_matched_tokens 实现.

因此我们期望分离 “计算” 和 “kvcache load/save”, kvcache load/save 未 ready 的请求, 对于 scheduler 来说是 zero-overhead 的. 进一步地我们将这套组件分为 connector, backend 两个模块, 这里 backend 只需要负责 kvcache 的传输, 加载, 存储; 编写者只需要了解 kvcache layout 以及对应后端存储相关知识即可, 对 vllm scheduler 相关细节完全不需要感知. connector 则负责为 backend 提供运行环境; 以及动态扩缩, 容错, 请求生命周期管理等. 今天的问题就出现在这个 backend 中.

追凶

我有一个同事最近新增了一个 backend, 负责将请求的 kvcache 保存到 global kvstore, 以及从 global kvstore 中为请求加载 kvcache. 然后他在调试的时候就会偶发遇到一个诡异的问题, 在开启这个 backend 之后, forward 计算会出现 nan. 经过他的不懈努力整理出了一个最小复现 case, 这个复现 case 为后续的问题排查起到了非常重要的作用, case 简单来说就是:

  主线程                connector 线程
model forward
model forward           with torch.cuda.stream(CONNECTOR_STREAM):
model forward              NEW_EVENT.record()
model forward

而且有个现象是 CONNECTOR_STREAM 如果在主线程创建, 则会出现 nan. 如果在 connector 线程创建, 则会报错: enqueue.cc: Cuda failure: 400 invalid resource handle.

//CUDACHECK(cudaLaunchKernelExC(&launchConfig, fnAddr, args));
CUCHECK(cuLaunchKernelEx(&launchConfig, fn, nullptr, extra));  // line: 1500
return ncclSuccess;

而 nccl 我可太熟了, 所以从这个入手, 祭出我的 debug container, 里面存放着一份 python, torch, nccl debug build 开始跑起来, 加个断点 info locals 然后看到 launchStream=0x7f1f20710500! 这个值有点诡异, 一般情况下集合通信所用 stream 都是 0(default stream) 或者比较简短的值. 等等这个值该不会是 CONNECTOR_STREAM 吧???! 继续加日志输出 {CONNECTOR_STREAM!r} 发现还真是!!!!

这也就意味着 connector 线程 torch.cuda.stream 影响到了主线程的 current stream!!!

# Code1
class StreamContext(AbstractContextManager):

    def __enter__(self):
        cur_stream = self.stream
        if cur_stream is None:
            return

        # _current_stream 就
        global _current_stream
        self.prev_stream = _current_stream
        _current_stream = cur_stream

风起

让 Qwen 帮忙写了个 demo 想单独验证下.

import torch
import threading
import time

def thread_with_new_stream():
    time.sleep(3)
    s = torch.cuda.Stream()
    print(f"[Thread-A] New CUDA Stream: {s!r}")
    with torch.cuda.stream(s):
        print(f"[Thread-A] Enter stream context: {torch.cuda.current_stream()!r}")
        time.sleep(1300)

def thread_print_stream():
    for i in range(400):
        # 注意:每个线程(即每个Python Thread)有自己的stream堆栈
        print(f"[Thread-B] {i}: Current stream: {torch.cuda.current_stream()!r}")
        time.sleep(1)

def main():
    torch.cuda.init()  # 确保 CUDA 初始化

    t1 = threading.Thread(target=thread_with_new_stream)
    t2 = threading.Thread(target=thread_print_stream)

    t1.start()
    t2.start()

    t1.join()
    t2.join()

if __name__ == "__main__":
    main()
$ python dbg_stream.py
[Thread-B] 0: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 1: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 2: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-A] New CUDA Stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f525c003040>
[Thread-A] Enter stream context: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f525c003040>
[Thread-B] 3: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 4: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 5: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>

可以看到 Thread-A torch.cuda.stream(s) 之后并不影响 Thread-B 啊!!! 定眼一看才发现 Code1 中的 StreamContext 是 torch/cpu/__init__.py 的, 而 torch/cuda/__init__.py 的 StreamContext 是没有问题的… 啊? 怎么办, 牛都吹出来了!!!

然后忽然意识到这种痛苦这种愚蠢! 我三年前, 曾经经历过…

回落

没办法, 继续咬牙看吧! 首先可以明确的一点是 gdb 确实清晰展示 nccl 使用了 CONNECTOR_STREAM!

#1  0x00007fb1044818a5 in ncclLaunchKernel (comm=0x11d591c0, plan=0x18e82398) at enqueue.cc:1505
1505      CUCHECK(cuLaunchKernel(fn, grid.x, grid.y, grid.z, block.x, block.y, block.z, smem, launchStream, nullptr, extra));
(gdb) info locals
errStr = 0x7fb0b0a576a0 "invalid resource handle"
err = CUDA_ERROR_INVALID_HANDLE
launchStream = 0x7fa1187104e0
$ grep -F 0x7fa1187104e0 nan1.log
(VllmWorker rank=2 pid=333471) loop thread: save! <torch.cuda.Stream device=cuda:0 cuda_stream=0x7fa1187104e0> <torch.cuda.Event 0x7fa1199edd50> os.getpid()=333471 <torch.cuda.Stream device=cuda:0 cuda_stream=0x7fa1187104e0>
(VllmWorker rank=2 pid=333471) loop thread: save! <torch.cuda.Stream device=cuda:0 cuda_stream=0x7fa1187104e0> <torch.cuda.Event 0x7fa1199edf40> os.getpid()=333471 <torch.cuda.Stream device=cuda:0 cuda_stream=0x7fa1187104e0>

大方向感觉是没有的! py-bt 看下 nccl all reduce 用的 stream 是哪来的:

(gdb) py-bt
Traceback (most recent call first):
  File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 291, in ncclAllReduce
    self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
  File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 127, in all_reduce
    self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
  File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/cuda_communicator.py", line 92, in all_reduce
    out = pynccl_comm.all_reduce(input_)
def all_reduce(self,
               in_tensor: torch.Tensor,
               op: ReduceOp = ReduceOp.SUM,
               stream=None) -> torch.Tensor:
    if stream is None:
        stream = current_stream()
    self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
                            buffer_type(out_tensor.data_ptr()),
                            in_tensor.numel(),
                            ncclDataTypeEnum.from_torch(in_tensor.dtype),
                            ncclRedOpTypeEnum.from_torch(op), self.comm,
                            cudaStream_t(stream.cuda_stream))

def current_stream() -> torch.cuda.Stream:
    from vllm.platforms import current_platform
    global _current_stream
    return _current_stream

def _patched_set_stream(stream: torch.cuda.Stream) -> None:
    global _current_stream
    _current_stream = stream
    prev_set_stream(stream)

torch.cuda.set_stream = _patched_set_stream

哈哈哈, 原来小鬼是你啊 vllm!!!

import torch
import threading
import time
import vllm.utils

def thread_with_new_stream():
    time.sleep(3)
    s = torch.cuda.Stream()
    print(f"[Thread-A] New CUDA Stream: {s!r}")
    with torch.cuda.stream(s):
        print(f"[Thread-A] Enter stream context: {torch.cuda.current_stream()!r}")
        print(f"[Thread-A] Current stream: {vllm.utils.current_stream()!r}")
        time.sleep(1300)

def thread_print_stream():
    for i in range(400):
        # 注意:每个线程(即每个Python Thread)有自己的stream堆栈
        print(f"[Thread-B] {i}: Current stream: {torch.cuda.current_stream()!r}")
        print(f"[Thread-B] {i}: Current stream: {vllm.utils.current_stream()!r}")
        time.sleep(1)

def main():
    torch.cuda.init()  # 确保 CUDA 初始化

    t1 = threading.Thread(target=thread_with_new_stream)
    t2 = threading.Thread(target=thread_print_stream)

    t1.start()
    t2.start()

    t1.join()
    t2.join()

if __name__ == "__main__":
    main()
$ python stream.py
[Thread-B] 2: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 2: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-A] New CUDA Stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
[Thread-A] Enter stream context: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
[Thread-A] Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
[Thread-B] 3: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 3: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>
[Thread-B] 4: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x0>
[Thread-B] 4: Current stream: <torch.cuda.Stream device=cuda:0 cuda_stream=0x7f8160003040>

后记

写到这里, 忽然意识:

而且有个现象是 CONNECTOR_STREAM 如果在主线程创建, 则会出现 nan. 如果在 connector 线程创建, 则会报错: enqueue.cc: Cuda failure: 400 invalid resource handle.

为啥 CONNECTOR_STREAM 在主线程创建不会导致 Cuda failure: 400 invalid resource handle 的原因也清楚了, 就是: