0%

G2PW

摘要

G2PW:A ConditionalWeighted Softmax BERT for Polyphone Disambiguation in Mandarin

论文最亮点的地方是公开了基于 Mandarin Polyphone dataset with Bopomofo (MPB) 数据集训练的模型参数👍,MPB数据集共包含436个多音字,2610344条包含多音字的文本。

主要框架

模型整体框架如下图:

g2pw

conditional weight layer框架如下图:

cond-weight

主要问题如下:

  • 提供的预训练模型中的 POLYPHONIC_CHARS.txt 字典存在不足,会存在某些常见多音字不在多音字字典的现象,文末提供了改善方法
  • 最大的痛点是推断慢,这个是模型本身设计导致的,需要修改框架,改了框架我就没有这么多数据去训练了🤣。推断慢的原因是每个句子会出现很多多音字,如欢迎拨打卫健委流调电话,小卫为您服务。共有5个多音字,由于模型是按照每个字去训练的,则一句话需要反复推断5次,这里比较浪费时间和资源。

代码详解

从上面的框架图可以看出,相对来说很简单了。

数据处理

利用 Pytorch 构建 dataset,通过处理生成如下数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
outputs = {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'phoneme_mask': phoneme_mask,
'char_id': char_id,
'position_id': position_id,
'pos_id' = pos_id,
'label_id' = label_id
}

1. input_ids: 原始文本对应的id, 这个id是根据 bert 中 tokenizer 对应的
2. token_type_ids:全为 0 的向量,这个是为 bert 提供的,全为0表示都是上半句
3. attention_mask:全为 1 的向量,这个是为 bert 提供的,全为1表示不mask,所有词均参与计算
4. phoneme_mask:多音字 hard mask, 向量长度为所有多音字可能发音的长度(CPP数据集为650),多音字对应字符可能发音的为1,其余位置为0
5. char_ids:多音字 id,这个是多音字处于多音字序列的位置id,用于后续Embedding
6. position_ids:多音字所在句子中的位置id,用于截取Bert输出的结果
7. pos_ids:# 表示多音字对应的 POS,用于后续POS预测与利用POS帮助多音字预测
8. label_ids:多音字的 GroundTruth

数据处理过程有如下需要注意的点:

  • 对原始 sentence 按照 window=32 进行了切分,即以多音字为中心,左右两边文本的长度最大为16,超过的则剪切。
  • 多音字POS是利用ckiptagger package 生成的,在训练阶段,利用了 Teacher Forcing的做法,即利用真实的POS去干预后续多音字预测。

核心模型

核心是 BERT 和 线性层,核心模型注意的点如下:

  • BERT没有固定参数,一同进行训练;
  • conditional weight layer:三个nn.Embedding,一个是多音字id,一个是多音字与POS结合的id,一个是bias,然后通过sigmoid获得每个个可能label_ids的概率,这个是soft,然后乘以最后的phoneme_mask(这个是hard mask),得到最后的weight。
  • POS用于多音字预测中引入了Teacher Forcing
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# 这里是针对CPP参数简化的模型

class G2PW(BertPreTrainedModel):
def __init__(self, model_source, labels, chars, pos_tags,
use_conditional=False, param_conditional=None,
use_focal=False, param_focal=None,
use_pos=False, param_pos=None):
super().__init__(model_source) # bert-base-chinese

self.num_labels = len(labels) # 650 个多音字 phoneme
self.num_chars = len(chars) # 623 个多音字符
self.num_pos_tags = len(pos_tags) # 11 个词性标签

self.bert = BertModel(self.config)

self.classifier = nn.Linear(self.config.hidden_size, self.num_labels) # 768 ---> 650

self.use_conditional = use_conditional # True
self.param_conditional = param_conditional
if self.use_conditional:
conditional_affect_location = self.param_conditional['affect_location']
target_size = self.config.hidden_size if conditional_affect_location == 'emb' else self.num_labels

if self.param_conditional['bias']: # True
self.descriptor_bias = nn.Embedding(1, target_size)
if self.param_conditional['char-linear']: # True
self.char_descriptor = nn.Embedding(self.num_chars, target_size)
if self.param_conditional['char+pos-second']: # True
self.second_order_descriptor = nn.Embedding(self.num_chars * self.num_pos_tags, target_size)

