前言
本文來源論文《Simple Copy-Paste is a Strong Data Augmentation Method
for Instance Segmentation》(CVPR2020),對其數據增強方式進行實現。
論文地址:https://arxiv.org/abs/2012.07177
解讀:https://mp.weixin.qq.com/s/nKC3bEe3m1eqPDI0LpVTIA
主要思想:
本文參考該數據增強的語義分割實現[1],相應修改為對應目標檢測的實現,坐標變換的寫法參考[2]。
其中,對應的標注信息為txt格式,如果自己的數據集是VOC或COCO格式,可自行修改,也可先轉換成txt格式再使用下述代碼。
1. 效果展示
數據來源CCPD2019數據集,下圖分別為img_main和img_src:
將img_src的車牌目標“復制-粘貼”到img_main的結果:
新生成的圖片大小與img_main一致,空白的部分會補灰邊。
代碼說明
'''
Descripttion: Data Augment for Object Detection.
version: 1.0.0
Author: lakuite
Date: 2021-08-06 13:37:38
Copyright: Copyright(c) 2021 lakuite. All Rights Reserved
'''
import numpy as np
import cv2
import os
import tqdm
import argparse
from skimage.draw import polygon
import random
def random_flip_horizontal(img, box, p=0.5):
'''
對img和mask隨機進行水平翻轉。box為二維np.array。
https://blog.csdn.net/weixin_41735859/article/details/106468551
img[:,:,::-1] gbr-->bgr、img[:,::-1,:] 水平翻轉、img[::-1,:,:] 上下翻轉
'''
if np.random.random() < p:
w = img.shape[1]
img = img[:, ::-1, :]
box[:, [0, 2, 4, 6]] = w - box[:, [2, 0, 6, 4]] # 僅針對4個點變換
return img, box
def Large_Scale_Jittering(img, box, min_scale=0.1, max_scale=2.0):
'''
對img和box進行0.1-2.0的大尺度抖動,并變回h*w的大小。
'''
rescale_ratio = np.random.uniform(min_scale, max_scale)
h, w, _ = img.shape
# rescale
h_new, w_new = int(h * rescale_ratio), int(w * rescale_ratio)
img = cv2.resize(img, (w_new, h_new), interpolation=cv2.INTER_LINEAR)
# crop or padding
# x,y是隨機選擇左上角的一個點,讓小圖片在這個位置,或者讓大圖片從這個位置開始裁剪
x, y = int(np.random.uniform(0, abs(w_new - w))), int(np.random.uniform(0, abs(h_new - h)))
# 如果圖像縮小了,那么其余部分要填充為像素168大小
if rescale_ratio <= 1.0: # padding
img_pad = np.ones((h, w, 3), dtype=np.uint8) * 168
img_pad[y:y + h_new, x:x + w_new, :] = img
box[:, [0, 2, 4, 6]] = box[:, [0, 2, 4, 6]] * w_new/w + x # x坐標
box[:, [1, 3, 5, 7]] = box[:, [1, 3, 5, 7]] * h_new/h + y # y坐標
return img_pad, box
# 如果圖像放大了,那么要裁剪成h*w的大小
else: # crop
img_crop = img[y:y + h, x:x + w, :]
box[:, [0, 2, 4, 6]] = box[:, [0, 2, 4, 6]] * w_new/w - x
box[:, [1, 3, 5, 7]] = box[:, [1, 3, 5, 7]] * h_new/h - y
return img_crop, box
def img_add(img_src, img_main, mask_src, box_src):
'''
將src加到main圖像中,結果圖還是main圖像的大小。
'''
if len(img_main.shape) == 3:
h, w, c = img_main.shape
elif len(img_main.shape) == 2:
h, w = img_main.shape
src_h, src_w = img_src.shape[0], img_src.shape[1]
mask = np.asarray(mask_src, dtype=np.uint8)
# mask是二值圖片,對src進行局部遮擋,即只露出目標物體的像素。
sub_img01 = cv2.add(img_src, np.zeros(np.shape(img_src), dtype=np.uint8), mask=mask) # 報錯深度不一致
mask_02 = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
mask_02 = np.asarray(mask_02, dtype=np.uint8)
sub_img02 = cv2.add(img_main, np.zeros(np.shape(img_main), dtype=np.uint8),
mask=mask_02) # 在main圖像上對應位置挖了一塊
# main圖像減去要粘貼的部分的圖,然后加上復制過來的圖
img_main = img_main - sub_img02 + cv2.resize(sub_img01, (w, h),
interpolation=cv2.INTER_NEAREST)
box_src[:, [0, 2, 4, 6]] = box_src[:, [0, 2, 4, 6]] * w/src_w
box_src[:, [1, 3, 5, 7]] = box_src[:, [1, 3, 5, 7]] * h/src_h
return img_main, box_src
def normal_(jpg_path, txt_path="", box=None):
"""
根據txt獲得box或者根據box獲得mask。
:param jpg_path: 圖片路徑
:param txt_path: x1,y1,x2,y2 x3,y3,x4,y4...
:param box: 如果有box,則為根據box生成mask
:return: 圖像,box 或 掩碼
"""
if isinstance(jpg_path, str): # 如果是路徑就讀取圖片
jpg_path = cv2.imread(jpg_path)
img = jpg_path.copy()
if box is None: # 一定有txt_path
lines = open(txt_path).readlines()
box = []
for line in lines:
ceils = line.strip().split(',')
xy = []
for ceil in ceils:
xy.append(round(float(ceil)))
box.append(np.array(xy))
return np.array(img), np.array(box)
else: # 獲得mask
h, w = img.shape[:2]
mask = np.zeros((h, w), dtype=np.float32)
for xy in box: # 對每個框
xy = np.array(xy).reshape(-1, 2)
cv2.fillPoly(mask, [xy.astype(np.int32)], 1)
return np.array(mask)
def is_coincide(polygon_1, polygon_2):
'''
判斷2個四邊形是否重合
:param polygon_1: [x1, y1,...,x4, y4]
:param polygon_2:
:return: bool,1表示重合
'''
rr1, cc1 = polygon([polygon_1[i] for i in range(0, len(polygon_1), 2)],
[polygon_1[i] for i in range(1, len(polygon_1), 2)])
rr2, cc2 = polygon([polygon_2[i] for i in range(0, len(polygon_2), 2)],
[polygon_2[i] for i in range(1, len(polygon_2), 2)])
try: # 能包含2個四邊形的最小矩形長寬
r_max = max(rr1.max(), rr2.max()) + 1
c_max = max(cc1.max(), cc2.max()) + 1
except:
return 0
# 相當于canvas是包含了2個多邊形的一個畫布,有2個多邊形的位置像素為1,重合位置像素為2
canvas = np.zeros((r_max, c_max))
canvas[rr1, cc1] += 1
canvas[rr2, cc2] += 1
intersection = np.sum(canvas == 2)
return 1 if intersection!=0 else 0
def copy_paste(img_main_path, img_src_path, txt_main_path, txt_src_path, coincide=False, muti_obj=True):
'''
整個復制粘貼操作,輸入2張圖的圖片和坐標路徑,返回其融合后的圖像和坐標結果。
1. 傳入隨機選擇的main圖像和src圖像的img和txt路徑;
2. 對其進行隨機水平翻轉;
3. 對其進行隨機抖動;
4. 獲得src變換完后對應的mask;
5. 將src的結果加到main中,返回對應main_new的img和src圖的box.
'''
# 讀取圖像和坐標
img_main, box_main = normal_(img_main_path, txt_main_path)
img_src, box_src = normal_(img_src_path, txt_src_path)
# 隨機水平翻轉
img_main, box_main = random_flip_horizontal(img_main, box_main)
img_src, box_src = random_flip_horizontal(img_src, box_src)
# LSJ, Large_Scale_Jittering 大尺度抖動,并變回h*w大小
img_main, box_main = Large_Scale_Jittering(img_main, box_main)
img_src, box_src = Large_Scale_Jittering(img_src, box_src)
if not muti_obj or box_src.ndim==1: # 只復制粘貼一個目標
id = random.randint(0, len(box_src)-1)
box_src = box_src[id]
box_src = box_src[np.newaxis, :] # 增加一維
# 獲得一系列變換后的img_src的mask
mask_src = normal_(img_src_path, box=box_src)
# 將src結果加到main圖像中,返回main圖像的大小的疊加圖
img, box_src = img_add(img_src, img_main, mask_src, box_src)
# 判斷融合后的區(qū)域是否重合
if not coincide:
for point_main in box_main:
for point_src in box_src:
if is_coincide(point_main, point_src):
return None, None
box = np.vstack((box_main, box_src))
return img, box
def save_res(img, img_path, box, txt_path):
'''
保存圖片和txt坐標結果。
'''
cv2.imwrite(img_path, img)
h, w = img.shape[:2]
with open(txt_path, 'w+') as ftxt:
for point in box: # [x1,y1,...x4,,y4]
strxy = ""
for i, p in enumerate(point):
if i%2==0: # x坐標
p = np.clip(p, 0, w-1)
else: # y坐標
p = np.clip(p, 0, h-1)
strxy = strxy + str(p) + ','
strxy = strxy[:-1] # 去掉最后一個逗號
ftxt.writelines(strxy + "\n")
def main(args):
# 圖像和坐標txt文件輸入路徑
JPEGs = os.path.join(args.input_dir, 'jpg')
BOXes = os.path.join(args.input_dir, 'txt')
# 輸出路徑
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'cpAug_jpg'), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'cpAug_txt'), exist_ok=True)
# 參與數據增強的圖片名稱,不含后綴
imgs_list = open(args.aug_txt, 'r').read().splitlines()
flag = '.jpg' # 圖像的后綴名 .jpg ,png
tbar = tqdm.tqdm(imgs_list, ncols=100) # 進度條顯示
for src_name in tbar:
# src圖像
img_src_path = os.path.join(JPEGs, src_name+flag)
txt_src_path = os.path.join(BOXes, src_name+'.txt')
# 隨機選擇main圖像
main_name = np.random.choice(imgs_list)
img_main_path = os.path.join(JPEGs, main_name+flag)
txt_main_path = os.path.join(BOXes, main_name+'.txt')
# 數據增強
img, box = copy_paste(img_main_path, img_src_path, txt_main_path, txt_src_path,
args.coincide, args.muti_obj)
if img is None:
continue
# 保存結果
img_name = "copy_" + src_name + "_paste_" + main_name
save_res(img, os.path.join(args.output_dir, 'cpAug_jpg', img_name+flag),
box, os.path.join(args.output_dir, 'cpAug_txt', img_name+'.txt'))
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", default="./input_dir", type=str,
help="要進行數據增強的圖像路徑,路徑結構下應有jpg和txt文件夾")
parser.add_argument("--output_dir", default="./output_dir", type=str,
help="保存數據增強結果的路徑")
parser.add_argument("--aug_txt", default="./input_dir/test.txt",
type=str, help="要進行數據增強的圖像的名字,不包含后綴")
parser.add_argument("--coincide", default=False, type=bool,
help="True表示允許數據增強后的圖像目標出現重合,默認不允許重合")
parser.add_argument("--muti_obj", default=False, type=bool,
help="True表示將src圖上的所有目標都復制粘貼,False表示只隨機粘貼一個目標")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args)
-
圖像路徑:
input_dir存放要數據增強的圖片和其對應的txt,其中圖片和txt名稱應相同,圖片后綴可修改 flag,默認為.jpg。output_dir輸出數據增強后的圖片,無需創(chuàng)建。 -
需進行增強的圖片列表test.txt,不含后綴:
生成test.txt代碼[3]:
# 獲取驗證集訓練集劃分的txt文件,劃分僅保存名字,不包含后綴
import os
import random
random.seed(0)
xmlfilepath = './input_dir/txt' # 標簽路徑
saveBasePath = "./input_dir" # 保存的位置
trainval_percent = 0.9 # 訓練+驗證集的比例,不為1說明有測試集
train_percent = 1 # 訓練集在訓練+驗證集中占的比例,如果代碼是從訓練集分出的驗證集,那就不用改
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".txt"):
total_xml.append(xml)
num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)
print("train and val size", tv)
print("traub suze", tr)
ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
for i in list:
name = total_xml[i][:-4] + '\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
運行后可在input_dir下生成4個.txt,其中test.txt僅包含10% input_dir中的圖片。
3.標簽txt格式:
3. 參考文檔
參考文檔
[1] 代碼復現:Copy-Paste 數據增強for 語義分割 https://blog.csdn.net/oyezhou/article/details/111696577
[2] 目標檢測中的數據增強方法(附詳細代碼講解)https://www.cnblogs.com/xiamuzi/p/13471386.html文章來源:http://www.zghlxwxcb.cn/news/detail-664548.html
4. 不合適點
以上是人家的代碼,但用在我這邊不合適,是因為:它的車牌不會有交叉覆蓋,我的是煙火識別,
煙和火是兩個目標,有覆蓋。 所以不合適。文章來源地址http://www.zghlxwxcb.cn/news/detail-664548.html
import glob
import cv2
import numpy as np
import random
def crop_image(image, x, y, width, height):
cropped_image = image[y:y + height, x:x + width]
return cropped_image
def convert_to_absolute(label, image_width, image_height):
class_id, relative_x_center, relative_y_center, relative_width, relative_height = label
# 計算邊界框的絕對坐標
absolute_x_center = relative_x_center * image_width
absolute_y_center = relative_y_center * image_height
absolute_width = relative_width * image_width
absolute_height = relative_height * image_height
# 計算邊界框的左上角和右下角坐標
left = absolute_x_center - absolute_width / 2
top = absolute_y_center - absolute_height / 2
right = absolute_x_center + absolute_width / 2
bottom = absolute_y_center + absolute_height / 2
# 返回絕對坐標形式的邊界框
return [class_id, left, top, right, bottom]
def convert_to_yolo_format(class_id, left, top, right, bottom, image_width, image_height):
# 計算目標框的中心點坐標和寬高
x = (left + right) / 2
y = (top + bottom) / 2
width = right - left
height = bottom - top
# 將坐標和尺寸歸一化到[0, 1]之間
x /= image_width
y /= image_height
width /= image_width
height /= image_height
# 返回Yolo格式的標注
return f"{class_id} {x} {y} {width} {height}"
def get_src():
img_list = glob.glob(r"E:\Dataset\zhongwaiyun\data_fire(1w)\data_fire(1w)\scr_copy_paste\images\*.jpg")
random.shuffle(img_list)
img_path = img_list[0]
txt_path = img_list[0].replace("images", "txt").replace(".jpg", ".txt")
return img_path, txt_path
img_list = glob.glob(r"E:\Dataset\zhongwaiyun\zwy_make_background\*.jpg")
for img_b_path in img_list:
img_a_path, img_a_txt = get_src()
image_a = cv2.imread(img_a_path)
image_height, image_width, _ = image_a.shape
img_b_txt = img_b_path.replace(".jpg", ".txt").replace("zwy_make_background", "zwy_make_fire_and_smoke")
img_b_path_new = img_b_path.replace("zwy_make_background", "zwy_make_fire_and_smoke")
src_location_map = []
with open(img_a_txt) as f:
for line_str in f:
line_info = line_str.strip().split(" ")
label = [int(line_info[0]), float(line_info[1]), float(line_info[2]), float(line_info[3]),
float(line_info[4])]
class_id, left, top, right, bottom = convert_to_absolute(label, image_width, image_height)
src_location_map.append([class_id, left, top, right, bottom])
image_b = cv2.imread(img_b_path)
res_list = []
for row in src_location_map:
class_id, left, top, right, bottom = row
if left or top or right or bottom:
try:
# 目標可以出現在空白圖片的任何位置,只要沒有超過限制即可
x = int(left) # 指定區(qū)域的起始橫坐標
y = int(top) # 指定區(qū)域的起始縱坐標
width = int(right - left) # 指定區(qū)域的寬度
height = int(bottom - top) # 指定區(qū)域的高度
cropped_image_a = crop_image(image_a, int(x), int(y), int(width), int(height))
image_b_height, image_b_width, _ = image_b.shape
b_x = random.randint(0, int(image_b_width - width - 5))
b_y = random.randint(0, int(image_b_height - height - 5))
image_b[b_y:b_y + height, b_x:b_x + width] = cropped_image_a
res = convert_to_yolo_format(class_id, b_x, b_y, b_x + width, b_y + height, image_b_width, image_b_height)
print("--==", img_b_txt)
with open(img_b_txt, "a") as f:
f.write(res)
cv2.imwrite(img_b_path_new, image_b)
break
except:
break
到了這里,關于【目標檢測】“復制-粘貼 copy-paste” 數據增強實現的文章就介紹完了。如果您還想了解更多內容,請在右上角搜索TOY模板網以前的文章或繼續(xù)瀏覽下面的相關文章,希望大家以后多多支持TOY模板網!