0%

CLVP

CLVP:一种文本、语音的预训练模型,用来挑选更高质量的语音合成结果。方法类似于图像中的 CLIP。

理论很简单:

clip

如上图所示,文本通过一个 Text Encoder 得到 Text Embedding;图像经过 Image Emcoder 得到 Image Embedding;通过对比学习的方式进行模型训练。在推理阶段可以进行 zero-shot 任务,而且分类数也不受限制。

  • 代码部分
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class VoiceCLIP(nn.Module):
"""
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
transcribed text.

Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
"""

def __init__(
self,
*,
dim_text=512,
dim_speech=512,
dim_latent=512,
num_text_tokens=256,
text_enc_depth=6,
text_seq_len=120,
text_heads=8,
num_speech_tokens=8192,
speech_enc_depth=6,
speech_heads=8,
speech_seq_len=250,
text_mask_percentage=0,
voice_mask_percentage=0,
wav_token_compression=1024,
use_xformers=False,
clip_mels=False,
min_mel_size=10, # Default is approximately .5sec with default mel specs.
distributed_collect=False,
):
super().__init__()
# nn.Embedding
self.text_emb = mbnb.nn.Embedding(num_text_tokens, dim_text)
self.to_text_latent = mbnb.nn.Linear(dim_text, dim_latent, bias=False)

# nn.Embedding
self.speech_emb = mbnb.nn.Embedding(num_speech_tokens, dim_speech)
self.to_speech_latent = mbnb.nn.Linear(dim_speech, dim_latent, bias=False)

if use_xformers:
self.text_transformer = CheckpointedXTransformerEncoder(
needs_permute=False,
exit_permute=False,
max_seq_len=-1,
attn_layers=Encoder(
dim=dim_text,
depth=text_enc_depth,
heads=text_heads,
ff_dropout=.1,
ff_mult=2,
attn_dropout=.1,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
))
self.speech_transformer = CheckpointedXTransformerEncoder(
needs_permute=False,
exit_permute=False,
max_seq_len=-1,
attn_layers=Encoder(
dim=dim_speech,
depth=speech_enc_depth,
heads=speech_heads,
ff_dropout=.1,
ff_mult=2,
attn_dropout=.1,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
))
else:
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
heads=text_heads)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads)

self.temperature = nn.Parameter(torch.tensor(1.))
self.text_mask_percentage = text_mask_percentage
self.voice_mask_percentage = voice_mask_percentage
self.wav_token_compression = wav_token_compression
self.xformers = use_xformers
self.clip_mels = clip_mels
self.min_mel_size = min_mel_size
self.distributed_collect = distributed_collect
if not use_xformers:
# nn.Embedding
self.text_pos_emb = mbnb.nn.Embedding(text_seq_len, dim_text)
# nn.Embedding
self.speech_pos_emb = mbnb.nn.Embedding(num_speech_tokens, dim_speech)

def embed_text(self, text):
text_mask = torch.ones_like(text.float()).bool()
text_emb = self.text_emb(text)
enc_text = self.text_transformer(text_emb, mask=text_mask)
text_latents = masked_mean(enc_text, text_mask, dim=1)
text_latents = self.to_text_latent(text_latents)
return text_latents

def forward(
self,
text,
speech_tokens,
return_loss=True
):
# print(f"text: {text}, speech_token: {speech_tokens}")
# print(f"text shape: {text.shape}, speech_token shape: {speech_tokens.shape}")

b, device = text.shape[0], text.device
if self.training:
if self.clip_mels:
margin = speech_tokens.shape[-1] - self.min_mel_size
speech_tokens = speech_tokens[:, :self.min_mel_size+randint(0,margin)]
voice_mask = torch.ones_like(speech_tokens.float()).bool() # Disable voice masking in this case.
else:
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
else:
text_mask = torch.ones_like(text.float()).bool()
voice_mask = torch.ones_like(speech_tokens.float()).bool()
text_emb = self.text_emb(text)
speech_emb = self.speech_emb(speech_tokens)

if not self.xformers:
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))

enc_text = self.text_transformer(text_emb, mask=text_mask)
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)

text_latents = masked_mean(enc_text, text_mask, dim=1)
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)

text_latents = self.to_text_latent(text_latents)
speech_latents = self.to_speech_latent(speech_latents)

if self.distributed_collect:
collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(collective, text_latents)
collective[torch.distributed.get_rank()] = text_latents # For gradient propagation.
text_latents = torch.cat(collective, dim=0)
collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())]
collective[torch.distributed.get_rank()] = speech_latents # For gradient propagation.
torch.distributed.all_gather(collective, speech_latents)
speech_latents = torch.cat(collective, dim=0)
b = text_latents.shape[0]

text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))

temp = self.temperature.exp()

if not return_loss:
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
return sim

sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
labels = torch.arange(b, device=device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss
  • 理解损失函数
  1. 首先通过各自的 Encoder 之后,text_latent 和 mel_token_latent 分别按照 t 取平均(这里是有 mask 的)得到两个【16,768】的向量;
  2. sim = einsum(‘i d, j d -> i j’, text_latents, speech_latents) 得到两两相似度的计算。这里 einsum 计算的是两两相量的点积,output[i, j] = sum(text_latents[i, d] * speech_latents[j, d])
  3. labels = torch.arange(b, device=device),b 是 batch size;[0, 1, 2, …, 15]
  4. loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2

将相似度矩阵 sim 视为一个分类任务的输入,每一行代表一个样本的预测值,目标是让每个样本的预测标签等于其对应的真实 labels,正好就是想对角线元素就是真实 labels;转置再计算一遍是因为sim矩阵不是对称的,主对角线元素 sim 就是应该高,其他都是低。

  • 注意
  1. 在实际推理过程中可以将 zero-shot 的<文本,语音>cat到一起进行打分,这样打分出来才会更符合 prompt 的风格;
  2. 在训练过程中 batch-size 一定要大才可以,至少 2048 以上。