1.安裝nemo
pip install -U nemo_toolkit[all] ASR-metrics
2.下載ASR預(yù)訓(xùn)練模型到本地(建議使用huggleface,比nvidia官網(wǎng)快很多)
3.從本地創(chuàng)建ASR模型
asr_model = nemo_asr.models.EncDecCTCModel.restore_from("stt_zh_quartznet15x5.nemo")
3.定義train_mainfest,包含語音文件路徑、時長和語音文本的json文件
{"audio_filepath": "test.wav", "duration": 8.69, "text": "誒前天跟我說昨天跟我說十二期利率是多少工號幺九零八二六十二期的話零點八一萬的話分十二期利息八十嘛"}
4.讀取模型的yaml配置
# 使用YAML讀取quartznet模型配置文件
try:
? ? from ruamel.yaml import YAML
except ModuleNotFoundError:
? ? from ruamel_yaml import YAML
config_path ="/NeMo/examples/asr/conf/quartznet/quartznet_15x5_zh.yaml"
yaml = YAML(typ='safe')
with open(config_path) as f:
? ? params = yaml.load(f)
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])
5.設(shè)置訓(xùn)練及驗證manifest
train_manifest = "train_manifest.json"
val_manifest = "train_manifest.json"
params['model']['train_ds']['manifest_filepath']=train_manifest
params['model']['validation_ds']['manifest_filepath']=val_manifest
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])
asr_model.setup_training_data(train_data_config=params['model']['train_ds'])
asr_model.setup_validation_data(val_data_config=params['model']['validation_ds'])
6.使用pytorch_lightning訓(xùn)練
import pytorch_lightning as pl?
trainer = pl.Trainer(accelerator='gpu', devices=1,max_epochs=10)
trainer.fit(asr_model)#調(diào)用‘fit’方法開始訓(xùn)練?
7.保存訓(xùn)練好的模型
asr_model.save_to('my_stt_zh_quartznet15x5.nemo')
8.看看訓(xùn)練后的效果
my_asr_model = nemo_asr.models.EncDecCTCModel.restore_from("my_stt_zh_quartznet15x5.nemo")
queries=my_asr_model.transcribe(['test1.wav'])
print(queries)
#['誒前天跟我說的昨天跟我說十二期利率是多少工號幺九零八二六零十二期的話零點八一萬的話分十二期利息八十嘛']
9.計算字錯率
from ASR_metrics import utils as metrics
s1 = "誒前天跟我說昨天跟我說十二期利率是多少工號幺九零八二六十二期的話零點八一萬的話分十二期利息八十嘛"#指定正確答案
s2 = " ".join(queries)#識別結(jié)果
print("字錯率:{}".format(metrics.calculate_cer(s1,s2)))#計算字錯率cer
print("準(zhǔn)確率:{}".format(1-metrics.calculate_cer(s1,s2)))#計算準(zhǔn)確率accuracy
#字錯率:0.041666666666666664
#準(zhǔn)確率:0.9583333333333334
10.增加標(biāo)點符號輸出
from zhpr.predict import DocumentDataset,merge_stride,decode_pred
from transformers import AutoModelForTokenClassification,AutoTokenizer
from torch.utils.data import DataLoader
def predict_step(batch,model,tokenizer):
? ? ? ? batch_out = []
? ? ? ? batch_input_ids = batch
? ? ? ? encodings = {'input_ids': batch_input_ids}
? ? ? ? output = model(**encodings)
? ? ? ? predicted_token_class_id_batch = output['logits'].argmax(-1)
? ? ? ? for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
? ? ? ? ? ? out=[]
? ? ? ? ? ? tokens = tokenizer.convert_ids_to_tokens(input_ids)
? ? ? ? ? ??
? ? ? ? ? ? # compute the pad start in input_ids
? ? ? ? ? ? # and also truncate the predict
? ? ? ? ? ? # print(tokenizer.decode(batch_input_ids))
? ? ? ? ? ? input_ids = input_ids.tolist()
? ? ? ? ? ? try:
? ? ? ? ? ? ? ? input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
? ? ? ? ? ? except:
? ? ? ? ? ? ? ? input_id_pad_start = len(input_ids)
? ? ? ? ? ? input_ids = input_ids[:input_id_pad_start]
? ? ? ? ? ? tokens = tokens[:input_id_pad_start]
? ??
? ? ? ? ? ? # predicted_token_class_ids
? ? ? ? ? ? predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
? ? ? ? ? ? predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]
? ? ? ? ? ? for token,ner in zip(tokens,predicted_tokens_classes):
? ? ? ? ? ? ? ? out.append((token,ner))
? ? ? ? ? ? batch_out.append(out)
? ? ? ? return batch_out
if __name__ == "__main__":
? ? window_size = 256
? ? step = 200
? ? text = queries[0]
? ? dataset = DocumentDataset(text,window_size=window_size,step=step)
? ? dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)
? ? model_name = 'zh-wiki-punctuation-restore'
? ? model = AutoModelForTokenClassification.from_pretrained(model_name)
? ? tokenizer = AutoTokenizer.from_pretrained(model_name)文章來源:http://www.zghlxwxcb.cn/news/detail-644162.html
? ? model_pred_out = []
? ? for batch in dataloader:
? ? ? ? batch_out = predict_step(batch,model,tokenizer)
? ? ? ? for out in batch_out:
? ? ? ? ? ? model_pred_out.append(out)
? ? ? ??
? ? merge_pred_result = merge_stride(model_pred_out,step)
? ? merge_pred_result_deocde = decode_pred(merge_pred_result)
? ? merge_pred_result_deocde = ''.join(merge_pred_result_deocde)
? ? print(merge_pred_result_deocde)
#誒前天跟我說的。昨天跟我說十二期利率是多少。工號幺九零八二六零十二期的話,零點八一萬的話,分十二期利息八十嘛。文章來源地址http://www.zghlxwxcb.cn/news/detail-644162.html
到了這里,關(guān)于NeMo中文/英文ASR模型微調(diào)訓(xùn)練實踐的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!