0%

MQTTS

MQTTS: A Vector Quantized Approach for Text to Speech Synthesis on Real-World Spontaneous Speech
通过将传统以 Mel谱 为中间件建模的方式,改变为以 多个矢量组 为中间件建模的方式,解决真实世界中自发的口语化语音难以建模的问题。利用该方法可以使用大数据量(如 WeNetSpeech)训练TTS模型,合成语音更加真实自然。

Paper

论文:https://arxiv.org/abs/2302.04215

论文附录:https://cmu.box.com/s/7ghw0bgkbqv5e7hu5jsznhlzuo4rexgx

代码:https://github.com/b04901014/MQTTS

Some Sample

  • first issue um but going back to what he’s saying i’m hundred percent you have have snow tires um a rear wheel drive car is going to .
MQTTS(LJ-Speech embedding) MS-Jenny(微软Jenny发音人)
wav

论文及代码详解

框架图如下:
MQTTS

模型推断大概的框架是(后续会详细解释推理细节):

  1. 音素和说话人信息作为条件信息
  2. 自回归方式形成 矢量
  3. 矢量 通过预训练的解码器得到语音波形

模型训练主要分为两个部分,

Quantization of Raw Speech

矢量量化框架主要借鉴 HiFi-GAN 的结构。Quantizer Decoder 的结构就是 HiFi-GAN 的生成器,Quantizer Encoder 就是把生成器的转置卷积换成了卷积操作。这里通用利用了对抗训练,判别器的结构还是和 HiFi-GAN 的判别器是一样的。

主要的区别是有 多码本 的一个学习过程,在训练过程中不仅有对抗损失,同时加入了矢量量化损失。

这里可以利用Meta发布的预训练模型 encodec,微软的 VALL-E 就是这样的。

  • Encoder

