一、代碼
#---------------------------------------------------#
# 檢測圖片
#---------------------------------------------------#
def detect_image(self, image, count=False, name_classes=None):
#---------------------------------------------------------#
# 在這里將圖像轉(zhuǎn)換成RGB圖像,防止灰度圖在預(yù)測時(shí)報(bào)錯(cuò)。
# 代碼僅僅支持RGB圖像的預(yù)測,所有其它類型的圖像都會(huì)轉(zhuǎn)化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------#
# 對(duì)輸入圖像進(jìn)行一個(gè)備份,后面用于繪圖
#---------------------------------------------------#
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------------#
# 給圖像增加灰條,實(shí)現(xiàn)不失真的resize
# 也可以直接resize進(jìn)行識(shí)別
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
#---------------------------------------------------------#
# 添加上batch_size維度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 圖片傳入網(wǎng)絡(luò)進(jìn)行預(yù)測
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 取出每一個(gè)像素點(diǎn)的種類
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
#--------------------------------------#
# 將灰條部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
#---------------------------------------------------#
# 進(jìn)行圖片的resize
#---------------------------------------------------#
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
#---------------------------------------------------#
# 取出每一個(gè)像素點(diǎn)的種類
#---------------------------------------------------#
pr = pr.argmax(axis=-1)
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
#------------------------------------------------#
# 將新圖片轉(zhuǎn)換成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
#------------------------------------------------#
# 將新圖與原圖及進(jìn)行混合
#------------------------------------------------#
image = Image.blend(old_img, image, 0.7)
二、代碼逐步debug調(diào)試
(1)讀圖
#---------------------------------------------------------#
# 在這里將圖像轉(zhuǎn)換成RGB圖像,防止灰度圖在預(yù)測時(shí)報(bào)錯(cuò)。
# 代碼僅僅支持RGB圖像的預(yù)測,所有其它類型的圖像都會(huì)轉(zhuǎn)化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
(2) Letterbox
無論輸入的圖片尺寸多大,都會(huì)經(jīng)過letter_box后,變?yōu)?12x512尺寸
(3) 歸一化、HWC 轉(zhuǎn) CHW,并expand維度到NCHW,轉(zhuǎn)tensor
def preprocess_input(image):
image /= 255.0
return image
#---------------------------------------------------------#
# 添加上batch_size維度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
(4) 前向傳播
#---------------------------------------------------#
# 圖片傳入網(wǎng)絡(luò)進(jìn)行預(yù)測
#---------------------------------------------------#
pr = self.net(images)[0]
21個(gè)channel代表(20+1)個(gè)類別,512x512為模型輸入及輸入尺寸
(5) softmax 計(jì)算像素類別概率
#---------------------------------------------------#
# 取出每一個(gè)像素點(diǎn)的種類
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
經(jīng)過softmax后,512x512的mask圖中,每個(gè)位置(x,y)對(duì)應(yīng)的21個(gè)channel的值和為1。
(6) 截取灰條部分,并resize到原圖尺寸(逆letter_box)
#--------------------------------------#
# 將灰條部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
#---------------------------------------------------#
# 進(jìn)行圖片的resize
#---------------------------------------------------#
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
pr類型是np,array,所以可以通過這種方式進(jìn)行逆letter_box操作,將mask的寬高,還原到原始輸入圖片的寬高。
(7) 利用argmax,計(jì)算每個(gè)像素屬于的類別
#---------------------------------------------------#
# 取出每一個(gè)像素點(diǎn)的種類
#---------------------------------------------------#
pr = pr.argmax(axis=-1)
返回最后一個(gè)維度(channel)中,最大值所對(duì)應(yīng)的索引,即類別。例如,像素點(diǎn)(x1,y1)所對(duì)應(yīng)的21個(gè)channel中,第5個(gè)channel的值最大,則像素點(diǎn)(x1,y1)對(duì)應(yīng)類別則是class=5。文章來源:http://www.zghlxwxcb.cn/news/detail-853081.html
(8) 可視化
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
#------------------------------------------------#
# 將新圖片轉(zhuǎn)換成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
#------------------------------------------------#
# 將新圖與原圖及進(jìn)行混合
#------------------------------------------------#
image = Image.blend(old_img, image, 0.7)
將預(yù)測的結(jié)果與原圖進(jìn)行混合。文章來源地址http://www.zghlxwxcb.cn/news/detail-853081.html
到了這里,關(guān)于【深度學(xué)習(xí)實(shí)戰(zhàn)(6)】搭建通用的語義分割推理流程的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!