self.use_pos = use_pos # True
self.param_pos = param_pos #
# 'param_pos ': {
# 'weight': 0.1,
# 'pos_joint_training': True,
# 'train_pos_path': 'train.pos',
# 'valid_pos_path': 'dev.pos',
# 'test_pos_path': 'test.pos'
# }
if self.use_pos and self.param_pos['pos_joint_training']:
self.pos_classifier = nn.Linear(self.config.hidden_size, self.num_pos_tags)
# pos 分类

def _weighted_softmax(self, logits, weights, eps):
max_logits, _ = torch.max(logits, dim=-1, keepdim=True)
weighted_exp_logits = torch.exp(logits - max_logits) * weights
# 这里线性层之后手写了 softmax 函数,为的就是 * weights
norm = torch.sum(weighted_exp_logits, dim=-1, keepdim=True)
probs = weighted_exp_logits / norm
probs = torch.clamp(probs, min=eps, max=1-eps)
return probs

def forward(self, input_ids, token_type_ids, attention_mask, phoneme_mask, char_ids, position_ids, pos_ids=None, label_ids=None, eps=1e-6):
transformers_major_ver = int(transformers.__version__.split('.')[0])
# 4.6.1
if transformers_major_ver >= 4:
sequence_output, pooled_output = self.bert(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
return_dict=False
)
else:
sequence_output, pooled_output = self.bert(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask
)
# sequence_output [256, 34, 768]
batch_size = input_ids.size(0)
orig_selected_hidden = sequence_output[torch.arange(batch_size), position_ids]
selected_hidden = orig_selected_hidden # [256, 768]
if self.use_conditional:
if (self.param_conditional['char+pos-second']
or self.param_conditional['char+pos-second_lowrank']
or self.param_conditional['char+pos-second_fm']
or self.param_conditional['pos-linear']
or self.param_conditional['fix_mode'] == 'count_distr:char+pos'):
pred_pos_ids = pos_ids if self.training or not self.param_pos['pos_joint_training'] \
else self.pos_classifier(orig_selected_hidden).argmax(dim=-1) # teacher mode while training
# pred_pos_ids [256]
affect_terms = []
if self.param_conditional['bias']: # True
bias_tensor = self.descriptor_bias(torch.zeros_like(char_ids))
# print(char_ids.shape) [256]
# print(bias_tensor.shape) [256, 650]
affect_terms.append(bias_tensor)
if self.param_conditional['char-linear']: # True
affect_terms.append(self.char_descriptor(char_ids))
if self.param_conditional['char+pos-second']: # true
char_pos_ids = self._get_char_pos_ids(char_ids, pred_pos_ids)
affect_terms.append(self.second_order_descriptor(char_pos_ids))
affect_hidden = sum(affect_terms)
# softmax
phoneme_mask = phoneme_mask * torch.sigmoid(affect_hidden)

logits = self.classifier(selected_hidden)
probs = self._weighted_softmax(logits, phoneme_mask, eps)
if label_ids is not None:
if self.use_focal:
loss_layer = ModifiedFocalLoss(alpha=self.param_focal['alpha'], gamma=self.param_focal['gamma'])
loss = loss_layer(probs, label_ids)
else:
loss_layer = nn.NLLLoss()
log_probs = torch.log(probs)
loss = loss_layer(log_probs, label_ids)
# 最后多音字的损失

pos_logits = None
if self.use_pos and pos_ids is not None and self.param_pos['pos_joint_training']:
pos_logits = self.pos_classifier(orig_selected_hidden)
loss_fct = nn.CrossEntropyLoss() # nn.logSoftmax()和nn.NLLLoss()的结合
pos_loss = loss_fct(pos_logits, pos_ids)
scaling = self._get_pos_loss_scaling_when_using_focal(probs, label_ids) if self.use_focal else 1.
loss += self.param_pos['weight'] * scaling * pos_loss

return probs, loss, pos_logits
else:
return probs

推断代码

  • 简易版推断