一维卷积降低维度,膨胀卷积扩大感受野,降采样的过程是 [2, 2, 8, 8], 将[b, 1, 8192]维度的音频降低到 [b, 512, 32]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class Encoder(torch.nn.Module):
def __init__(self, h):
super(Encoder, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(Conv1d(1, 32, 7, 1, padding=3))
self.normalize = nn.ModuleList()
resblock = ResBlock1 if h.resblock == '1' else ResBlock2

self.ups = nn.ModuleList()
for i, (u, k) in enumerate(list(reversed(list(zip(h.upsample_rates, h.upsample_kernel_sizes))))):
self.ups.append(weight_norm(
Conv1d(32*(2**i), 32*(2**(i+1)),
k, u, padding=(u//2 + u%2))))

self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = 32*(2**(i+1))
for j, (k, d) in enumerate(zip(list(reversed(h.resblock_kernel_sizes)), list(reversed(h.resblock_dilation_sizes)))):
self.resblocks.append(resblock(h, ch, k, d))
self.normalize.append(torch.nn.GroupNorm(ch // 16, ch, eps=1e-6, affine=True))

self.conv_post = Conv1d(512, 512, 3, 1, padding=1)
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)

def forward(self, x):
# x.shape [b, 1, 8192], 8192 为音频的点的个数,沿用了 HiFi-GAN 中的
x = self.conv_pre(x)
# x.shape [b, 32, 8192]
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
xs = self.normalize[i*self.num_kernels+j](xs)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
xs = self.normalize[i*self.num_kernels+j](xs)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
return x

def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)

  • Quantizer

简单理解矢量量化,就是四舍五入。这里利用 nn.embedding 去学习码本。我们通过上述 Encoder 进行了时间维度的压缩,压缩率为 256,即 1秒 16000 采样率的音频压缩之后就变成了 6.25 个时序点。但是存储每个点的这个值的范围还是很大的,就拿 16bit 来说,它可能的取值就是-32768到32767。这么大的范围给后续自回归预测带来了比较大的挑战,所里这里要进行矢量量化。具体来说就是每一个 16bit 映射到一个固定的 整数,这个整数的范围是0到159,这不是量化了呀,后续自回归预测也容易了很多。

具体怎么进行映射呢?

  1. 首先将Encoder后的向量 512 通道进行拆分,拆分成 4 份,每一份128维。因为后续有 4 个码本。
  2. 分别计算这 4 份向量和 nn.embedding.weight 之间的欧式距离,取距离最近的 index 就得到了矢量。这里 nn.embedding 设置的矢量的个数是 160。
  3. 矢量通过 nn.embedding 则得到了 z_q
  4. z_q 输入 Decoder 得到原始波形
  5. 这里损失有两个,分别是对抗损失(还有一些 HiFi-GAN 中训练用的损失)和矢量量化损失(Encoder输出值和 z_q之间的损失)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class Quantizer_module(torch.nn.Module):
def __init__(self, n_e, e_dim):
# n_e 160
# e_dim 128
super(Quantizer_module, self).__init__()
self.embedding = nn.Embedding(n_e, e_dim)
# nn.Embedding 是密码本,输入的是一个index,输出的是 128 维度的向量,index 总数为 160
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)

def forward(self, x):
# 这里是计算 x 与 码本之间的一个相似度,x的维度是 [512, 128], 可以理解为有 512 个128的向量,512根据时间长短不同,这个数值是不固定的
# 这里去分别寻找 这 512 个 128 维,和这个密码本 [160, 128]的最短距离,然后返回最短距离的这个 index, index 为密码本中的索引
d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) - 2 * torch.matmul(x, self.embedding.weight.T)
# 这里计算过程中利用了广播机制,实际就是计算了欧式距离,得到的就是 512 个向量和 160个密码本的两两相似度的一个矩阵
min_indicies = torch.argmin(d, 1)
#
z_q = self.embedding(min_indicies)
return z_q, min_indicies


class Quantizer(torch.nn.Module):
def __init__(self, h):
super(Quantizer, self).__init__()
assert 512 % h.n_code_groups == 0
self.quantizer_modules = nn.ModuleList([
Quantizer_module(h.n_codes, 512 // h.n_code_groups) for _ in range(h.n_code_groups)
])
self.h = h

def forward(self, xin):
#B, C, T
xin = xin.transpose(1, 2) # [16, 32, 512]
x = xin.reshape(-1, 512) # [512, 512]
x = torch.split(x, 512 // self.h.n_code_groups, dim=-1) # 将原始的 x 分成了四份,每份代表不同的 矢量
min_indicies = []
z_q = []
for _x, m in zip(x, self.quantizer_modules):
_z_q, _min_indicies = m(_x)
z_q.append(_z_q)
min_indicies.append(_min_indicies) #B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2)
return z_q, loss, min_indicies

def embed(self, x):
#idx: N, T, 4
x = torch.split(x, 1, 2)
ret = []
for q, embed in zip(x, self.quantizer_modules):
q = embed.embedding(q.squeeze(-1))
ret.append(q)
ret = torch.cat(ret, -1)
return ret.transpose(1, 2) # N, C, T

Conditional Synthesis with Transformer

框架图如下所示:

MQTTS-Transformer

该模块的内容主要是利用 Transformer 自回归预测矢量,不过细节改动的比较多,包括:

  1. global speaker embedding,该模块的嵌入可以支持多说话人,嵌入使用预训练模型pyannote.
  2. ALiBi replace positional encoding,该方法使得输入序列变很长的时候不会降低合成质量
  3. 根据 TTS 的特性修改了对齐机制
  4. 新增 Sub-Decoder 来预测不同码本对应的矢量
  5. 根据语音特性用 Repetition Token 来标识和上一时刻重复的矢量

详细结构如下:

  • Phone Encoder
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
encoder = TransformerEncoder(
nn.ModuleList(
[TransformerEncoderLayer(hp) for i in range(hp.enc_nlayers)]
)
)
# 6层 TransformerEncoder


class TransformerEncoderLayer(nn.Module):
def __init__(self, hp, dropout=0.1):
super().__init__()
self.hp = hp
self.d_model = hp.hidden_size # 768
self.dropout_p = dropout # 0.1
self.self_attn = MultiheadAttention(self.d_model, hp.nheads, dropout=0.1) # 12
# Implementation of Feedforward model
self.linear1 = nn.Linear(self.d_model, hp.ffd_size) # ffd_size:3072
self.dropout = nn.Dropout(dropout) # 0.1
self.linear2 = nn.Linear(hp.ffd_size, self.d_model)

self.norm1 = nn.LayerNorm(self.d_model, eps=hp.layer_norm_eps)
self.norm2 = nn.LayerNorm(self.d_model, eps=hp.layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

self.activation = nn.GELU()

def forward(self, src, src_mask=None, attn_bias=None, src_key_padding_mask=None):
res, self_attn = self.self_attn(src, src, src, attn_mask=src_mask, attn_bias=attn_bias,
key_padding_mask=src_key_padding_mask)
src = src + self.dropout1(res)
src = self.norm1(src)
res = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(res)
src = self.norm2(src)
return src, self_attn

  • Decoder

整体框架和 Encoder 类似,为 6 层 TransformerDecoderLayer。输入为 q_input,这里的 q_input 是真实值的一个错位,q_input 的第 0 时间是 [161, 161, 161, 161] 这样的特殊标识,第 1 时间为 q 的第 0 时间的真实值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
self.decoder = TransformerDecoder(
nn.ModuleList(
[TransformerDecoderLayer(hp, with_cross_attention=False) for i in range(hp.dec_nlayers)]
)
)


class TransformerDecoderLayer(nn.Module):
def __init__(self, hp, with_cross_attention, dropout=0.1):
super().__init__()
self.hp = hp
self.d_model = hp.hidden_size # 256
self.dropout_p = dropout # 0.1
self.self_attn = MultiheadAttention(self.d_model, hp.nheads, dropout=0.1) # nheads: 4
self.with_cross_attention = with_cross_attention # False
if with_cross_attention:
self.multihead_attn = MultiheadAttention(self.d_model, hp.nheads, dropout=0.1)
self.norm2 = nn.LayerNorm(self.d_model, eps=hp.layer_norm_eps)
self.dropout2 = nn.Dropout(dropout)
# 这里区分了两个,一个是 self_attn, 一个是 multihead_attn,(其实这里命名是有点不规范的)
# self_attn 是自己和自己做 attn,也可以是多头的注意力
# multihead_attn 是自己和别人做 attn,也可以是多头的注意力,命名应该改成 cross_attention 更合理

# Implementation of Feedforward model
self.linear1 = nn.Linear(self.d_model, hp.ffd_size)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(hp.ffd_size, self.d_model)

self.norm1 = nn.LayerNorm(self.d_model, eps=hp.layer_norm_eps)
self.norm3 = nn.LayerNorm(self.d_model, eps=hp.layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = nn.GELU()

def forward(self, tgt, memory=None, tgt_mask=None, attn_bias=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None, past_kv=None):
tgt2, self_attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, attn_bias=attn_bias,
key_padding_mask=tgt_key_padding_mask, past_kv=past_kv)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
attn = None
if self.with_cross_attention:
assert memory is not None
tgt2, attn = self.multihead_attn(tgt, memory, memory,
key_padding_mask=memory_key_padding_mask, past_kv=past_kv)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt, attn, self_attn

class TransformerDecoder(nn.Module):
def __init__(self, decoder_layers):
super().__init__()
self.layers = decoder_layers
self.num_layers = len(decoder_layers)

def forward(self, tgt, memory, tgt_mask=None, attn_bias=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None, past_kvs=None):
output = tgt
# print(output)
# print(output.shape)
# raise OSError('end')
attns = []
self_attns = []
outputs = []
if past_kvs is None:
past_kvs = [None for _ in range(len(self.layers))]
for i, mod in enumerate(self.layers):
output, attn, self_attn = mod(output, memory, tgt_mask=tgt_mask, attn_bias=attn_bias,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
past_kv=past_kvs[i])
if attn is not None:
attns.append(attn.detach())
if self_attn is not None:
self_attns.append(self_attn.detach())
outputs.append(output)
return output, attns, self_attns, outputs
  • Aligner

可以理解为 Decoder 的最后一层,仅仅最后一层使用了 CrossAttn。这里获得了 phone 和 q_input 共同作用之后的信息,同时获得了对齐信息。Aligner采用了单头注意力。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
output, alignment = self.aligner(output, phone, tgt_mask=tgt_mask,
tgt_key_padding_mask=q_mask, memory_key_padding_mask=phone_mask)

class CrossAttnOnlyLayer(nn.Module):
def __init__(self, hp, dropout=0.1):
super().__init__()
self.dropout_p = dropout
#Only one head for alignment!
self.multihead_attn = MultiheadAttention(hp.hidden_size, 1, dropout=0.1, softmax_temp=hp.aligner_softmax_temp)
# Implementation of Feedforward model
self.linear1 = nn.Linear(hp.hidden_size, hp.ffd_size)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(hp.ffd_size, hp.hidden_size)

self.norm1 = nn.LayerNorm(hp.hidden_size, eps=hp.layer_norm_eps)
self.norm2 = nn.LayerNorm(hp.hidden_size, eps=hp.layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

self.activation = nn.GELU()

def forward(self, tgt, memory, tgt_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None):
tgt2, attn = self.multihead_attn(tgt, memory, memory,
key_padding_mask=memory_key_padding_mask)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
return tgt, attn

  • Sub-Decoder

自回归形式预测 矢量q,同时引入了 【R】标识

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
audio_output = self.transducer.decode(audio_output, q_input)
# audio_output [B, T, 4, 163], 其中 163 为每个矢量的概率,矢量为 163 个

class ARCodeTransformer(nn.Module):
def __init__(self, hp, n_decoder_codes):
super().__init__()
self.hp = hp
ar_hp = Namespace(hidden_size=hp.ar_hidden_size, nheads=hp.ar_nheads, layer_norm_eps=hp.layer_norm_eps,
ffd_size=hp.ar_ffd_size)
self.model = TransformerDecoder(
nn.ModuleList(
[TransformerDecoderLayer(ar_hp, with_cross_attention=False) for i in range(hp.ar_layer)]
)
)
self.embedding = nn.ModuleList(
[
nn.Embedding(n_decoder_codes, hp.ar_hidden_size) for _ in range(self.hp.n_cluster_groups - 1)
]
)
# ?这里的 embedding 为什么只有3个,不是 4 个呢? 因为第一个输入为 cond 不需要 embedding

self.linear = nn.Linear(hp.hidden_size, hp.ar_hidden_size)
self.layer_norm = nn.LayerNorm(hp.ar_hidden_size, eps=hp.layer_norm_eps)
self.decoders = nn.ModuleList([
nn.Linear(hp.ar_hidden_size, n_decoder_codes)
for i in range(hp.n_cluster_groups)
])
tgt_mask = (torch.tril(torch.ones(hp.n_cluster_groups, hp.n_cluster_groups), diagonal=0) == 0)
self.register_buffer('tgt_mask', tgt_mask)

def forward(self, cond, gt):
#cond: N, T, C
#gt: N, T, 4
#return: N, T, 4, n_codes
N, T, _ = cond.size()
cond, gt = cond.reshape(N * T, -1), gt.reshape(N * T, -1)
# print(cond.shape) [N*T, 768]
# print(gt.shape) [N*T, 4]

cond = self.linear(cond) # [N*T, 256]

gt = gt[:, : -1] #NT, 3
gt_in = []
for i in range(self.hp.n_cluster_groups - 1):
gt_in.append(self.embedding[i](gt[:, i])) #3 [NT, C] 【NT, 256】

inp = torch.stack([cond] + gt_in, 1) #NT, 4, C
# inp 是自回归的 input,第 0 位就是 cond,第一位是上一个真实的值,依次类推,
inp = self.layer_norm(inp)
out, _, _, _ = self.model(inp, memory=None, tgt_mask=self.tgt_mask)
# out 是预测的 q,根据上一个预测的下一个,这里用到 teacher-force,因此训练阶段是并行的
ret = []
for i in range(self.hp.n_cluster_groups):
ret.append(self.decoders[i](out[:, i]))
ret = torch.stack(ret, 1).reshape(N, T, self.hp.n_cluster_groups, -1)
# 将 q 恢复成 【N, T, 4, 163】 的维度
return ret

最后通过交叉熵损失训练模型。

推理过程(从文本合成音频)

  1. 干净的语音作为 Prompt,为了生成相对干净的音频,因为训练过程利用的 ASR 的音频,质量比较低。所以这里输入干净的语音作为 Prompt。
  2. 单调对齐,同时利用单调对齐进行停止的判断,而不是利用一个二分类器,这样得到的会更加稳定。
  • Prompt
1
2
3
4
5
6
7
8
9
10
11
12
low_background_noise = torch.randn(batch_size, int(self.hp.sample_rate * 5.0)) * self.hp.prior_noise_level
# print(low_background_noise.shape) [b, 80000] , 5秒干净的音频作为 Prompt 背景噪声很小
# 'prior_noise_level': 1e-5,
base_prior = self.vocoder.encode(low_background_noise.cuda())
# 将音频进行矢量量化
# print(base_prior.shape) [b, 312, 4]
if self.hp.clean_speech_prior:
prior = base_prior[:, :self.hp.prior_frame]
# 取前 3 帧就够 promte 了
# print(prior) [b, 3, 4]
else:
prior = None
  • 自回归生成

输入为 phone、speaker_embedding 和 Prompt,输出为以该 speaker 为说话人的语音。

1
synthetic = self.TTSdecoder.inference_topkp_sampling_batch(phone_features, speaker_embedding, phone_masks, prior=prior)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def inference_topkp_sampling_batch(self, phone, spkr, phone_mask, prior=None, output_alignment=False):
batch_size = phone.size(0)
final_outputs = [0 for _ in range(batch_size)]
spkr = self.layer_norm_spkr(spkr.unsqueeze(1))
inp = self.layer_norm(self.transducer.start_token(phone.device)) #1, 1, C
# print(inp)
# print(inp.shape) # [1, 1, 768]
# start_token 是 161
inp = inp.expand(batch_size, -1, -1) #N, 1, C
inp = torch.cat([spkr, inp], 1)
# 嵌入 spkr

prior_size = 0
if prior is not None:
prior = self.transducer.encode(prior)
prior = self.layer_norm(prior)
prior_size = prior.size(1)
inp = torch.cat([inp, prior], 1)
# 加入静音片段作为 Prompt
phone = self.encode_phone(phone, spkr, phone_mask)
tgt_mask = self.tgt_mask[:inp.size(1), :inp.size(1)].to(inp.device)
inps = inp
#Decode
past_kvs1, past_kv_cross, past_kvs2, clusters = None, None, None, torch.empty([batch_size, 0, self.hp.n_cluster_groups], device=phone.device, dtype=torch.long)
audio_alibi = self.alibi(inp)
audio_alibi[:, 0] = 0
audio_alibi[:, :, 0] = 0
back_map = torch.zeros([batch_size, 1], device=phone.device, dtype=torch.long)
# 和 batch_size 为行的 0 向量

length_counter = torch.zeros([batch_size], device=phone.device, dtype=torch.long)
real_phone_lengths = (~phone_mask).long().sum(-1) #N,
if output_alignment:
assert batch_size == 1, "Now only support output alignment for bs = 1 for debugging issues..."
alignment = torch.zeros((1, self.hp.max_output_length, self.hp.max_output_length), device=phone.device)
for i in range(self.hp.max_output_length):
cond, _, _, new_1 = self.decoder(inp, memory=None, attn_bias=audio_alibi, tgt_mask=tgt_mask, past_kvs=past_kvs1)
# cond.shape: [B, 5, 768], 5 分别是 spk, 开始token 和 3个 Prompt

#Only feed in the current frame and the next frame attending!
t_length, c_length = phone.size(1), phone.size(2) # T, C=768
selected_phone = phone.reshape(-1, c_length) #N*T, C
index_map = torch.arange(self.hp.phone_context_window, device=phone.device) # [0, 1, 2]
index_map = back_map[:, -1:] + index_map.repeat(batch_size, 1)
# back_map 最后一列 + index_map
# 这里应该是求 单调对齐的

add = torch.arange(batch_size, device=index_map.device).unsqueeze(1) #N, 1

index_map = index_map + add * t_length # 由于处理过程中将 batch 拉平了,所以 index 需要特殊处理
index_map = index_map.reshape(-1) #N * 3 # phone 是被拉长了的!
selected_phone = selected_phone[index_map].reshape(batch_size, self.hp.phone_context_window, c_length) #N*3, C
#Mask for the starting phones
phone_mask = torch.arange(self.hp.phone_context_window, device=phone.device).repeat(batch_size, 1)
phone_mask = (phone_mask <= (back_map[:, -1:] + 1).expand(-1, self.hp.phone_context_window))
phone_mask = ~phone_mask
cond, _align = self.aligner(cond, selected_phone, tgt_mask=tgt_mask, memory_key_padding_mask=phone_mask)
# 仅仅利用部分 phone 进行计算,不需要全局的!

cond = cond[:, -1].unsqueeze(1) #N, 1, C

#Run sub-decoder inference
output = []
for j in range(self.hp.n_cluster_groups):
# 循环预测每一个码本的矢量
q_input = torch.cat(output, 1) if j else None
logit = self.transducer.decoder.infer(cond, q_input) #N, n_codes , 9, 163
#Block Start Token
logit[:, self.hp.n_codes + 1] = -float("Inf") # 删除没用的标识符号
#Don't output stop token if alignment not near end
# print(real_phone_lengths)
logit_tmp = logit[back_map[:, -1] < (real_phone_lengths - 2)]
# print(back_map)
# print(logit_tmp.shape)
# raise OSError('end')
logit_tmp[:, self.hp.n_codes] = -float("Inf")
logit[back_map[:, -1] < (real_phone_lengths - 2)] = logit_tmp
#Repetition penalty
if self.hp.use_repetition_token and self.hp.repetition_penalty != 1.0:
logit[:, self.hp.n_codes + 2] /= self.hp.repetition_penalty
if self.hp.use_repetition_gating:
logit[:, self.hp.n_codes + 2] = torch.min(torch.max(logit[:, :self.hp.n_codes]), logit[:, self.hp.n_codes + 2])
#Top_p
if self.hp.top_p < 1.0 and self.hp.top_p > 0.0:
sorted_logits, sorted_idxs = torch.sort(logit, descending=True)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
additional_prob = (self.hp.length_penalty_max_prob - self.hp.top_p) * (length_counter / self.hp.length_penalty_max_length)
idx_to_remove = cum_probs > (self.hp.top_p + additional_prob).unsqueeze(-1)
idx_to_remove[:, :self.hp.min_top_k] = False
idx_to_remove = idx_to_remove.scatter(1, sorted_idxs, idx_to_remove)
logit[idx_to_remove] = -float("Inf")
#Sampling
probs = torch.softmax(logit / self.hp.sampling_temperature, dim=-1)
idx = torch.multinomial(probs, 1) #N, 1
#If is repetition token
if self.hp.use_repetition_token:
if clusters.size(1) == 0: #First token, random choice
idx[idx==(self.hp.n_codes + 2)] = torch.randint(self.hp.n_codes, size=(1,), device=idx.device)
else:
idx[idx==(self.hp.n_codes + 2)] = clusters[:, -1:, j][idx==(self.hp.n_codes + 2)]
output.append(idx)
output = torch.cat(output, 1).unsqueeze(1) #N, 1, n_groups
# output: t=1 时刻的矢量

#Stop criterion
stopping_streams = (back_map[:, -1] == (real_phone_lengths - self.hp.phone_context_window))
stopping_streams = (stopping_streams & self.transducer.is_end_token_batch(output)) | (stopping_streams & (torch.argmax(_align[:, 0, -1], dim=-1) == self.hp.phone_context_window - 1)) #N,
if i == self.hp.max_output_length - 1:
stopping_streams[:] = True
stopping_streams_idx = np.where(stopping_streams.detach().cpu().numpy())[0]
num_stopped = stopping_streams.long().sum().item()
if num_stopped > 0:
stopped = clusters[stopping_streams]
n_seats, stop_seats = 0, 0
for n_s, seat in enumerate(final_outputs):
if type(seat) == int:
n_seats += 1
if n_seats - 1 in stopping_streams_idx:
# print (n_seats, stopping_streams_idx, stopped.size(), stop_seats)
final_outputs[n_s] = stopped[stop_seats]
stop_seats += 1
n_remained = sum([int(type(x) == int) for x in final_outputs])
if n_remained == 0:
break
#Trim batches
batch_size = batch_size - num_stopped
output = output[~stopping_streams]
phone = phone[~stopping_streams]
real_phone_lengths = real_phone_lengths[~stopping_streams]
clusters = clusters[~stopping_streams]
back_map = back_map[~stopping_streams]
length_counter = length_counter[~stopping_streams]
_align = _align[~stopping_streams]
news = [inps] + new_1
inps = inps[~stopping_streams]
for layer in range(len(news)):
news[layer] = news[layer][~stopping_streams]
if past_kvs1 is not None:
for layer in range(len(past_kvs1)):
past_kvs1[layer] = past_kvs1[layer][~stopping_streams]

#Update args
tgt_mask = self.tgt_mask[i+3+prior_size, :i+3+prior_size].to(phone.device).unsqueeze(0)
audio_alibi = self.alibi(tgt_mask)[:, -1].unsqueeze(1)
audio_alibi[:, :, 0] = 0
if output_alignment:
alignment[:, i, back_map[0, -1]: back_map[0, -1]+self.hp.phone_context_window] = _align[:, 0, -1].unsqueeze(0)
next_idx = (_align[:, 0, -1, 0] < (1 / self.hp.phone_context_window)).long()
next_idx[length_counter >= self.hp.length_penalty_max_length] = 1
new_bk = torch.minimum(back_map[:, -1] + next_idx, real_phone_lengths - self.hp.phone_context_window)
back_map = torch.cat([back_map, new_bk.unsqueeze(1)], 1)
length_counter[next_idx == 0] += 1
length_counter[next_idx != 0] = 0
if i == 0:
past_kvs1 = news[:self.hp.dec_nlayers]
else:
news = [x[:, -1:] for x in news]
for ii, (p, n) in enumerate(zip(past_kvs1, news[:self.hp.dec_nlayers])):
past_kvs1[ii] = torch.cat([p, n], 1)

inp = self.transducer.encode(output)
inp = self.layer_norm(inp)
inps = torch.cat([inps, inp], 1)
clusters = torch.cat([clusters, output], 1) #N, T, 4
if output_alignment:
return final_outputs, alignment[:, :i, :phone.size(1)]
return final_outputs

  • 将矢量解码,得到语音
1
synthetic = self.vocoder(padded_synthetic, norm_spkr)

中文模型(WeNetSpeech)

中文模型在录制数据上微调,解决 WeNetSpeech 数据标注不准确的问题

中英文混合模型(GigaSpeech + WeNetSpeech)