0%

理解 KV Cache

KV Cache:大模型推理优化的一种常用手段,通过空间换时间的思想,提高推理性能。本篇以 hf版 Llama 来做一个理解。

  • 参考 transformers 库 中 LlamaAttention,暂时位置嵌入。

  • 定义网络如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import torch.nn as nn

hidden_size = 1024
num_heads = 16
head_dim = 64
num_key_value_heads = 16

q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=False)
v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=False)
o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)

  • 模拟前项过程
  1. 预填充阶段,即用户输入,假设序列长度为 100
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
past_key_value = None
hidden_states = torch.randn(16, 100, 1024)
bsz, q_len, _ = hidden_states.size()

query_states = q_proj(hidden_states) # torch.Size([16, 100, 1024])
key_states = k_proj(hidden_states) # torch.Size([16, 100, 1024])
value_states = v_proj(hidden_states) # torch.Size([16, 100, 1024])

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) # torch.Size([16, 16, 100, 64])
key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) # torch.Size([16, 16, 100, 64])
value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) # torch.Size([16, 16, 100, 64])

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim) # torch.Size([16, 16, 100, 100])

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)

attn_output = o_proj(attn_output) # torch.Size([16, 100, 1024])

  1. 自回归推理阶段,每次生成一个token
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
hidden_states = torch.randn(16, 1, 1024)

query_states = q_proj(hidden_states) # torch.Size([16, 1, 1024])
key_states = k_proj(hidden_states) # torch.Size([16, 1, 1024])
value_states = v_proj(hidden_states) # torch.Size([16, 1, 1024])

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) # torch.Size([16, 16, 1, 64])
key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) # torch.Size([16, 16, 1, 64])
value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) # torch.Size([16, 16, 1, 64])

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) # 101

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim) # torch.Size([16, 16, 1, 101])

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)

attn_output = o_proj(attn_output) # torch.Size([16, 1, 1024])

  1. 可以验证,按照(100+1)cat到一起的向量计算的k、v和分开计算再cat到一起的是完全一样的。所以在自回归预测阶段没必要一直cat前面的toekn重新计算k和v;把之前计算的保存下来可以大大减少运算量。
  • 针对现有语音合成模型的思考

现在 zero-shot 语音合成模型基本套用了 LLM 的结构,在训练和推理过程中文本token和语音token是拼接起来的,如:

1
<spk embed> <p_t0> <p_t1> ... <s_t0> <s_t1>...<p_s0> <p_s1> ......

但是这种方式不利于预填充,中间存在待合成文本,这个是变化的。如果将训练方式改成:

1
<spk embed> <p_t0> <p_t1> ... <p_s0> <p_s1> ... <s_t0> <s_t1> ......

那么在针对 zero-shot 语音克隆时可以先将prompt进行预填充,这个永远是不变的,可以减少每次的推理计算,同时也能降低整体的响应时间。

  • KV Cache 详细讲解可以参考
  1. https://www.zhihu.com/question/596900067/answer/3040011798
  2. https://zhuanlan.zhihu.com/p/624740065