1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| import argparse import torch import torchaudio import dac from tqdm import tqdm from concurrent.futures import ThreadPoolExecutor, as_completed import multiprocessing
def single_job(utt, utt2wav, dac_model, device): """单个任务的执行逻辑""" torch.cuda.set_device(device) audio, sample_rate = torchaudio.load(utt2wav[utt]) if sample_rate != 44100: audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=44100)(audio)
with torch.no_grad(): _, speech_token, _, _, _ = dac_model.encode(audio.to(device).unsqueeze(0)) return utt, speech_token.to('cpu').numpy()
def process_chunk(chunk, utt2wav, device, result_dict): """多进程任务,每个进程处理一个子集""" torch.cuda.set_device(device) dac_model = dac.DAC.load('/data/yuanxin/pretrained_ckpt/weights.pth').eval().to(device)
with ThreadPoolExecutor(max_workers=4) as executor: future_to_utt = {executor.submit(single_job, utt, utt2wav, dac_model, device): utt for utt in chunk} for future in tqdm(as_completed(future_to_utt), total=len(chunk), desc=f"Processing on GPU {device}"): utt, speech_token = future.result() result_dict[utt] = speech_token
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dir", type=str) args = parser.parse_args() torch.multiprocessing.set_start_method('spawn')# good solution !!!! # 读取 utt2wav 映射 utt2wav = {} with open(f'{args.dir}/target_clean.scp') as f: for l in f: l = l.strip().split(' ') utt2wav[l[0]] = ' '.join(l[1:])
# 任务分割,每个进程 5000 个 utt chunk_size = 5000 keys = list(utt2wav.keys()) chunks = [keys[i:i + chunk_size] for i in range(0, len(keys), chunk_size)]
# 进程间共享结果字典 manager = multiprocessing.Manager() result_dict = manager.dict()
# 启动多个进程,每个进程绑定一个 GPU processes = [] num_gpus = torch.cuda.device_count() for i, chunk in enumerate(chunks): device = i % num_gpus # 轮询分配 GPU p = multiprocessing.Process(target=process_chunk, args=(chunk, utt2wav, device, result_dict)) processes.append(p) p.start()
# 等待所有进程完成 for p in processes: p.join()
# 保存结果 torch.save(dict(result_dict), f'{args.dir}/utt2speech_token.pt')
|