3月底整理了一个关于经典Paged Attention算法的ppt, 想起这个几年没写过的blog,把PPT改成一篇文章证明我还活着(-_-)。
vLLM 的 Paged Attention
开始前先说明一下,vLLM里的Paged Attention Kernel是有好几个不同的版本的,大概是下面这样子:
vLLM早期版本:
- Prefilling -> Flash Attention的flash_attn_varlen_func
 - Dedocding -> 自己实现的Paged Attention
- paged_attention_v1 : 用于比较短的sequence
 - paged_attention_v2 : 用于不想用v1的情况 :)
 
 
源码大概是这样的:
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
    # TODO(woosuk): Tune this heuristic.
    # For context len > 8192, use V2 kernel to avoid shared memory
    # shortage.
    use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512))
vLLM 最新版本就已经全部转向Flash Attention, 用cutlass实现了。
NVIDIA GPU 基础
在深入 Paged Attention 的实现之前,我们需要了解 NVIDIA GPU的基本架构。(这里我们主要讲A100)
在做开发时,GPU 的 CUDA 程序包括 Grid -> Thread Block -> Threads三层架构。 这三层架构对应到GPU的硬件:GPU -> SM -> Cuda Core
在实际执行的时候,Threads会以每32个为一组执行,因此这里还多了一层:Thread Wrap,因此结构变成:
- CUDA程序:Grid -> Thread Block -> Thread Wrap -> Threads四层架构。
 - GPU硬件:GPU -> SM(多次执行) -> SM(一次执行) -> Cuda Core,也是四层。
 
在 A100 GPU 上,我们有:
- 108个SM
 - 每个SM有4个Wrap scheduler
- 最多有4个Thread Wrap同时在一个SM上执行
 
 - 每个Wrap scheduler有一个长度为16的调度队列
- 一个SM上最多可以调度64个Thread Wrap
 
 
基本上这些数字在设计Kernel的时候都可以被考虑到,从而最大化一个Kernel的硬件利用率。
vLLM Kernel 映射
现在我们看一下vLLM Kernel的设计:(处于简化的目的,我们认为没有TP)
设计Kernel的第一步是把程序拆分成不同的Thread Block来简化问题,vLLM中每个Thread Block会负责1个Query Token的一个Query Head的计算。
这个设计其实比较粗糙。不过没关系,Flash Attention里有更多优化。
设计好计算粒度后,是内存布局的优化,vLLM的Kernel对Q,K, V使用了不同的内存布局,看代码:
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
内存设计的时候,我们一般要考虑的是如何能够更好的做到:1. 每次读取连续的内存块(向量读指令,最好能翻译成_ld.global.b128,以128bit也就是16Bytes为单位);2. 降低不同thread的读冲突。QKV的内存布局基本上也是考虑这些来做的:
- Q的布局是最简单的,序列长度->head数量->head维度。
 - K的布局比较复杂,最外层的num_block和num_kv_heads比较好理解,对应了一块KV Cache。但是在没有按照head_size连续存储,还引入了一个参数x。这个布局其实是为了优化K的读取效率,我们在后面再讲。
 - 最后是V的布局,比K要直接一些,没有额外的维度x。
 