1
2
3
4
from poly import G2PWConverter
conv = G2PWConverter(model_dir="saved_models/CPP_BERT_M_DescWS-Sec-cLin-B_POSw01/",
style='pinyin', enable_non_tradional_chinese=False, use_cuda=True)
conv('你好')

测试结果

  • 推断代码中bopomofo多音字存在问题,这是字典设置的问题。
  • 推断过程运行较满,每句大概需要0.2s,直接用到TTS肯定不行,需要重写G2PWConverter提供的api
1
2
3
4
5
6
7
8
9
10
预测准确的案例:
{
卫健委流调<diao4, tiao2>电话,为<wei2, wei4>您服务
[['wei4', 'jian4', 'wei3', 'liu2', 'diao4', 'dian4', 'hua4', None, 'wei4', 'nin2', 'fu2', 'wu4']]
}
预测失败的案例:(这里是bopomofo字典问题,后续需要改进)
{
长<chang2, zhang3>期
[['zhang3', 'qi1']]
}

改进

  • 完整版本推断
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class G2PWConverter:
def __init__(self, model_dir='G2PWModel/', style='bopomofo', model_source=None, use_cuda=False, num_workers=None, batch_size=None,
turnoff_tqdm=True, enable_non_tradional_chinese=False):
if not os.path.exists(os.path.join(model_dir, 'best_accuracy.pth')):
download_model(model_dir)
self.model_dir = model_dir
self.config = load_config(os.path.join(model_dir, 'config.py'), use_default=True)

self.num_workers = num_workers if num_workers else self.config.num_workers
self.batch_size = batch_size if batch_size else self.config.batch_size
self.model_source = model_source if model_source else self.config.model_source
self.turnoff_tqdm = turnoff_tqdm
self.enable_opencc = enable_non_tradional_chinese

self.device = torch.device('cuda' if use_cuda else 'cpu')

self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)

polyphonic_chars_path = os.path.join(model_dir, 'POLYPHONIC_CHARS.txt')
polyphonic_chars_path_s = os.path.join(model_dir, 'POLYPHONIC_CHARS_S.txt')
monophonic_chars_path = os.path.join(model_dir, 'MONOPHONIC_CHARS.txt')
self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path).read().strip().split('\n')]
self.polyphonic_chars_s = [line.split('\t') for line in open(polyphonic_chars_path_s).read().strip().split('\n')]
# polyphonic_chars 8022
self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path).read().strip().split('\n')]
# monophonic_chars 9476
self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars)
self.labels_s, self.char2phonemes_s = get_char_phoneme_labels(
self.polyphonic_chars_s) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars_s)
# self.labels 1305, 共 1305个 bopomofo 发音
# char2phonemes,共 3582 个多因字符
self.chars = sorted(list(self.char2phonemes.keys()))
self.pos_tags = TextDataset.POS_TAGS

self.model = G2PW.from_pretrained(
self.model_source,
labels=self.labels,
chars=self.chars,
pos_tags=self.pos_tags,
use_conditional=self.config.use_conditional,
param_conditional=self.config.param_conditional,
use_focal=self.config.use_focal,
param_focal=self.config.param_focal,
use_pos=self.config.use_pos,
param_pos=self.config.param_pos
)
checkpoint = os.path.join(model_dir, 'best_accuracy.pth')
self.model.load_state_dict(torch.load(checkpoint, map_location=self.device))
self.model.to(self.device)

with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),
'bopomofo_to_pinyin_wo_tune_dict.json'), 'r') as fr:
self.bopomofo_convert_dict = json.load(fr)
# 这里才有 424 个,这个 wo tune 到底是啥,这里应该就是 bopomofo 对应的 pinyin, 出错不应该是这里
self.style_convert_func = {
'bopomofo': lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin,
}[style]

with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),
'char_bopomofo_dict.json'), 'r') as fr:
self.char_bopomofo_dict = json.load(fr)
# char_bopomofo_dict: 汉字到bopomofo的字典,41497

if self.enable_opencc:
self.cc = OpenCC('s2tw') # 将中文简体转换为繁体,以台湾标准

def _convert_bopomofo_to_pinyin(self, bopomofo):
tone = bopomofo[-1]
assert tone in '12345'
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
else:
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None

def __call__(self, sentences):

