1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
| class G2PWConverter: def __init__(self, model_dir='G2PWModel/', style='bopomofo', model_source=None, use_cuda=False, num_workers=None, batch_size=None, turnoff_tqdm=True, enable_non_tradional_chinese=False): if not os.path.exists(os.path.join(model_dir, 'best_accuracy.pth')): download_model(model_dir) self.model_dir = model_dir self.config = load_config(os.path.join(model_dir, 'config.py'), use_default=True)
self.num_workers = num_workers if num_workers else self.config.num_workers self.batch_size = batch_size if batch_size else self.config.batch_size self.model_source = model_source if model_source else self.config.model_source self.turnoff_tqdm = turnoff_tqdm self.enable_opencc = enable_non_tradional_chinese
self.device = torch.device('cuda' if use_cuda else 'cpu')
self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)
polyphonic_chars_path = os.path.join(model_dir, 'POLYPHONIC_CHARS.txt') polyphonic_chars_path_s = os.path.join(model_dir, 'POLYPHONIC_CHARS_S.txt') monophonic_chars_path = os.path.join(model_dir, 'MONOPHONIC_CHARS.txt') self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path).read().strip().split('\n')] self.polyphonic_chars_s = [line.split('\t') for line in open(polyphonic_chars_path_s).read().strip().split('\n')] # polyphonic_chars 8022 self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path).read().strip().split('\n')] # monophonic_chars 9476 self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars) self.labels_s, self.char2phonemes_s = get_char_phoneme_labels( self.polyphonic_chars_s) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars_s) # self.labels 1305, 共 1305个 bopomofo 发音 # char2phonemes,共 3582 个多因字符 self.chars = sorted(list(self.char2phonemes.keys())) self.pos_tags = TextDataset.POS_TAGS
self.model = G2PW.from_pretrained( self.model_source, labels=self.labels, chars=self.chars, pos_tags=self.pos_tags, use_conditional=self.config.use_conditional, param_conditional=self.config.param_conditional, use_focal=self.config.use_focal, param_focal=self.config.param_focal, use_pos=self.config.use_pos, param_pos=self.config.param_pos ) checkpoint = os.path.join(model_dir, 'best_accuracy.pth') self.model.load_state_dict(torch.load(checkpoint, map_location=self.device)) self.model.to(self.device)
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bopomofo_to_pinyin_wo_tune_dict.json'), 'r') as fr: self.bopomofo_convert_dict = json.load(fr) # 这里才有 424 个,这个 wo tune 到底是啥,这里应该就是 bopomofo 对应的 pinyin, 出错不应该是这里 self.style_convert_func = { 'bopomofo': lambda x: x, 'pinyin': self._convert_bopomofo_to_pinyin, }[style]
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'char_bopomofo_dict.json'), 'r') as fr: self.char_bopomofo_dict = json.load(fr) # char_bopomofo_dict: 汉字到bopomofo的字典,41497
if self.enable_opencc: self.cc = OpenCC('s2tw') # 将中文简体转换为繁体,以台湾标准
def _convert_bopomofo_to_pinyin(self, bopomofo): tone = bopomofo[-1] assert tone in '12345' component = self.bopomofo_convert_dict.get(bopomofo[:-1]) if component: return component + tone else: print(f'Warning: "{bopomofo}" cannot convert to pinyin') return None
def __call__(self, sentences):
# s = time.time() if isinstance(sentences, str): sentences = [sentences] # e = time.time() # print('第一步时间: %4f' % (e - s)) # s = time.time() if self.enable_opencc: translated_sentences = [] for sent in sentences: translated_sent = self.cc.convert(sent) assert len(translated_sent) == len(sent) translated_sentences.append(translated_sent) sentences = translated_sentences # e = time.time() # print('第二步时间: %4f' % (e - s)) # s = time.time() texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences) # e = time.time() # print('第三步时间: %4f' % (e - s)) # s = time.time() dataset = TextDataset(self.tokenizer, self.labels, self.char2phonemes_s, self.chars, texts, query_ids, use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme, window_size=self.config.window_size, for_train=False)
dataloader = DataLoader( dataset=dataset, batch_size=self.batch_size, collate_fn=dataset.create_mini_batch, num_workers=self.num_workers ) # e = time.time() # print('time data: %4f' % (e-s)) s = time.time() preds, confidences = predict(self.model, dataloader, self.device, self.labels, turnoff_tqdm=self.turnoff_tqdm) e = time.time() print('time predict: %4f' % (e-s)) if self.config.use_char_phoneme: preds = [pred.split(' ')[1] for pred in preds]
# s = time.time() results = partial_results for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): if self.model_dir == 'G2PWModel/': results[sent_id][query_id] = self.style_convert_func(pred) else: results[sent_id][query_id] = pred # e = time.time() # print('time results: %4f' % (e - s)) return results
def _prepare_data(self, sentences): # polyphonic_chars = set(self.chars) # 3582 个多音字 (这个多音字是缺失的,比如没有 “长”) polyphonic_chars = set(sorted(list(self.char2phonemes_s.keys()))) # 3582 个多音字 (这个多音字是缺失的,比如没有 “长”) # 这里我新增了一个搜索 多音字字典,由于网络中用到了 char_id Embedding, 所以原始的不能修改, # 我想在获取 char_id 中 利用随机数去预测 monophonic_chars_dict = { char: phoneme for char, phoneme in self.monophonic_chars } texts, query_ids, sent_ids, partial_results = [], [], [], [] for sent_id, sent in enumerate(sentences): partial_result = [None] * len(sent) for i, char in enumerate(sent): if char in polyphonic_chars: texts.append(sent) query_ids.append(i) sent_ids.append(sent_id) elif char in monophonic_chars_dict: # 先去 自己定义的 单音字字典进行匹配,还是没有 “长” 字 partial_result[i] = self.style_convert_func(monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: # 再去最大的字典匹配(这里就是问题所在了,如果是多音字,默认仅仅取第1个定义的发音!) partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) partial_results.append(partial_result) return texts, query_ids, sent_ids, partial_results
|