0%

快速提取 DAC 的 Speech Token

DAC:Descript-Audio-Codec(DAC) 是一款“能打的”音频编解码器,主要是鲁棒性比较好,能支持语音、音效、音乐,甚至多人说话、噪声、混响都有不错的重构效果。很多基于 LLM 的语音生成类模型也都用了该编解码器,如parler-ttsGenhancer,还有很多端到端的语音对话模型也是采用了 RVQ 的编解码器,如 Moshi。这里主要是记录一下怎么快速提取 DAC 的 Token供后续模型训练。

利用多进程和多线程,将所有的音频平均分配到每张GPU上,单张GPU用多线程处理(这里主要是读取音频I/O),GPU之间用多进程处理。

  • 代码部分
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import torch
import torchaudio
import dac
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing

def single_job(utt, utt2wav, dac_model, device):
"""单个任务的执行逻辑"""
torch.cuda.set_device(device)
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 44100:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=44100)(audio)

with torch.no_grad():
_, speech_token, _, _, _ = dac_model.encode(audio.to(device).unsqueeze(0))
return utt, speech_token.to('cpu').numpy()

def process_chunk(chunk, utt2wav, device, result_dict):
"""多进程任务,每个进程处理一个子集"""
torch.cuda.set_device(device)
dac_model = dac.DAC.load('/data/yuanxin/pretrained_ckpt/weights.pth').eval().to(device)

with ThreadPoolExecutor(max_workers=4) as executor:
future_to_utt = {executor.submit(single_job, utt, utt2wav, dac_model, device): utt for utt in chunk}

for future in tqdm(as_completed(future_to_utt), total=len(chunk), desc=f"Processing on GPU {device}"):
utt, speech_token = future.result()
result_dict[utt] = speech_token

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str)
args = parser.parse_args()
torch.multiprocessing.set_start_method('spawn')# good solution !!!!
# 读取 utt2wav 映射
utt2wav = {}
with open(f'{args.dir}/target_clean.scp') as f:
for l in f:
l = l.strip().split(' ')
utt2wav[l[0]] = ' '.join(l[1:])

# 任务分割,每个进程 5000 个 utt
chunk_size = 5000
keys = list(utt2wav.keys())
chunks = [keys[i:i + chunk_size] for i in range(0, len(keys), chunk_size)]

# 进程间共享结果字典
manager = multiprocessing.Manager()
result_dict = manager.dict()

# 启动多个进程,每个进程绑定一个 GPU
processes = []
num_gpus = torch.cuda.device_count()
for i, chunk in enumerate(chunks):
device = i % num_gpus # 轮询分配 GPU
p = multiprocessing.Process(target=process_chunk, args=(chunk, utt2wav, device, result_dict))
processes.append(p)
p.start()

# 等待所有进程完成
for p in processes:
p.join()

# 保存结果
torch.save(dict(result_dict), f'{args.dir}/utt2speech_token.pt')
  • 总结

总体测试下来,1万小时音频利用8卡机器处理大约需要3个小时。

DAC官方库默认的是torch推理,并没有支持ONNX,这里将 DAC 转换为 ONNX 模型推理。

  • 模型导出代码为:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import dac

class DACEncoderONNX(nn.Module):
def __init__(self, encoder, quantizer):
super().__init__()
self.encoder = encoder
self.quantizer = quantizer

def forward(self, audio_data: torch.Tensor):
z = self.encoder(audio_data)
_, codes, _, _, _ = self.quantizer(z)
return codes

# 加载原始模型

dac_model = dac.DAC.load('/data/yuanxin/pretrained_ckpt/weights.pth')
dac_model.eval()
# 创建推理模型
encoder_model = DACEncoderONNX(dac_model.encoder, dac_model.quantizer)
encoder_model.eval()

# 创建示例输入
dummy_input = torch.randn(1, 1, 44100)

# 导出为 ONNX
torch.onnx.export(
encoder_model,
dummy_input,
"dac_encoder.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=["audio_input"],
output_names=["codes"],
dynamic_axes={"audio_input": {2: "audio_length"}, "codes": {2: "code_length"}}
)

print("ONNX model exported successfully!")

  • 推理代码为:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
from tqdm import tqdm
import onnxruntime # onnxruntime 1.14.1
import numpy as np
import torchaudio
import os


def single_job(utt):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 44100:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=44100)(audio)

speech_token = ort_session.run([ort_session.get_outputs()[0].name],
{ort_session.get_inputs()[0].name: audio.unsqueeze(0).numpy()})

return utt, speech_token[0]

def main(args):
all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
utt2speech_token = {}
for future in tqdm(as_completed(all_task), total=len(utt2wav), desc=f"Processing on GPU {args.split}"):
utt, speech_token = future.result()
utt2speech_token[utt] = speech_token
os.makedirs(os.path.join(args.dir, 'dac',f'{args.type}_utt2speech_token_{args.epoch}'), exist_ok=True)
torch.save(utt2speech_token, os.path.join(args.dir, 'dac', f'{args.type}_utt2speech_token_{args.epoch}',
f'{args.type}_utt2speech_token_{args.epoch}_device{args.split}.pt'))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str) # /home/data2/tts_data/ctts/v2/mos4.5_same_asr
parser.add_argument("--epoch", type=str) # 仿真数据标识
parser.add_argument("--split", type=str) # 每张 gpu 切分标识
parser.add_argument("--type", type=str) # noise or clena 标识

args = parser.parse_args()

utt2wav = {}
with open('{}/{}_epoch_{}_split{}.scp'.format(os.path.join(args.dir, 'wavscp'),
args.type, args.epoch, args.split)) as f:
for l in f:
l = l.replace('\n', '').split(' ')
utt2wav[l[0]] = ' '.join(l[1:])

num_processes = 8 # Adjust as needed
executor = ThreadPoolExecutor(max_workers=num_processes)

option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CUDAExecutionProvider"]

ort_session = onnxruntime.InferenceSession('/data/yuanxin/pretrained_ckpt/dac_encoder.onnx',
sess_options=option, providers=providers)
main(args)