0%

Grouped Code Modeling

GCM:Grouped Code Modeling(GCM) 是 Vall-E 2 提出的一种加速模型训练及推理的方法,同时可以使得建模的序列更长。核心观点就是可以将相邻的多个 speech token 合并建模。主要还是 codec 的帧率太大了,原始的 encodec 帧率是75;CosyVoice1 帧率是 50,CosyVoice2 帧率是 25。但是实际上针对语音语义token来说,帧率是10左右就够了。

理论部分

Vall-E 2还是分为了自回归和非自回归两个模块,GCM主要是在自回归模块应用,如下图所示:

gcm

这个是论文中关于GCM的定义:

valle2paper1

valle2paper2

代码部分

实际代码实现起来还是比较容易的:

  • 在训练过程中,核心就是这几个事情:1️⃣怎么分组;2️⃣分组之后过LLM;3️⃣逆向分组过程;4️⃣计算loss。
1
2
3
4
5
# 1.分组过程:拼接 token 的 embedding,假设 gcm 系数为2,embedding dim 为 1024,则得到的分组后的 token embedding 为 2048
speech_token = speech_token.view(speech_token.size(0), speech_token.size(1) // self.gcm, self.gcm*speech_token.size(2))
# 2.分组之后需要经过 Linear 层映射到 LLM 的维度
# 3.逆向分组过程:LLM的最后一层输出经过一个 Linear 映射到 gcm * dim 的维度,然后再通过 view 函数得到没有分组的序列
# 4.计算 CE Loss,这里需要注意的是构造 target 和 input 的时候切记不要引用未来信息。就是不要用这一组合并的 token 去预测合并前的 token。要不然准确率飙升,模型其实学错了!
  • 在推理过程,核心就是 LLM 前向一次,需要多次采样,采样得到这一组的所有 token,这样就提升了推理速度了。

实验部分

这里按照 CosyVoice2 进行了实验,

补齐 CosyVoice2 缺少的 forward:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Qwen2Encoder(torch.nn.Module):
def __init__(self, pretrain_path):
super().__init__()
# self.config = Qwen2Config.from_pretrained(pretrain_path)
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
# self.model = Qwen2ForCausalLM(self.config)

def forward_one_step(self, xs, masks, cache=None):
input_masks = masks[:, -1, :]
outs = self.model(
inputs_embeds=xs,
attention_mask=input_masks,
output_hidden_states=True,
return_dict=True,
use_cache=True,
past_key_values=cache,
)
xs = outs.hidden_states[-1]
new_cache = outs.past_key_values
return xs, new_cache

def forward(self, xs, lm_input_len):
max_length = max(lm_input_len)
attention_mask = torch.tensor([[1] * length + [0] * (max_length - length) for length in lm_input_len]).to(torch.bool)
causal_mask = torch.tril(torch.ones((max_length, max_length))).to(torch.bool)
final_mask = causal_mask.unsqueeze(0) & attention_mask.unsqueeze(1)

outs = self.model(
inputs_embeds=xs,
attention_mask=final_mask,
output_hidden_states=True,
return_dict=True,
use_cache=False,
)
return outs

构造新的 Qwen2LM_GCM 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class Qwen2LM_GCM(torch.nn.Module):
def __init__(
self,
llm_input_size: int,
llm_output_size: int,
speech_token_size: int,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
gcm: int = 2, # group code modeling, can be 1, 2, 4, 8 ... from vall-e2
):
super().__init__()
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size

# 2. build speech token language model related modules
self.sos_eos = 0
self.task_id = 1
self.fill_token = 2

self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm

self.group_embedding = nn.Linear(llm_input_size*gcm, llm_input_size)
self.group_prediction = nn.Linear(llm_output_size, llm_output_size*gcm)

self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 3,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)

# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)

# 4. sampling method
self.sampling = sampling

self.gcm = gcm

def sampling_ids(
self,
weighted_scores: torch.Tensor,
decoded_tokens: List,
sampling: int,
ignore_eos: bool = True,
):
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
return top_ids

