提示:文章寫完后,目錄可以自動(dòng)生成,如何生成可參考右邊的幫助文檔
前言
文本識(shí)別是圖像領(lǐng)域的一個(gè)常見任務(wù),場(chǎng)景文字識(shí)別OCR任務(wù)中,需要先檢測(cè)出圖像中文字位置,再對(duì)檢測(cè)出的文字進(jìn)行識(shí)別,文本介紹的CRNN模型可用于后者, 對(duì)檢測(cè)出的文字進(jìn)行識(shí)別。
An End-to-End Trainable Neural Network for Image-Based Sequence Recognition and Its Application to Scene Text Recognition
原論文地址:論文地址
一、CRNN模型介紹
1.模型結(jié)構(gòu)
CRNN模型結(jié)合了CNN模型與RNN模型,CNN用于提取圖像特征,RNN將CNN提取的特征進(jìn)行處理得到輸出,對(duì)應(yīng)最終的標(biāo)簽。
CRNN包含三層,卷積層,循環(huán)層和轉(zhuǎn)錄層,由于每張圖像中英文單詞的長(zhǎng)度不一致,但是經(jīng)過(guò)CNN之后提取的特征長(zhǎng)度是一定的,所以就需要一個(gè)轉(zhuǎn)錄層處理,得到最終結(jié)果。
該圖為模型的大體結(jié)構(gòu)。
輸入模型的是一張圖像,其shape是(1,32,100) (channel,width,height),
經(jīng)過(guò)一個(gè)卷積神經(jīng)網(wǎng)絡(luò)之后,其shape變成(512,1,24)(new_channel,new_height,new_width),把channel和height這兩個(gè)維度合并,合并后shape(512,24),再將這兩個(gè)維度交換位置,(24,512)(new_width,new_height*new_channel),由于后續(xù)需要將提取的特征輸入循環(huán)神經(jīng)網(wǎng)絡(luò),這個(gè)24就相當(dāng)于是時(shí)間步了,24個(gè)時(shí)間步。輸出特征圖shape是(24,512)可以理解為,把原圖分成24列,每一列用512維的特征向量表示。如下圖所示
將24個(gè)特征向量輸入進(jìn)循環(huán)神經(jīng)網(wǎng)絡(luò),論文中循環(huán)神經(jīng)網(wǎng)絡(luò)層是兩個(gè)LSTM堆疊而成的,經(jīng)過(guò)后就得到24個(gè)時(shí)間步的輸出,再經(jīng)過(guò)全連接層以及softmax層得到一個(gè)概率矩陣,形狀為(T,num_class),T是時(shí)間步,num_class是要分類的類別數(shù),是0-9數(shù)字以及a-z字母組合,還有一個(gè)blank標(biāo)識(shí)符,總共37類。時(shí)間步輸出是24個(gè),但是圖片中字符數(shù)不一定都是24,長(zhǎng)短不一,經(jīng)過(guò)轉(zhuǎn)錄層將其處理。
2.CTCLoss
如果使用傳統(tǒng)的loss function,需要對(duì)齊訓(xùn)練樣本,有24個(gè)時(shí)間步,就需要有24個(gè)對(duì)應(yīng)的標(biāo)簽,在該任務(wù)中顯然不合適,除非可以把圖片中的每一個(gè)字符都單獨(dú)檢測(cè)出來(lái),一個(gè)字符對(duì)應(yīng)一個(gè)標(biāo)簽,則需要很強(qiáng)大的文字檢測(cè)算法,CTCLoss不需要對(duì)齊樣本。
還是24個(gè)時(shí)間步得到24個(gè)標(biāo)簽,再進(jìn)行一個(gè)β變換,才得到最終標(biāo)簽。24個(gè)時(shí)間步可以看作原圖中分成24列,每一列輸出一個(gè)標(biāo)簽,有時(shí)一個(gè)字母占據(jù)好幾列,例如字母S占據(jù)三列,則這三列輸出類別都應(yīng)該是S,有的列沒(méi)有字母,則輸出空白類別,可以這么理解。得到最終類別時(shí)將連續(xù)重復(fù)的字符去重(空白符兩側(cè)的相同字符不去重,因?yàn)檎鎸?shí)標(biāo)簽中可能存在連續(xù)重復(fù)字符,例如green,中的兩個(gè)連續(xù)的e不應(yīng)該去重,則生成標(biāo)簽的時(shí)候就該是類似e-e這種,則不會(huì)去重),最終去除空白符即可得到最終標(biāo)簽。
β變換定義如下
β
:
L
′
T
→
L
<
=
T
\beta :L^{'T} →L^{<=T}
β:L′T→L<=T
T代表時(shí)間步,長(zhǎng)度,由于對(duì)連續(xù)重復(fù)字符去重,則處理后的長(zhǎng)度一定小于T
舉幾個(gè)β變換的例子,空白用-表示
β
(
?
?
s
s
t
a
a
a
t
?
e
e
)
=
s
t
a
t
e
\beta(--sstaaat-ee)=state
β(??sstaaat?ee)=state
β
(
?
?
s
?
t
t
?
a
?
t
?
e
)
=
s
t
a
t
e
\beta(--s-tt-a-t-e)=state
β(??s?tt?a?t?e)=state
β
(
?
s
?
s
t
?
a
a
t
?
e
)
=
s
s
t
a
t
e
\beta(-s-st-aat-e)=sstate
β(?s?st?aat?e)=sstate
β
(
?
s
?
t
t
a
?
t
t
?
e
e
)
=
s
t
a
t
e
\beta(-s-tta-tt-ee)=state
β(?s?tta?tt?ee)=state
可以看出若想要輸出state,不止一條路徑可以實(shí)現(xiàn)輸出state.
經(jīng)過(guò)LSTM后的結(jié)果需要送入轉(zhuǎn)錄層處理,設(shè)LSTM的輸出標(biāo)簽序列為x,輸出標(biāo)簽為l的概率為:
p
(
l
∣
x
)
=
∑
π
∈
β
?
(
l
)
p
(
π
∣
x
)
p(l|x)=\sum_{\pi \in \beta ^{-}(l) }p(\pi |x)
p(l∣x)=π∈β?(l)∑?p(π∣x)
π
∈
β
?
(
l
)
\pi \in \beta ^{-}(l)
π∈β?(l)表示經(jīng)過(guò)β變換后為l的路徑集合
π
\pi
π
對(duì)于每一條路徑
π
\pi
π有
p
(
π
∣
x
)
=
∏
t
=
1
T
y
π
t
t
p(\pi |x)=\prod_{t=1}^{T}y_{\pi ^{t}}^{t }
p(π∣x)=t=1∏T?yπtt?
y
π
t
t
y_{\pi ^{t}}^{t }
yπtt?表示該路徑第t個(gè)時(shí)間步取得該標(biāo)簽的一個(gè)概率,連乘起來(lái)就是取得該路徑的概率。
CTCLoss的優(yōu)化目標(biāo)是使得
p
(
l
∣
x
)
=
∑
π
∈
β
?
(
l
)
p
(
π
∣
x
)
p(l|x)=\sum_{\pi \in \beta ^{-}(l) }p(\pi |x)
p(l∣x)=∑π∈β?(l)?p(π∣x)最大,所以
l
o
s
s
=
?
p
(
l
∣
x
)
=
∑
π
∈
β
?
(
l
)
p
(
π
∣
x
)
loss=-p(l|x)=\sum_{\pi \in \beta ^{-}(l) }p(\pi |x)
loss=?p(l∣x)=∑π∈β?(l)?p(π∣x),使得該loss最小化,來(lái)更新前面lstm以及cnn的參數(shù),由于CTCLoss計(jì)算有些復(fù)雜,暫不討論。Pytorch中提供了CTCLoss的計(jì)算接口,我們直接使用即可。
from torch.nn import CTCLoss
beam search
訓(xùn)練階段使用CTCLoss更新參數(shù),測(cè)試階段如果使用暴力解法,算出每條路徑的一個(gè)概率,最終取最大概率的一個(gè)路徑,時(shí)間復(fù)雜度非常大,如果有37個(gè)類別,序列長(zhǎng)度是24,那么路徑總和是 3 7 24 37^{24} 3724,這只是一個(gè)樣本的路徑數(shù) 。所以就需要用到beam search來(lái)優(yōu)化計(jì)算過(guò)程。
計(jì)算過(guò)程如圖所示,現(xiàn)在第一個(gè)時(shí)間步中找到概率最大的三(可以自由設(shè)置)個(gè)標(biāo)簽,以這三個(gè)最大概率的標(biāo)簽為基礎(chǔ)再往后搜索,在第二步會(huì)在第一步的概率基礎(chǔ)上(需要以第一步的三個(gè)標(biāo)簽的概率乘以后面的標(biāo)簽概率)搜索出九個(gè)標(biāo)簽,在這九個(gè)標(biāo)簽中取三個(gè)最大的 ,繼續(xù)往后搜索,以此類推,在經(jīng)過(guò)最后一個(gè)時(shí)間步后會(huì)得到三條路徑,取概率最大的那條,在經(jīng)過(guò)CTC decode即可得到最終label。
二、使用pytorch實(shí)現(xiàn)crnn
數(shù)據(jù)集
將好幾個(gè)數(shù)據(jù)集合并并做了相關(guān)處理,得到八千多張圖片
只在這里展示關(guān)鍵部分代碼
代碼以及數(shù)據(jù)集在鏈接:https://pan.baidu.com/s/1j1sUFIgdB1qga1Cfrh-jlw
提取碼:lf2m
dataset.py
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
class Synth90kDataset(Dataset):
CHARS = '0123456789abcdefghijklmnopqrstuvwxyz'
CHAR2LABEL = {char: i + 1 for i, char in enumerate(CHARS)}
LABEL2CHAR = {label: char for char, label in CHAR2LABEL.items()}
def __init__(self, root_dir=None,image_dir = None, mode=None, file_names=None, img_height=32, img_width=100):
if mode == "train":
file_names, texts = self._load_from_raw_files(root_dir, mode)
else:
texts = None
self.root_dir = root_dir
self.image_dir = image_dir
self.file_names = file_names
self.texts = texts
self.img_height = img_height
self.img_width = img_width
def _load_from_raw_files(self, root_dir, mode):
paths_file = None
if mode == 'train':
paths_file = 'train.txt'
elif mode == 'test':
paths_file = 'test.txt'
file_names = []
texts = []
with open(os.path.join(root_dir, paths_file), 'r') as fr:
for line in fr.readlines():
file_name, ext = line.strip().split('.')
text = file_name.split('_')[-1].lower()
file_names.append(file_name + "." + ext)
texts.append(text)
return file_names, texts
def __len__(self):
return len(self.file_names)
def __getitem__(self, index):
file_name = self.file_names[index]
file_path = os.path.join(self.image_dir,file_name)
image = Image.open(file_path).convert('L') # grey-scale
image = image.resize((self.img_width, self.img_height), resample=Image.BILINEAR)
image = np.array(image)
image = image.reshape((1, self.img_height, self.img_width))
image = (image / 127.5) - 1.0
image = torch.FloatTensor(image)
if self.texts:
text = self.texts[index]
target = [self.CHAR2LABEL[c] for c in text]
target_length = [len(target)]
target = torch.LongTensor(target)
target_length = torch.LongTensor(target_length)
# 如果DataLoader不設(shè)置collate_fn,則此處返回值為迭代DataLoader時(shí)取到的值
return image, target, target_length
else:
return image
def synth90k_collate_fn(batch):
# zip(*batch)拆包
images, targets, target_lengths = zip(*batch)
# stack就是向量堆疊的意思。一定是擴(kuò)張一個(gè)維度,然后在擴(kuò)張的維度上,把多個(gè)張量納入僅一個(gè)張量。想象向上摞面包片,摞的操作即是stack,0軸即按塊stack
images = torch.stack(images, 0)
# cat是指向量拼接的意思。一定不擴(kuò)張維度,想象把兩個(gè)長(zhǎng)條向量cat成一個(gè)更長(zhǎng)的向量。
targets = torch.cat(targets, 0)
target_lengths = torch.cat(target_lengths, 0)
# 此處返回的數(shù)據(jù)即使train_loader每次取到的數(shù)據(jù),迭代train_loader,每次都會(huì)取到三個(gè)值,即此處返回值。
return images, targets, target_lengths
if __name__ == '__main__':
from torch.utils.data import DataLoader
from config import train_config as config
img_width = config['img_width']
img_height = config['img_height']
data_dir = config['data_dir']
train_batch_size = config['train_batch_size']
cpu_workers = config['cpu_workers']
train_dataset = Synth90kDataset(root_dir=data_dir, mode='train',
img_height=img_height, img_width=img_width)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
shuffle=True,
num_workers=cpu_workers,
collate_fn=synth90k_collate_fn)
model.py
import torch.nn as nn
class CRNN(nn.Module):
def __init__(self, img_channel, img_height, img_width, num_class,
map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
super(CRNN, self).__init__()
self.cnn, (output_channel, output_height, output_width) = \
self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)
self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)
self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
# 如果接雙向lstm輸出,則要 *2,固定用法
self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
self.dense = nn.Linear(2 * rnn_hidden, num_class)
# CNN主干網(wǎng)絡(luò)
def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
assert img_height % 16 == 0
assert img_width % 4 == 0
# 超參設(shè)置
channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
strides = [1, 1, 1, 1, 1, 1, 1]
paddings = [1, 1, 1, 1, 1, 1, 0]
cnn = nn.Sequential()
def conv_relu(i, batch_norm=False):
# shape of input: (batch, input_channel, height, width)
input_channel = channels[i]
output_channel = channels[i+1]
cnn.add_module(
f'conv{i}',
nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
)
if batch_norm:
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))
relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
cnn.add_module(f'relu{i}', relu)
# size of image: (channel, height, width) = (img_channel, img_height, img_width)
conv_relu(0)
cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
# (64, img_height // 2, img_width // 2)
conv_relu(1)
cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
# (128, img_height // 4, img_width // 4)
conv_relu(2)
conv_relu(3)
cnn.add_module(
'pooling2',
nn.MaxPool2d(kernel_size=(2, 1))
) # (256, img_height // 8, img_width // 4)
conv_relu(4, batch_norm=True)
conv_relu(5, batch_norm=True)
cnn.add_module(
'pooling3',
nn.MaxPool2d(kernel_size=(2, 1))
) # (512, img_height // 16, img_width // 4)
conv_relu(6) # (512, img_height // 16 - 1, img_width // 4 - 1)
output_channel, output_height, output_width = \
channels[-1], img_height // 16 - 1, img_width // 4 - 1
return cnn, (output_channel, output_height, output_width)
# CNN+LSTM前向計(jì)算
def forward(self, images):
# shape of images: (batch, channel, height, width)
conv = self.cnn(images)
batch, channel, height, width = conv.size()
conv = conv.view(batch, channel * height, width)
conv = conv.permute(2, 0, 1) # (width, batch, feature)
# 卷積接全連接。全連接輸入形狀為(width, batch, channel*height),
# 輸出形狀為(width, batch, hidden_layer),分別對(duì)應(yīng)時(shí)序長(zhǎng)度,batch,特征數(shù),符合LSTM輸入要求
seq = self.map_to_seq(conv)
recurrent, _ = self.rnn1(seq)
recurrent, _ = self.rnn2(recurrent)
output = self.dense(recurrent)
return output # shape: (seq_len, batch, num_class)
train.py
import os
import cv2
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss
from dataset import Synth90kDataset, synth90k_collate_fn
from model import CRNN
from evaluate import evaluate
from config import train_config as config
def train_batch(crnn, data, optimizer, criterion, device):
crnn.train()
images, targets, target_lengths = [d.to(device) for d in data]
logits = crnn(images)
log_probs = torch.nn.functional.log_softmax(logits, dim=2)
batch_size = images.size(0)
input_lengths = torch.LongTensor([logits.size(0)] * batch_size)
target_lengths = torch.flatten(target_lengths)
loss = criterion(log_probs, targets, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def main():
epochs = config['epochs']
train_batch_size = config['train_batch_size']
lr = config['lr']
show_interval = config['show_interval']
valid_interval = config['valid_interval']
save_interval = config['save_interval']
cpu_workers = config['cpu_workers']
reload_checkpoint = config['reload_checkpoint']
img_width = config['img_width']
img_height = config['img_height']
data_dir = config['data_dir']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')
train_dataset = Synth90kDataset(root_dir=data_dir,image_dir='../data/images', mode='train',
img_height=img_height, img_width=img_width)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=train_batch_size,
shuffle=True,
num_workers=cpu_workers,
collate_fn=synth90k_collate_fn)
num_class = len(Synth90kDataset.LABEL2CHAR) + 1
crnn = CRNN(1, img_height, img_width, num_class,
map_to_seq_hidden=config['map_to_seq_hidden'],
rnn_hidden=config['rnn_hidden'],
leaky_relu=config['leaky_relu'])
if reload_checkpoint:
crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
crnn.to(device)
optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
criterion = CTCLoss(reduction='sum')
criterion.to(device)
assert save_interval % valid_interval == 0 or valid_interval % save_interval ==0
i = 1
for epoch in range(1, epochs + 1):
print(f'epoch: {epoch}')
tot_train_loss = 0.
tot_train_count = 0
for train_data in train_loader:
loss = train_batch(crnn, train_data, optimizer, criterion, device)
train_size = train_data[0].size(0)
tot_train_loss += loss
tot_train_count += train_size
if i % show_interval == 0:
print('train_batch_loss[', i, ']: ', loss / train_size)
if i % save_interval == 0:
save_model_path = os.path.join(config["checkpoints_dir"],"crnn.pt")
torch.save(crnn.state_dict(), save_model_path)
print('save model at ', save_model_path)
i += 1
print('train_loss: ', tot_train_loss / tot_train_count)
if __name__ == '__main__':
main()
識(shí)別效果還算可以文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-734187.html
測(cè)試效果文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-734187.html
到了這里,關(guān)于文本識(shí)別CRNN模型介紹以及pytorch代碼實(shí)現(xiàn)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!