# s = time.time()
if isinstance(sentences, str):
sentences = [sentences]
# e = time.time()
# print('第一步时间: %4f' % (e - s))
# s = time.time()
if self.enable_opencc:
translated_sentences = []
for sent in sentences:
translated_sent = self.cc.convert(sent)
assert len(translated_sent) == len(sent)
translated_sentences.append(translated_sent)
sentences = translated_sentences
# e = time.time()
# print('第二步时间: %4f' % (e - s))
# s = time.time()
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences)
# e = time.time()
# print('第三步时间: %4f' % (e - s))
# s = time.time()
dataset = TextDataset(self.tokenizer, self.labels, self.char2phonemes_s, self.chars, texts, query_ids,
use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme,
window_size=self.config.window_size, for_train=False)

dataloader = DataLoader(
dataset=dataset,
batch_size=self.batch_size,
collate_fn=dataset.create_mini_batch,
num_workers=self.num_workers
)
# e = time.time()
# print('time data: %4f' % (e-s))
s = time.time()
preds, confidences = predict(self.model, dataloader, self.device, self.labels, turnoff_tqdm=self.turnoff_tqdm)
e = time.time()
print('time predict: %4f' % (e-s))
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]

# s = time.time()
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
if self.model_dir == 'G2PWModel/':
results[sent_id][query_id] = self.style_convert_func(pred)
else:
results[sent_id][query_id] = pred
# e = time.time()
# print('time results: %4f' % (e - s))
return results

def _prepare_data(self, sentences):
# polyphonic_chars = set(self.chars) # 3582 个多音字 (这个多音字是缺失的,比如没有 “长”)
polyphonic_chars = set(sorted(list(self.char2phonemes_s.keys()))) # 3582 个多音字 (这个多音字是缺失的,比如没有 “长”)
# 这里我新增了一个搜索 多音字字典,由于网络中用到了 char_id Embedding, 所以原始的不能修改,
# 我想在获取 char_id 中 利用随机数去预测
monophonic_chars_dict = {
char: phoneme for char, phoneme in self.monophonic_chars
}
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
partial_result = [None] * len(sent)
for i, char in enumerate(sent):
if char in polyphonic_chars:
texts.append(sent)
query_ids.append(i)
sent_ids.append(sent_id)
elif char in monophonic_chars_dict:
# 先去 自己定义的 单音字字典进行匹配,还是没有 “长” 字
partial_result[i] = self.style_convert_func(monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
# 再去最大的字典匹配(这里就是问题所在了,如果是多音字,默认仅仅取第1个定义的发音!)
partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
partial_results.append(partial_result)
return texts, query_ids, sent_ids, partial_results

这里主要修改的点是:

  • 重新加载了多音字字典,构建了POLYPHONIC_CHARS_S.txt用于检索sentences是否存在多音字,这里我们就可以新增多音字了,大大增加了可扩展性。
  • 由于模型中应用了 char_id 多音字ID进行了 Embedding 辅助训练,这里通过POLYPHONIC_CHARS_S.txt新增的多音字是没有对应的 char_id 的,我重写了 dataset部分,将没有对应的多音字 char_id 对应为0。这个没有太多道理,我觉得在模型部分 BERT 的重要程度要远远大于 char_id_embedding, 事实证明确实是的。
1
2
3
4
5
# dataset部分修改
char_id = self.chars.index(query_char) if query_char in self.chars else 0
phoneme_mask = [1 if i in self.char2phonemes[query_char] else 0 for i in range(len(self.labels))] \
if self.use_mask else [1] * len(self.labels)
# 注意这里 char2phonemes 是根据新的多音字字典生成的。

还可能出现的问题是:

  • 新增了多音字,但是多音字的发音不在 phonemes 中。这种就很难该了,因为推断的 labels_num 已经由模型固定了,需要重新修改模型框架,再训练。

快速推理的思考

  • 将一句中所有多音字都标识出来,一句话仅仅推断一次;
  • 将韵律预测模块进行拼接,减少一个模型。可以参考的论文Unified Mandarin TTS Front-end Based on Distilled BERT Model
  • 将基础模型 bert-base-chinese 换成更小更快的预训练语言模型,或者进行蒸馏。