基于这个设计,在列一下相关的代码:
  // 我们希望能用一个thread wrap一次处理一个KV cache block,这里block_size一般是4/8/16这样子。 
  // 因为wrap_size=32,大于cache block size,我们就可以给一个token多分配几个thread来加速计算。
  // 用wrap size 处理 cache_block_size,这样就知道一个wrap thread可以用几个Thread来处理一个token。
  // 这里数字被记作thread group size
  [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  assert(head_size % thread_group_size == 0);
  ...
  # 这里的Num threads是128,最后算出来4个wrap,应该也是对应了A100个一个SM有4个wrap scheduler。
  constexpr int NUM_WARPS = NUM的_THREADS / WARP_SIZE;
  ...
  # 没啥好说的,就是一个thread block用128个thread,处理一个token的一个head
  dim3 grid(num_heads, num_seqs, 1);
  dim3 block(NUM_THREADS);
随着Thread Group的提出,这个CUDA Kernel的架构变的更复杂了 :(
- CUDA程序:Grid -> Thread Block -> Thread Wrap -> Thread Group -> Threads五层架构。
 
Query数据访问
讲了这么多终于开始计算了,先拿张图演示一下Query Token的访问:
  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
       i += NUM_THREAD_GROUPS) {
    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
    q_vecs[thread_group_offset][i] =
        *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  }
对着这块代码讲一下: 可以看到这里因为一个block是专门处理一个token的一个head的,因此把这块数据存在了shared memory里,这样方便复用。
然后这里又引入了一个叫VEC_SIZE的东西,这里其实就是说如果我想一次读16Bytes(_ld.global.b128), 那每个thead一次要读几个元素(因为一共有thread group个线程一起读)。
然后就用各种size来算一下每个thread要读多少次vec,这多么元素又对应多少个VEC,我们就知道一个thread具体要读哪些数据了。
这里thread每次读取的粒度是vec, 而每个thread group每次则读取16Bytes,最后就可以合并成一个向量读的指令来优化IO。
按这个方式把数据都读进来吗,这里其实很多thread要读的数据是一样的,只要有一份数据完成读取,然后其他thread group就可以直接复用这个数据了
Key Cache数据访问
和Query一样,每个thread按VEC访问数据,也是希望一次访问16Bytes,这也是Query最内层有一个X维度的原因。
我们每x个元素的大小是16Bytes,那么每个Thread就可以按16Bytes的方式对数据进行读取。
因为我们需要处理Query Token和所有Key Token的乘积,因此这里要一个block一个block的把所有Key Token读进来。 整个读取过程大概是这样一个三层循环:
每个Thread Wrap通过循环的方式处理多个Paged Block
- 循环的每次大迭代处理一个Paged Block(内部通过两个小循环处理)
- 内层外循环:遍历block内每个Token (block_size次)
- 内层内循环:遍历每个Token的每个vec (head_size / (thread_group_size*vec_size))
 
 
 - 内层外循环:遍历block内每个Token (block_size次)
 
同样,只要一个Thread Wrap完成读取,Thread Block里其他Thread Wrap会复用这个读取结果。
看一个这个三重循环的代码:
  // 外层大循环:4个wrap一起遍历所有所有paged blocks
  // 每次循环:每个wrap处理一个paged block
      // 4 个wrap同时执行 (并行度)
  // wrap_idx = thread_is / wrap_size
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {
    ...
    // Load a key to registers.
    // Each thread in a thread group has a different part of the key.
    // For example, if the the thread group size is 4, then the first thread in
    // the group has 0, 4, 8, ... th vectors of the key, and the second thread
    // has 1, 5, 9, ... th vectors of the key, and so on.
    // 内层外循环:遍历所有paged blocks
    // 每次循环:32个thread一起遍历当前paged block内所有token
        // NUM_TOKENS_PER_THREAD_GROUP =      DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
        // 每个thread group负责一个token(token_idx)
    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
      const int physical_block_offset =
          (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
      K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
        // 内层内循环:按 VEC_SIZE 遍历Token内的所有fp16
        // 每次循环:32个thread一起遍历paged block内每个token
            // NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
            // NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
            // 每个thread group 为一组,处理同一个token
            // 每个threa一次读取一个VEC
            // 每个thread group一共负责 NUM_VECS_PER_THREAD 个VEC
            // K_vec k_vecs[NUM_VECS_PER_THREAD];  (1 Paged Block)
      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
        const cache_t* k_ptr =
            k_cache + physical_block_number * kv_block_stride +
            kv_head_idx * kv_head_stride + physical_block_offset * x;
        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
        const int offset1 = (vec_idx * VEC_SIZE) / x;
        const int offset2 = (vec_idx * VEC_SIZE) % x;
        if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
          k_vecs[j] = *reinterpret_cast<const K_vec*>(
              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
        } else {
          // Vector conversion from Quant_vec to K_vec.
          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
          k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
              k_vec_quant, *k_scale);
        }
      }
      // Compute dot product.
      ....
  }
QK 计算
Query和Key都读进来了,下一步自然就是矩阵乘了,还是先上个图:
基础的计算的逻辑就是上面代码里省略的部分
      // Compute dot product.
      // This includes a reduction across the threads in the same thread group.
      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
                             q_vecs[thread_group_offset], k_vecs);
      // Add the ALiBi bias if slopes are given.
      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
      if (thread_group_offset == 0) {
        // Store the partial reductions to shared memory.
        // NOTE(woosuk): It is required to zero out the masked logits.
        const bool mask = token_idx >= seq_len;
        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
        // Update the max value.
        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
      }
    }
