0%

Delay Pattern LLM

最早看到 Delay Pattern 应该是 meta 在 musicgen中提出的,主要就是音频编解码器是RVQ的形式,这个时候怎么利用LLM就会有很多种不同的序列构造方式,当时 musicgen 里面对比了多种方式。总体实验下来 Delay 模式在效果和速度方面得到了一个很好的平衡。

整体的对比图如下所示,后续很多工作(主要是涉及到利用 RVQ的)陆续开始使用 Delay LLM,如Parler-tts、audiocraft等。另一种方式也有一个,就是利用 RQ-Transformer;这个后续介绍。

Delay Pattern LLM-1
Delay Pattern LLM-2
Delay Pattern LLM-2

这种 Delay 方式就是在构造训练数据和推理的时候需要注意一下,其他基本都是一致的。

训练阶段

  1. 将 RVQ token 变为 Delay 模式
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def build_delay_pattern_mask(
self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape

input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1

# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)

# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]

# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))

bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id

# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
first_start_id = min(start_ids)
else:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id = seq_len

# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask

def postprocess_dataset(self, labels):
bos_labels = torch.ones((labels.shape[0], self.num_codebooks, 1)) * self.start_mel_token
# (1, codebooks, seq_len)
# add bos
labels = torch.cat([bos_labels.to(labels.device), labels], dim=-1)

labels, delay_pattern_mask = self.build_delay_pattern_mask(
labels,
bos_token_id=self.start_mel_token,
pad_token_id=self.stop_mel_token,
max_length=labels.shape[-1] + self.num_codebooks,
num_codebooks=self.num_codebooks,
)

# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask == -1, self.stop_mel_token, delay_pattern_mask)

# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
# output = {"labels": labels[:, 1:]}
return labels[:, 1:]
  1. 切记在计算 Loss 中要 mask 由于 delay模式 pad 的开始符号和结束符号。仅仅保留最后一个结束符号就好了。同时不同VQ之间的损失最好也不要都设置成一样的,不利于模型收敛。可以按照一个客观指标判断每一层VQ的重要程度,然后设置对应的权重。44.1khz的DAC参考权重为[15, 12.66, 5.43, 2.92, 1.81, 1.48, 0.86, 0.85, 0.75]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
speech_token, speech_token_target = speech_token[:, :, :-1], speech_token
temp_speech_token = F.pad(speech_token, (1,0), value=1024)
speech_token_target = speech_token_target.masked_fill(speech_token_target == self.start_mel_token, -100)
mask = (temp_speech_token != self.stop_mel_token) & ((speech_token_target != -100))

loss_acc_dict = {}
total_loss = torch.zeros([], device=device)
acc = torch.zeros([], device=device)
for i in range(self.num_codebooks):
codebook_logits = logits[:, i, :, :][mask[:, i, :]].view(speech_token_target.shape[0], -1, logits.shape[3])
codebook = speech_token_target[:, i, :][mask[:, i, :]].view(speech_token_target.shape[0], -1)
codeloss = self.criterion_ce(codebook_logits, codebook)
codeacc = th_accuracy(codebook_logits.view(-1, self.speech_token_size + 3), codebook, ignore_label=IGNORE_ID)
loss_acc_dict.update({f'codeloss_{i}': codeloss,
f'codeacc{i}': codeacc})
#
if self.codebook_weights is not None:
total_loss += codeloss * self.codebook_weights[i]
acc += codeacc
loss_acc_dict.update({'loss': total_loss, 'acc': acc/self.num_codebooks})

推理阶段

  1. 采样过程要控制 logits,确保上一层VQ没有输出结束符号时,下一层VQ的结束符号概率强制设置为0;确保开始时候按照 Delay 的方式采样,也就是开始的时候后续几层的概率全为 pad 的那个 token。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class DACLogitsProcessor(LogitsProcessor):
def __init__(self, eos_token_id, pad_token_id, num_codebooks: int, batch_size: int, device: str = "cpu"):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.batch_size = batch_size
# self.trunc_index = trunc_index

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

self.num_codebooks = num_codebooks
self.device = device


self.codebook_idx = torch.arange(self.batch_size*self.num_codebooks, device=self.device)
self.first_codebooks_unfinished = torch.arange(batch_size, device=device)*num_codebooks

max_codebooks = torch.arange(self.batch_size, device=self.device)*self.num_codebooks + self.num_codebooks -1
self.max_codebooks = max_codebooks

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
is_eos = torch.isin(input_ids, self.eos_token_id).sum(1)
self.first_codebooks_unfinished = torch.where((is_eos[self.first_codebooks_unfinished]>0) & (self.first_codebooks_unfinished<self.max_codebooks), self.first_codebooks_unfinished+1, self.first_codebooks_unfinished)

# every codebook higher than the first one unfinished will never be eos
eos_token_mask = self.codebook_idx > self.first_codebooks_unfinished.repeat_interleave(self.num_codebooks)
scores[eos_token_mask, self.eos_token_id] = -math.inf
# from two codebook to nine code book, set pad_token_id
temp_score = scores.clone()
pad_token_set = self.codebook_idx > input_ids.shape[1] - 1
scores[pad_token_set, :] = -math.inf
scores[:, self.pad_token_id] = temp_score[:, self.pad_token_id]

return scores
  1. 反向 Delay 获得最后的结果,然后通过 RVQ 的decoder得到最后的音频

这里的主要代码参考了 parler-tts,有小部分的修改。