def pad_unpad_sequence(self, sos_eos_emb, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len

def de_group(self, lm_output, text_token_len, speech_token_len):
sos_eos_emb = []
text_token = []
task_id_emb = []
speech_token = []
for t, i in enumerate(lm_output):
sos_eos_emb.append(i[0, :])
text_token.append(i[1:text_token_len[t]+1])
task_id_emb.append(i[text_token_len[t]+1, :])
speech_token.append(i[text_token_len[t]+2:text_token_len[t]+2+speech_token_len[t], :])
return sos_eos_emb, text_token, task_id_emb, speech_token

def con_cat(self, sos_eos_emb, text_token, task_id_emb, speech_token, speech_token_len):
lm_output = []
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
for i in range(len(text_token)):
lm_output.append(torch.concat([sos_eos_emb[i].unsqueeze(dim=0), text_token[i],
task_id_emb[i],
speech_token[i]]))
lm_output = pad_sequence(lm_output, batch_first=True, padding_value=IGNORE_ID)
return lm_output

def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
# import pdb
# pdb.set_trace()
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# embedding = batch['embedding'].to(device)

# 0. pad speech token to 整数倍(pad 静音对应的 token_v2)对应 4299
speech_token_list = []
speech_token_len_list = []
# import pdb
# pdb.set_trace()
for i, sl in enumerate(speech_token_len):
st = speech_token[i, :sl]
# print(st)
if sl % self.gcm != 0:
speech_token_list.append(F.pad(st, (0, (sl//self.gcm+1)*self.gcm-sl), value=4299))
speech_token_len_list.append((sl//self.gcm+1)*self.gcm)
else:
speech_token_list.append(st)
speech_token_len_list.append(sl)
speech_token = pad_sequence(speech_token_list, batch_first=True, padding_value=0).to(device)
speech_token_len = torch.tensor(speech_token_len_list)

# 1. prepare llm_target
# lm_target = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
# [self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
[self.speech_token_size] * self.gcm) for i in range(text_token.size(0))]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)

# 1. encode text_token
text_token = self.llm.model.model.embed_tokens(text_token)

# 2. embedding projection
# embedding = F.normalize(embedding, dim=1)
# embedding = self.spk_embed_affine_layer(embedding)
# embedding = embedding.unsqueeze(1)
# embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text_token.dtype).to(device)

# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)

# 4. encode speech_token
speech_token = self.speech_embedding(speech_token)
speech_token = speech_token.view(speech_token.size(0), speech_token.size(1) // self.gcm, self.gcm*speech_token.size(2))
speech_token = self.group_embedding(speech_token) # h*cfg --> h
# group task_id_emb (start code)
task_id_emb = torch.concat([task_id_emb] * self.gcm, dim=2)
task_id_emb = self.group_embedding(task_id_emb)

speech_token_len_gcm = speech_token_len//self.gcm
# 5. unpad and pad
# lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len_gcm)


# 6. run lm forward
lm_output = self.llm(lm_input, lm_input_len.to(device))
# hidden_states
lm_output = lm_output['hidden_states'][-1]

#
# de group
sos_eos_emb, text_token, task_id_emb, speech_token = self.de_group(lm_output, text_token_len, speech_token_len_gcm)
speech_token = pad_sequence(speech_token, batch_first=True, padding_value=IGNORE_ID)
speech_token = self.group_prediction(speech_token)
# de group embedding
speech_token = speech_token.view(speech_token.size(0), speech_token.size(1) * self.gcm, speech_token.size(2)//self.gcm)

task_id_emb = pad_sequence(task_id_emb, batch_first=True, padding_value=IGNORE_ID)
task_id_emb = task_id_emb.unsqueeze(1)
task_id_emb = self.group_prediction(task_id_emb)
task_id_emb = task_id_emb.view(task_id_emb.size(0), task_id_emb.size(1) * self.gcm, task_id_emb.size(2)//self.gcm)

# cat output
lm_output = self.con_cat(sos_eos_emb, text_token, task_id_emb, speech_token, speech_token_len)
# lm_target 多个 stop token 可能需要 mask!这里没有处理!
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target)
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}

@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
# import pdb
# pdb.set_trace()
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.llm.model.model.embed_tokens(text)

# 2. encode embedding
# embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)

# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
if prompt_speech_token_len % self.gcm != 0:
prompt_speech_token = F.pad(prompt_speech_token, (0, (prompt_speech_token_len//self.gcm+1)*self.gcm-prompt_speech_token_len), value=4299)
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)

# group code modeling
prompt_speech_token_emb = prompt_speech_token_emb.view(prompt_speech_token_emb.size(0), prompt_speech_token_emb.size(1) // self.gcm, self.gcm*prompt_speech_token_emb.size(2))
prompt_speech_token_emb = self.group_embedding(prompt_speech_token_emb) # h*cfg --> h
# group task_id_emb (start code)
task_id_emb = torch.concat([task_id_emb] * self.gcm, dim=2) # [1, 1, 2048]
task_id_emb = self.group_embedding(task_id_emb) # [1, 1, 1024]

lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)

# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)

# 5. step by step decode
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
y_pred_split = self.group_prediction(y_pred[:, -1])
y_pred_split = y_pred_split.view(y_pred_split.size(0) * self.gcm, y_pred_split.size(1)//self.gcm)
logp = self.llm_decoder(y_pred_split).log_softmax(dim=-1)
llm_input = []
for gcm_id in range(self.gcm):
logp_i = logp[gcm_id]
top_ids_i = self.sampling_ids(logp_i.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
# print(top_ids_i)
if top_ids_i == self.speech_token_size:
break
yield top_ids_i
out_tokens.append(top_ids_i)
llm_input_i = self.speech_embedding.weight[top_ids_i].reshape(1, 1, -1)
llm_input.append(llm_input_i)
if len(llm_input) < self.gcm:
break
llm_input = torch.concat(llm_input, dim=1)
llm_input = llm_input.view(llm_input.size(0), llm_input.size(1) // self.gcm, self.gcm*llm_input.size(2))
lm_input = self.group_embedding(llm_input)

结果

  • prompt:希望你以后能够做的比我还好呦。
  • 合成文本:收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。

  • 原始模型合成

  • GCM模型合成

总体上是读对了的,表现力还没有办法和原始的比。主要是训练数据用的比较少,而且为了降低字错率,在非常干净的数据上微调了模型,所以总体降低了韵律表现。

总结

基本验证了 GCM 的方法!但是这样会丢失 CosyVoice2 模型本身的流式、非流式推理,还有Instruct等功能。后续都可以加回来的,看怎么构造序列了。