然后要做reduce,因为所有qk的乘积是分布在多个Thread里面的,我们要把Thread block里所有的数据都聚合起来,才好算最后的softmax,这里是经典的两层reduce算法,细节就不解释了。
  // Perform reduction across the threads in the same warp to get the
  // max qk value for each "warp" (not across the thread block yet).
  // The 0-th thread of each thread group already has its max qk value.
#pragma unroll
  // 经典的Thread Wrap内reduce算法
  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
  }
  if (lane == 0) {
    red_smem[warp_idx] = qk_max;
  }
  __syncthreads();
  // 经典的Thread Block内reduce算法
  // TODO(woosuk): Refactor this part.
  // Get the max qk value for the sequence.
  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
  }
  // Broadcast the max qk value to all threads.
  qk_max = VLLM_SHFL_SYNC(qk_max, 0);
  // 计算softmax
  // Get the sum of the exp values.
  float exp_sum = 0.f;
  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
    float val = __expf(logits[i] - qk_max);
    logits[i] = val;
    exp_sum += val;
  }
  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
  // Compute softmax.
  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
    logits[i] *= inv_sum;
  }
  __syncthreads();
跑到这里以后,softmax就算完了,剩下来的就是在乘一下V Cache,这块就比较简单了。
Value访问和Attneion计算
Value的访问比较直接,直接上图,就是大家一起把所有Value都读进来。

同时边读边计算,都在这个图里了:

就是先按block进行遍历,然后每次重block里读所有token的一部分维度进行计算,把所有维度都算出来,分散的存在各个thread里。
  // 一样的外层大循环:4个wrap一起遍历所有paged blocks
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {
    // NOTE(woosuk): The block number is stored in int32. However, we cast it to
    // int64 because int32 can lead to overflow when this variable is multiplied
    // by large numbers (e.g., kv_block_stride).
    const int64_t physical_block_number =
        static_cast<int64_t>(block_table[block_idx]);
    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
    L_vec logits_vec;
    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
                                                           start_token_idx));
    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                           kv_head_idx * kv_head_stride;
#pragma unroll
    // NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
    // 一个Paged Block有几个VEC
    // NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
        // 一个WRAP可以同时处理几个 Paged Block 
    // NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
        // 遍历所有head纬度需要几个迭代
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE) {
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
        V_vec v_vec;
        if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
        } else {
          V_quant_vec v_quant_vec =
              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
          // Vector conversion from V_quant_vec to V_vec.
          v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
                                                                    *v_scale);
        }
        if (block_idx == num_seq_blocks - 1) {
          // num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
          // 对最后一个Paged Block特殊处理,防止越界
          // NOTE(woosuk): When v_vec contains the tokens that are out of the
          // context, we should explicitly zero out the values since they may
          // contain NaNs. See
          // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
          for (int j = 0; j < V_VEC_SIZE; j++) {
            v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
          }
        }
        // 按Vec分段计算和V的乘积:由多个thread计算,后续需要进一步聚合。
        accs[i] += dot(logits_vec, v_vec);
      }
    }
  }
等算法以后再进行reduce把结果聚合,得到最后的结果。这个是经典的reduce算法:
// Perform reduction within each warp.
// 经典in wrap规约算法
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    float acc = accs[i];
#pragma unroll
    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
      acc += VLLM_SHFL_XOR_SYNC(acc, mask);
    }
    accs[i] = acc;
  }
  // NOTE(woosuk): A barrier is required because the shared memory space for
  // logits is reused for the output.
  __syncthreads();
  // Perform reduction across warps.
  // 经典树状cross wrap规约算法
  float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
  for (int i = NUM_WARPS; i > 1; i /= 2) {
    int mid = i / 2;
    // Upper warps write to shared memory.
    if (warp_idx >= mid && warp_idx < i) {
      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          dst[row_idx] = accs[i];
        }
      }
    }
    __syncthreads();
    // Lower warps update the output.
    if (warp_idx < mid) {
      const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          accs[i] += src[row_idx];
        }
      }
    }
    __syncthreads();
  }
算完以后就是把结果写到output了,这块就不贴了。
收尾
大概流程就是这样,整个流程还是设计不少细节的,我也不知道写清楚没有,不过对着代码多看几遍总能看明白的。
Last modified on 2025-04-20