關(guān)于GaitSet核心算法,建議直接跳到 “4. 算法核心代碼——4.1 gaitset.py”
1. 論文&代碼源
論文地址:https://ieeexplore.ieee.org/document/9351667
CASIA-B數(shù)據(jù)集下載地址:http://www.cbsr.ia.ac.cn/china/Gait%20Databases%20CH.asp
代碼下載地址:https://github.com/AbnerHqC/GaitSet
2. 配置環(huán)境
2.1 硬件環(huán)境
1. 確定顯卡型號
右鍵“此電腦”——“管理”——“設(shè)備管理器”——“顯示適配器”查看GPU型號,以我的電腦為例:
版本為NVIDIA GeForce GTX 1650 SUPER。
2. 下載NVIDIA驅(qū)動程序
在NVIDIA驅(qū)動程序下載頁面選擇顯卡相應(yīng)的版本號、操作系統(tǒng)、下載類型等:
3. CUDA與cuDNN版本的選擇及下載
打開“NVIDIA Control Panel”——“系統(tǒng)信息”——“組件”
此處顯示的是顯卡所支持的CUDA(最高)版本,在NVIDIA DEVELOPER CUDA10.0可以進(jìn)行下載,我選擇的是10.0版本,對應(yīng)選項見下圖:
cuDNN選擇v7.6.5.32,同樣可以在NVIDIA DEVELOPER cudnn進(jìn)行下載。
?可以看到CUDA安裝包有2G的大小,盡量將其下載到除C盤外的其他磁盤內(nèi),下載教程參見:CUDA、CUDNN在windows下的安裝及配置;上述詳細(xì)解釋及配置調(diào)試驗證參見:Windows11 顯卡GTX1650 搭建CUDA+cuDNN環(huán)境,并安裝對應(yīng)版本的Anaconda和TensorFlow-GPU(本文代碼所基于的環(huán)境無需安裝TensorFlow)。
2.2 軟件配置
我下載的Python版本為3.7.8,根據(jù)Anaconda與Python版本對應(yīng)下載Anaconda3-2020.02-Windows-x86_64
其他軟件版本見下表:
NAME | VERSION |
---|---|
Python | 3.7.8 |
Anaconda | Anaconda3-2020.02 |
Pycharm | 2019.1.4 |
3. 運行代碼
ERROR1
論文作者給出的原始代碼在運行前出現(xiàn)from XX import XX紅色波浪線報錯的情況:
原因是代碼無法在項目文件夾中找到需要import的源代碼,因為原始項目的構(gòu)架如下圖所示:
所有需要人為主動編譯和運行過程中被動編譯的代碼所在的文件類型都是普通文件夾:
現(xiàn)在我們需要將項目中需要被編譯構(gòu)建的文件視為源代碼(類似C語言中的.h文件),因此將項目架構(gòu)變更為下圖:
項目更改前后的對應(yīng)關(guān)系如下圖所示:
※補(bǔ)充
Source Root:表明該文件夾內(nèi)的子文件夾及其代碼是源代碼,需要進(jìn)行編譯;
Excluded:表明該文件夾下的內(nèi)容不會被IDEA創(chuàng)建索引,可以類比代碼段中的注釋內(nèi)容;
Resorces Root:表明該文件夾內(nèi)含有項目中使用的資源文件,如:圖像、配置XML和屬性文件等;
Template Folder:存放模板的文件夾。
(官方解釋網(wǎng)址:PyCharm 2019.1 Help)
3.1 關(guān)于CASIA-B數(shù)據(jù)集
CASIA-B是中國科學(xué)院自動化研究所提供的CASIA步態(tài)數(shù)據(jù)庫其中之一。CASIA步態(tài)數(shù)據(jù)庫有三個數(shù)據(jù)集:Dataset A(小規(guī)模庫), Dataset B(多視角庫)和Dataset C(紅外庫),文章采用的是Dataset B,這是一個大規(guī)模、多視角的數(shù)據(jù)集,采集于20051月,數(shù)據(jù)集內(nèi)包含124個人,每個人有0°,18°,……,180°共11個視角,在普通(nm),穿大衣(cl)和背包(bg)3種行走狀態(tài)。
數(shù)據(jù)集可以直接在CASIA步態(tài)數(shù)據(jù)庫中下載(png格式步態(tài)輪廓數(shù)據(jù)),如需完整的視頻資源可以在同一界面填寫申請協(xié)議。
關(guān)于數(shù)據(jù)集內(nèi)步態(tài)輪廓png文件的命名格式是按照:行人編號-行走條件-序列號-視角(角度)-幀數(shù)
3.2 pretreatment.py
目的: 對數(shù)據(jù)集進(jìn)行預(yù)處理。將原始大小為320×240像素的圖像按照人像邊界頂點進(jìn)行裁剪,生成64×64像素的圖像。
這里僅僅是圖像預(yù)處理,對于GaitSet操作使用的圖像并非分辨率為64×64大小的,而是64×44,具體原因參見data_set.py文件。
# -*- coding: utf-8 -*-
# @Author : Abner
# @Time : 2018/12/19
import os
from scipy import misc as scisc
import cv2
import numpy as np
from warnings import warn
from time import sleep
import argparse
from multiprocessing import Pool
from multiprocessing import TimeoutError as MP_TimeoutError
#*全大寫單詞用于log中描述狀態(tài)(comment)
START = "START"
FINISH = "FINISH"
WARNING = "WARNING"
FAIL = "FAIL"
def boolean_string(s):
if s.upper() not in {'FALSE', 'TRUE'}:
raise ValueError('Not a valid boolean string')
return s.upper() == 'TRUE'
#*這一部分(以下三行)是在原作者代碼基礎(chǔ)上更改的,能夠直接調(diào)用系統(tǒng)的路徑地址,對數(shù)據(jù)進(jìn)行載入和導(dǎo)出
wd = os.getcwd()
input_path = os.path.join(wd, 'input_data_path')
output_path = os.path.join(wd, 'output_data_path')
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('--input_path', default=input_path, type=str,
help='Root path of raw dataset.')
parser.add_argument('--output_path', default=output_path, type=str,
help='Root path for output.')
parser.add_argument('--log_file', default='./pretreatment.log', type=str,
help='Log file path. Default: ./pretreatment.log') #*訓(xùn)練后自動生成的日志文件
parser.add_argument('--log', default=False, type=boolean_string,
help='If set as True, all logs will be saved. '
'Otherwise, only warnings and errors will be saved.'
'Default: False') #*若代碼運行無誤,程序日志文件將被保存;否則,保存警告和錯誤信息
parser.add_argument('--worker_num', default=1, type=int,
help='How many subprocesses to use for data pretreatment. '
'Default: 1') #*定義由多少個并行程序?qū)?shù)據(jù)進(jìn)行預(yù)處理,默認(rèn)值是1
opt = parser.parse_args()
INPUT_PATH = opt.input_path
OUTPUT_PATH = opt.output_path
IF_LOG = opt.log
LOG_PATH = opt.log_file
WORKERS = opt.worker_num
#*輸出圖像的高度和寬度均為64個像素
T_H = 64
T_W = 64
用戶需要修改inputdata_path和outputdata_path兩個變量。
input_data_path: CASIA-B數(shù)據(jù)集在本地文件的地址。(注意:數(shù)據(jù)集壓縮包內(nèi)的小數(shù)據(jù)集依舊是壓縮包的形式,同樣需要進(jìn)行解壓縮操作,也就是需要將數(shù)據(jù)集進(jìn)行兩次解壓。)
output_data_path: 數(shù)據(jù)集預(yù)處理后存放的路徑。此文件必須是一個空文件,否則會出現(xiàn)ERROR2:
FileExistsError: [WinError 183] 當(dāng)文件已存在時,無法創(chuàng)建該文件。
建議將上述兩個文件夾并列作為兩個子文件存在在一起,方便比對預(yù)處理前后的數(shù)據(jù)差異。
3.2.1 log2str函數(shù)
此函數(shù)用于定義生成日志的格式(不重要)。
#*日志報告數(shù)據(jù)生成函數(shù)
#*輸入變量:pid-進(jìn)程序號(process ID)
#* comment-狀態(tài)描述
#* logs-內(nèi)容描述
def log2str(pid, comment, logs):
str_log = '' #*str_log變量初始值為空
if type(logs) is str:
logs = [logs]
for log in logs:
str_log += "# JOB %d : --%s-- %s\n" % (
pid, comment, log)
return str_log
3.2.2 log_print函數(shù)
此函數(shù)用于將日志打印輸出(不重要)。
#*日志報告打印函數(shù)
#*輸入變量同log2str函數(shù)
def log_print(pid, comment, logs):
str_log = log2str(pid, comment, logs)
if comment in [WARNING, FAIL]: #*若運行過程出現(xiàn)警告或報錯,執(zhí)行此if函數(shù)
with open(LOG_PATH, 'a') as log_f: #*顯示錯誤地址
log_f.write(str_log)
if comment in [START, FINISH]:
if pid % 500 != 0: #*每執(zhí)行500步打印輸出一次
return
print(str_log, end='')
3.2.3 cut_img函數(shù)
此函數(shù)用于將圖像進(jìn)行裁剪(不是特別重要,知道操作流程是怎么回事就行)。
#*圖像裁剪函數(shù)
#*輸入變量:img-待處理圖像
#* seq_info-序列組信息
#* frame_name-序列組內(nèi)文件名
#* pid-進(jìn)程序號
def cut_img(img, seq_info, frame_name, pid):
# A silhouette contains too little white pixels
#*如果人像剪影白色像素點過少
# might be not valid for identification.
#*可能會有無效識別的情況出現(xiàn),見下文WARNING1
if img.sum() <= 10000:
message = 'seq:%s, frame:%s, no data, %d.' % (
'-'.join(seq_info), frame_name, img.sum())
warn(message)
log_print(pid, WARNING, message)
return None
# Get the top and bottom point
#*獲取圖像上下頂點
y = img.sum(axis=1)
y_top = (y != 0).argmax(axis=0)
y_btm = (y != 0).cumsum(axis=0).argmax(axis=0)
img = img[y_top:y_btm + 1, :]
# As the height of a person is larger than the width,
#*當(dāng)人像剪影的高度大于寬度時
# use the height to calculate resize ratio.
#*用高度去計算大小調(diào)整比率
_r = img.shape[1] / img.shape[0]
_t_w = int(T_H * _r)
img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC)
# Get the median of x axis and regard it as the x center of the person.
#*獲取x軸的中心點,將其視為人像的x軸中點
sum_point = img.sum()
sum_column = img.sum(axis=0).cumsum()
x_center = -1
for i in range(sum_column.size):
if sum_column[i] > sum_point / 2:
x_center = i
break
if x_center < 0:
message = 'seq:%s, frame:%s, no center.' % (
'-'.join(seq_info), frame_name)
warn(message)
log_print(pid, WARNING, message)
return None
h_T_W = int(T_W / 2)
left = x_center - h_T_W
right = x_center + h_T_W
if left <= 0 or right >= img.shape[1]:
left += h_T_W
right += h_T_W
_ = np.zeros((img.shape[0], h_T_W))
img = np.concatenate([_, img, _], axis=1)
img = img[:, left:right]
return img.astype('uint8')
axis=0 壓縮行: 將每一列的像素值相加,圖像矩陣壓縮為一行
axis=1 壓縮列: 將每一列的像素值相加,圖像矩陣壓縮為一列
argmax: 獲取最大值的索引值
3.2.4 cut_pickle函數(shù)
此函數(shù)用于獲取已裁剪完畢的圖像(也不重要)。
#*圖像獲取函數(shù)
#*輸入變量:seq_info-序列組信息
#* pid-進(jìn)程序號
def cut_pickle(seq_info, pid):
seq_name = '-'.join(seq_info)
log_print(pid, START, seq_name)
seq_path = os.path.join(INPUT_PATH, *seq_info)
out_dir = os.path.join(OUTPUT_PATH, *seq_info)
frame_list = os.listdir(seq_path)
frame_list.sort()
count_frame = 0
for _frame_name in frame_list:
frame_path = os.path.join(seq_path, _frame_name)
img = cv2.imread(frame_path)[:, :, 0]
img = cut_img(img, seq_info, _frame_name, pid)
if img is not None:
# Save the cut img
#*保存已完成裁剪的圖像
save_path = os.path.join(out_dir, _frame_name)
cv2.imwrite(save_path, img)
count_frame += 1
# Warn if the sequence contains less than 5 frames
#*當(dāng)有效圖像數(shù)量少于5張時會產(chǎn)生警告,見下文WARNING2
if count_frame < 5:
message = 'seq:%s, less than 5 valid data.' % (
'-'.join(seq_info))
warn(message)
log_print(pid, WARNING, message)
log_print(pid, FINISH,
'Contain %d valid frames. Saved to %s.'
% (count_frame, out_dir))
此外,在預(yù)處理過程中會出現(xiàn)兩種類型的警告:
WARNING1
UserWarning: seq:005-bg-01-000, less than 5 valid data.
少于5個有效數(shù)據(jù),打開原始數(shù)據(jù)可以看到這個文件夾下確實缺少數(shù)據(jù)↓
UserWarning: seq:005-bg-01-018, frame:005-bg-01-018-128.png, no data, 0.
缺少白色像素點引起的警告,此警告對應(yīng)的圖像確實如此↓
以上兩種警告均由數(shù)據(jù)集本身數(shù)據(jù)缺失引起(大概?),暫且無需理會。
3.2.5 圖像預(yù)處理完整代碼
# -*- coding: utf-8 -*-
# @Author : Abner
# @Time : 2018/12/19
import os
from scipy import misc as scisc
import cv2
import numpy as np
from warnings import warn
from time import sleep
import argparse
from multiprocessing import Pool
from multiprocessing import TimeoutError as MP_TimeoutError
START = "START"
FINISH = "FINISH"
WARNING = "WARNING"
FAIL = "FAIL"
def boolean_string(s):
if s.upper() not in {'FALSE', 'TRUE'}:
raise ValueError('Not a valid boolean string')
return s.upper() == 'TRUE'
wd = os.getcwd()
input_path = os.path.join(wd, 'D:\PyCharm\Project\Gaitset\GaitDatasetB-silh0\pretreatment')
output_path = os.path.join(wd, 'D:\PyCharm\Project\Gaitset\GaitDatasetB-silh0\output')
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('--input_path', default=input_path, type=str,
help='Root path of raw dataset.')
parser.add_argument('--output_path', default=output_path, type=str,
help='Root path for output.')
parser.add_argument('--log_file', default='./pretreatment.log', type=str,
help='Log file path. Default: ./pretreatment.log')
parser.add_argument('--log', default=False, type=boolean_string,
help='If set as True, all logs will be saved. '
'Otherwise, only warnings and errors will be saved.'
'Default: False')
parser.add_argument('--worker_num', default=1, type=int,
help='How many subprocesses to use for data pretreatment. '
'Default: 1')
opt = parser.parse_args()
INPUT_PATH = opt.input_path
OUTPUT_PATH = opt.output_path
IF_LOG = opt.log
LOG_PATH = opt.log_file
WORKERS = opt.worker_num
T_H = 64
T_W = 64
def log2str(pid, comment, logs):
str_log = ''
if type(logs) is str:
logs = [logs]
for log in logs:
str_log += "# JOB %d : --%s-- %s\n" % (
pid, comment, log)
return str_log
def log_print(pid, comment, logs):
str_log = log2str(pid, comment, logs)
if comment in [WARNING, FAIL]:
with open(LOG_PATH, 'a') as log_f:
log_f.write(str_log)
if comment in [START, FINISH]:
if pid % 500 != 0:
return
print(str_log, end='')
def cut_img(img, seq_info, frame_name, pid):
# A silhouette contains too little white pixels
# might be not valid for identification.
if img.sum() <= 10000:
message = 'seq:%s, frame:%s, no data, %d.' % (
'-'.join(seq_info), frame_name, img.sum())
warn(message)
log_print(pid, WARNING, message)
return None
# Get the top and bottom point
y = img.sum(axis=1)
y_top = (y != 0).argmax(axis=0)
y_btm = (y != 0).cumsum(axis=0).argmax(axis=0)
img = img[y_top:y_btm + 1, :]
# As the height of a person is larger than the width,
# use the height to calculate resize ratio.
_r = img.shape[1] / img.shape[0]
_t_w = int(T_H * _r)
img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC)
# Get the median of x axis and regard it as the x center of the person.
sum_point = img.sum()
sum_column = img.sum(axis=0).cumsum()
x_center = -1
for i in range(sum_column.size):
if sum_column[i] > sum_point / 2:
x_center = i
break
if x_center < 0:
message = 'seq:%s, frame:%s, no center.' % (
'-'.join(seq_info), frame_name)
warn(message)
log_print(pid, WARNING, message)
return None
h_T_W = int(T_W / 2)
left = x_center - h_T_W
right = x_center + h_T_W
if left <= 0 or right >= img.shape[1]:
left += h_T_W
right += h_T_W
_ = np.zeros((img.shape[0], h_T_W))
img = np.concatenate([_, img, _], axis=1)
img = img[:, left:right]
return img.astype('uint8')
def cut_pickle(seq_info, pid):
seq_name = '-'.join(seq_info)
log_print(pid, START, seq_name)
seq_path = os.path.join(INPUT_PATH, *seq_info)
out_dir = os.path.join(OUTPUT_PATH, *seq_info)
frame_list = os.listdir(seq_path)
frame_list.sort()
count_frame = 0
for _frame_name in frame_list:
frame_path = os.path.join(seq_path, _frame_name)
img = cv2.imread(frame_path)[:, :, 0]
img = cut_img(img, seq_info, _frame_name, pid)
if img is not None:
# Save the cut img
save_path = os.path.join(out_dir, _frame_name)
cv2.imwrite(save_path, img)
count_frame += 1
# Warn if the sequence contains less than 5 frames
if count_frame < 5:
message = 'seq:%s, less than 5 valid data.' % (
'-'.join(seq_info))
warn(message)
log_print(pid, WARNING, message)
log_print(pid, FINISH,
'Contain %d valid frames. Saved to %s.'
% (count_frame, out_dir))
if __name__ == '__main__':
pool = Pool(WORKERS)
results = list()
pid = 0
print('Pretreatment Start.\n'
'Input path: %s\n'
'Output path: %s\n'
'Log file: %s\n'
'Worker num: %d' % (
INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS))
id_list = os.listdir(INPUT_PATH)
id_list.sort()
# Walk the input path
for _id in id_list:
seq_type = os.listdir(os.path.join(INPUT_PATH, _id))
seq_type.sort()
for _seq_type in seq_type:
view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type))
view.sort()
for _view in view:
seq_info = [_id, _seq_type, _view]
out_dir = os.path.join(OUTPUT_PATH, *seq_info)
os.makedirs(out_dir)
results.append(
pool.apply_async(
cut_pickle,
args=(seq_info, pid)))
sleep(0.02)
pid += 1
pool.close()
unfinish = 1
while unfinish > 0:
unfinish = 0
for i, res in enumerate(results):
try:
res.get(timeout=0.1)
except Exception as e:
if type(e) == MP_TimeoutError:
unfinish += 1
continue
else:
print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n',
i, type(e))
raise e
pool.join()
3.3 config.py
conf = {
"WORK_PATH": "./work",
"CUDA_VISIBLE_DEVICES": "0,1,2,3", #*使用的GPU編號(一般設(shè)為0,若有多個GPU,可以根據(jù)剩余容量選擇相應(yīng)的編號)
"data": {
'dataset_path': "your_dataset_path", #*數(shù)據(jù)預(yù)處理后的路徑,即前文中的output_data_path
'resolution': '64', #*輸出圖像的分辨率(無需更改)
'dataset': 'CASIA-B', #*數(shù)據(jù)集名稱
# In CASIA-B, data of subject #5 is incomplete.
#*在CASIA-B數(shù)據(jù)集中,5號文件是不完整的
# Thus, we ignore it in training.
#*因此我們在訓(xùn)練的過程中忽略掉即可
# For more detail, please refer to
#*更多的細(xì)節(jié)信息可以參考
# function: utils.data_loader.load_data
#*函數(shù):utils.data_loader.load_data(前面文件換位置了,找的時候記得別找錯地方)
'pid_num': 73, #*設(shè)定用于訓(xùn)練的人數(shù),CASIA-B中一共有124個人,在這里作者選定73個人用于訓(xùn)練,剩余的用于測試
'pid_shuffle': False, #*在124個中隨機(jī)選出73個人
},
"model": {
'hidden_dim': 256, #*最后一層全連接層的隱藏層數(shù)量
'lr': 1e-4, #*學(xué)習(xí)率為0.0001
'hard_or_full_trip': 'full', #*損失函數(shù)
'batch_size': (8, 16), #*批次p*k = 8*16
'restore_iter': 0, #*第幾步開始訓(xùn)練
'total_iter': 80000, #*訓(xùn)練次數(shù)
'margin': 0.2, #*損失函數(shù)的margin參數(shù)
'num_workers': 3, #*線程數(shù)
'frame_num': 30, #*每個批次的幀數(shù)
'model_name': 'GaitSet',
},
}
值得注意的是,這里的batch_size是由兩個數(shù)組成的一個元組(p, k),其中p是人數(shù),k是p個人每人拿k個樣本,所以一個batch訓(xùn)練的樣本數(shù)量是p×k。
3.4 train.py
from initialization import initialization
from GaitSet.config import conf
import argparse
def boolean_string(s):
if s.upper() not in {'FALSE', 'TRUE'}:
raise ValueError('Not a valid boolean string')
return s.upper() == 'TRUE'
parser = argparse.ArgumentParser(description='Train')
parser.add_argument('--cache', default=True, type=boolean_string,
help='cache: if set as TRUE all the training data will be loaded at once'
' before the training start. Default: TRUE')
opt = parser.parse_args()
m = initialization(conf, train=opt.cache)[0]
print("Training START")
m.fit()
print("Training COMPLETE")
ERROR3
Traceback (most recent call last):
File "D:/PyCharm/Project/Gaitset/GaitSet-master/GaitSet/train.py", line 18, in <module>
m = initialization(conf, train=opt.cache)[0]
File "D:\PyCharm\Project\Gaitset\GaitSet-master\GaitSet\modelfile\initialization.py", line 57, in initialization
train_source, test_source = initialize_data(config, train, test)
File "D:\PyCharm\Project\Gaitset\GaitSet-master\GaitSet\modelfile\initialization.py", line 15, in initialize_data
train_source, test_source = load_data(**config['data'], cache=(train or test))
File "D:\PyCharm\Project\Gaitset\GaitSet-master\GaitSet\modelfile\data_loader.py", line 42, in load_data
pid_list = np.load(pid_fname)
File "D:\Anaconda\envs\GaitSet-master\lib\site-packages\numpy\lib\npyio.py", line 441, in load
pickle_kwargs=pickle_kwargs)
File "D:\Anaconda\envs\GaitSet-master\lib\site-packages\numpy\lib\format.py", line 743, in read_array
raise ValueError("Object arrays cannot be loaded when "
ValueError: Object arrays cannot be loaded when allow_pickle=False
當(dāng)allow_pickle=False時,無法加載對象數(shù)組
可以看到最終問題出在numpy的format.py文件中,查看numpy.load解釋文件
在1.16.3及之后版本中,allow_pickle的默認(rèn)值為False。
主要有2種解決方法:
1. 降低numpy版本
在Terminal中鍵入
pip install numpy==1.16.2
將numpy版本降到1.16.2及以下版本。
但是并不推薦這種方法,因為降級后的版本在其他地方調(diào)用時可能出現(xiàn)不兼容的情況(有待驗證?)。
2. 更改numpy.load()函數(shù)
定位到報錯最后一個文件下,將紅色框內(nèi)的代碼注釋掉:
..\aten\src\ATen\native\cuda\LegacyDefinitions.cpp:38: UserWarning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead.
..\aten\src\ATen\native\cuda\LegacyDefinitions.cpp:48: UserWarning: masked_select received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead.
...
masked_scatter_接收了一個dtype torch.uint8的掩碼,這種行為現(xiàn)在已被廢棄,請使用dtype torch.bool的掩碼來代替。
masked_select收到了一個dtype torch.uint8的掩碼,這種行為現(xiàn)在已經(jīng)被廢棄,請使用dtype torch.bool的掩碼代替。
瘋狂出現(xiàn)warning,原因是因為數(shù)據(jù)類型不對,只需要在triplet.py下圖這個位置加上紅框框住的兩行代碼,將uint8類型轉(zhuǎn)換為bool類型即可。
ERROR4
RuntimeError: CUDA out of memory. Tried to allocate 660.00 MiB (GPU 0; 4.00 GiB total capacity; 2.95 GiB already allocated; 0 bytes free; 14.10 MiB cached)
CUDA沒有內(nèi)存了。嘗試分配660.00 MiB (GPU 0; 4.00 GiB總?cè)萘? 2.95 GiB已分配; 0字節(jié)空閑; 14.10 MiB緩存)
程序運行所需內(nèi)存超出了GPU內(nèi)存容量,此問題一般有兩種解決方法:
1. 減小數(shù)據(jù)運算量
修改batch_size大小,使一個batch所需計算的數(shù)據(jù)量在可接受范圍內(nèi)。
2. 釋放占用GPU容量的無關(guān)進(jìn)程
打開cmd命令提示符,鍵入
nvidia-smi
查看當(dāng)前占用GPU內(nèi)存的進(jìn)程,通過命令:
taskkill -PID 進(jìn)程號 -F
殺死不需要的進(jìn)程,騰出GPU空間。
因為我的GPU容量本來就很小,把別的進(jìn)程都?xì)⒌粢才懿黄饋沓绦颍?,,所以選擇第一種改變batch_size大小的方法.一個batch只訓(xùn)練兩個人的16個樣本,程序就能正常RUN起來啦~
#batch_size=(8, 16)
batch_size=(2, 16)
運行結(jié)果
train了將近7個小時后……
每迭代100次打印輸出一條,每迭代1000次打印輸出一次運算時間,迭代全部結(jié)束后輸出一次總運行時間。
這里的hard和full分別表示的困難樣本對損失和所有樣本對損失。
hard: 對每個條帶進(jìn)行計算,找出每個樣本對應(yīng)的正樣本對中的最大距離,找出每個樣本的每個負(fù)樣本對中最小距離,這就相對于進(jìn)行困難樣本挖掘。
full: 對每個條帶進(jìn)行計算,計算每個正樣本對和負(fù)樣本對之間的triplet loss。
3.5 test.py
3.5.1 概念補(bǔ)充:probe set與gallery set
訓(xùn)練集和測試集均有probe set和gallery set,probe字面意思是探針,gallery為畫廊,可以把它們分別理解為驗證集和注冊集。
比如一個身份識別系統(tǒng),每個注冊用戶在注冊id時上傳的身份照片就構(gòu)成了gallery set;在用戶下次使用系統(tǒng),進(jìn)行身份認(rèn)證拍攝的照片,就構(gòu)成了probe set。步態(tài)識別的任務(wù)就是從gallery set和probe set先后分別提取出一個特征,計算兩個特征之間的距離(通常是歐氏距離),找到距離最近(差距最小or損失最少)的作為識別結(jié)果。
注意!
training set和test set與probe set和gallery set沒有什么一一對應(yīng)關(guān)系??!
下圖能夠較為直觀得理解這四者之間的關(guān)系
gaitset在訓(xùn)練集中學(xué)習(xí)如何匹配probe set和gallery set,將這一學(xué)習(xí)能力應(yīng)用于測試集的配對,所以在使用過程中,數(shù)據(jù)庫是可以隨時改變且不需要再次訓(xùn)練的。
from datetime import datetime
import numpy as np
import argparse
from initialization import initialization
from evaluator import evaluation
from config import conf
def boolean_string(s):
if s.upper() not in {'FALSE', 'TRUE'}:
raise ValueError('Not a valid boolean string')
return s.upper() == 'TRUE'
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('--iter', default='80000', type=int,
help='iter: iteration of the checkpoint to load. Default: 80000')
parser.add_argument('--batch_size', default='1', type=int,
help='batch_size: batch size for parallel test. Default: 1')
parser.add_argument('--cache', default=False, type=boolean_string,
help='cache: if set as TRUE all the test data will be loaded at once'
' before the transforming start. Default: FALSE')
opt = parser.parse_args()
# Exclude identical-view cases
def de_diag(acc, each_angle=False):
result = np.sum(acc - np.diag(np.diag(acc)), 1) / 10.0
if not each_angle:
result = np.mean(result)
return result
m = initialization(conf, test=opt.cache)[0]
# load modelfile checkpoint of iteration opt.iter
print('Loading the modelfile of iteration %d...' % opt.iter)
m.load(opt.iter)
print('Transforming...')
time = datetime.now()
test = m.transform('test', opt.batch_size)
print('Evaluating...')
acc = evaluation(test, conf['data'])
print('Evaluation complete. Cost:', datetime.now() - time)
# Print rank-1 accuracy of the best modelfile
# e.g.
# ===Rank-1 (Include identical-view cases)===
# NM: 95.405, BG: 88.284, CL: 72.041
for i in range(1):
print('===Rank-%d (Include identical-view cases)===' % (i + 1))
print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % (
np.mean(acc[0, :, :, i]),
np.mean(acc[1, :, :, i]),
np.mean(acc[2, :, :, i])))
# Print rank-1 accuracy of the best modelfile,excluding identical-view cases
# e.g.
# ===Rank-1 (Exclude identical-view cases)===
# NM: 94.964, BG: 87.239, CL: 70.355
for i in range(1):
print('===Rank-%d (Exclude identical-view cases)===' % (i + 1))
print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % (
de_diag(acc[0, :, :, i]),
de_diag(acc[1, :, :, i]),
de_diag(acc[2, :, :, i])))
# Print rank-1 accuracy of the best modelfile (Each Angle)
# e.g.
# ===Rank-1 of each angle (Exclude identical-view cases)===
# NM: [90.80 97.90 99.40 96.90 93.60 91.70 95.00 97.80 98.90 96.80 85.80]
# BG: [83.80 91.20 91.80 88.79 83.30 81.00 84.10 90.00 92.20 94.45 79.00]
# CL: [61.40 75.40 80.70 77.30 72.10 70.10 71.50 73.50 73.50 68.40 50.00]
np.set_printoptions(precision=2, floatmode='fixed')
for i in range(1):
print('===Rank-%d of each angle (Exclude identical-view cases)===' % (i + 1))
print('NM:', de_diag(acc[0, :, :, i], True))
print('BG:', de_diag(acc[1, :, :, i], True))
print('CL:', de_diag(acc[2, :, :, i], True))
3.5.2 運行結(jié)果
因為前面的batch_size被我改小了,所以最后的Rank-1結(jié)果和作者給出的差異有點大。
4. 算法核心代碼
首先放一下GaitSet的算法流程圖:
4.1 gaitset.py☆
首先對gaitset模型進(jìn)行初始化操作。在__init__部分,僅定義各層的操作,實際操作順序在下面的foward函數(shù)中進(jìn)行定義。
定義主體部分的卷積池化操作:
輸入圖片的通道數(shù)為1,卷積操作后的通道數(shù)為32,64,128,定義C1~C6六個層,分別為:
C1層:輸入通道數(shù)1,輸出通道數(shù)32,卷積核5×5,padding2
C2(+P)層:輸入通道數(shù)32,輸出通道數(shù)32,卷積核3×3,padding1,池化核2×2
C3層:輸入通道數(shù)32,輸出通道數(shù)64,卷積核3×3,padding1
C4(+P)層:輸入通道數(shù)64,輸出通道數(shù)64,卷積核3×3,padding1,池化核2×2
C5層:輸入通道數(shù)64,輸出通道數(shù)128,卷積核3×3,padding1
C6層:輸入通道數(shù)128,輸出通道數(shù)128,卷積核3×3,padding1
定義MGP部分的卷積池化操作:
因為輸入來自C2層,所以輸入通道數(shù)為32,卷積操作后的通道數(shù)為64,128,定義G1~G4四個層,分別為:
G1層:輸入通道數(shù)32,輸出通道數(shù)64,卷積核3×3,padding1
G2層:輸入通道數(shù)64,輸出通道數(shù)64,卷積核3×3,padding1
G3層:輸入通道數(shù)64,輸出通道數(shù)128,卷積核3×3,padding1
G4層:輸入通道數(shù)128,輸出通道數(shù)128,卷積核3×3,padding1
最大池化層,池化核2×2
定義前向傳播forward函數(shù):
輸入的數(shù)據(jù)集是已經(jīng)過圖像預(yù)處理后的數(shù)據(jù)集,torch.size為[128,30,64,44],指的是128(8×16)個人,每個人有30幀圖像,圖像大小為64×44。
前向傳播流程與上圖相一致,將輸入序列經(jīng)過C1,C2卷積池化操作后進(jìn)行SP(引入frame_max函數(shù),將C1和C2層操作完的30幀圖像進(jìn)行最大值提取,并合成一幀,這一幀的特征就是SetPooling,因此G1層的torch.size變?yōu)閇128,32,32,22]),然后上面的MGP和下面的主體分別、交互進(jìn)行(交互指的是主體經(jīng)過SP向MGP輸入數(shù)據(jù),下稱“融合”),通過相加運算實現(xiàn)數(shù)據(jù)的融合。
HPM部分負(fù)責(zé)將圖像分為5個尺度,分別為1,2,4,8,16條,并且將不可以進(jìn)行訓(xùn)練的Tensor數(shù)據(jù)轉(zhuǎn)化為自定義的Parameter,方面后續(xù)傳入module中進(jìn)行訓(xùn)練(成為模型的一部分)。
實現(xiàn)水平金字塔池化并完成全連接。首先將特征圖在高度(h)尺度上進(jìn)行分條,假設(shè)有
S
S
S個尺度,那么在尺度
s
∈
1
,
2
,
.
.
.
,
S
s \in 1, 2, ..., S
s∈1,2,...,S上,特征圖高度被分為
2
s
?
1
2^{s-1}
2s?1條,總共有
Σ
s
=
1
S
2
s
?
1
\Sigma_{s=1}^S 2^{s-1}
Σs=1S?2s?1 條,然后對這些條進(jìn)行全局池化,池化公式為:
f
s
,
t
′
=
m
a
x
p
o
o
l
(
z
s
,
t
)
+
a
v
g
p
o
o
l
(
z
s
,
t
)
f'_{s,t} = maxpool(z_{s,t})+avgpool(z_{s,t})
fs,t′?=maxpool(zs,t?)+avgpool(zs,t?)
其中
m
a
x
p
o
o
l
maxpool
maxpool和
a
v
g
p
o
o
l
avgpool
avgpool分別表示全局最大池化和全局平均池化,最后經(jīng)過全連接
f
c
fc
fc將特征映射到判別空間。
import torch
import torch.nn as nn
import numpy as np
from basic_blocks import SetBlock, BasicConv2d
class SetNet(nn.Module):
def __init__(self, hidden_dim):
super(SetNet, self).__init__()
self.hidden_dim = hidden_dim
self.batch_frame = None
#***注意此部分在__init__部分,僅定義各層的操作,實際操作順序見Foward函數(shù)
#*主體部分的卷積、池化操作
_set_in_channels = 1 #*輸入圖片的通道數(shù)為1
_set_channels = [32, 64, 128] #*通道數(shù)列表
self.set_layer1 = SetBlock(BasicConv2d(_set_in_channels, _set_channels[0], 5, padding=2))
#*C1層:輸入通道數(shù)1,輸出通道數(shù)32,卷積核5×5,padding2
self.set_layer2 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[0], 3, padding=1), True)
#*C2層:輸入通道數(shù)32,輸出通道數(shù)32,卷積核3×3,padding1,池化核2×2
self.set_layer3 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[1], 3, padding=1))
#*C3層:輸入通道數(shù)32,輸出通道數(shù)64,卷積核3×3,padding1
self.set_layer4 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[1], 3, padding=1), True)
#*C4層:輸入通道數(shù)64,輸出通道數(shù)64,卷積核3×3,padding1,池化核2×2
self.set_layer5 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[2], 3, padding=1))
#*C5層:輸入通道數(shù)64,輸出通道數(shù)128,卷積核3×3,padding1
self.set_layer6 = SetBlock(BasicConv2d(_set_channels[2], _set_channels[2], 3, padding=1))
#*C6層:輸入通道數(shù)128,輸出通道數(shù)128,卷積核3×3,padding1
#*MGP部分的卷積、池化操作
_gl_in_channels = 32 #*以C2層輸出數(shù)據(jù)作為這一部分的輸入,可知C2層的輸出通道數(shù)為32
_gl_channels = [64, 128] #*通道數(shù)列表
self.gl_layer1 = BasicConv2d(_gl_in_channels, _gl_channels[0], 3, padding=1)
#*G1層:輸入通道數(shù)32,輸出通道數(shù)64,卷積核3×3,padding1
self.gl_layer2 = BasicConv2d(_gl_channels[0], _gl_channels[0], 3, padding=1)
#*G2層:輸入通道數(shù)64,輸出通道數(shù)64,卷積核3×3,padding1
self.gl_layer3 = BasicConv2d(_gl_channels[0], _gl_channels[1], 3, padding=1)
#*G3層:輸入通道數(shù)64,輸出通道數(shù)128,卷積核3×3,padding1
self.gl_layer4 = BasicConv2d(_gl_channels[1], _gl_channels[1], 3, padding=1)
#*G4層:輸入通道數(shù)128,輸出通道數(shù)128,卷積核3×3,padding1
self.gl_pooling = nn.MaxPool2d(2)
#*最大池化層,池化核2×2
#*HPM部分的操作
self.bin_num = [1, 2, 4, 8, 16] #*將圖像分為5個尺度,分別為1、2、4、8、16條
#*將不可以進(jìn)行訓(xùn)練的Tensor數(shù)據(jù)轉(zhuǎn)化為自定義的Parameter,
#*方面后續(xù)傳入module中進(jìn)行訓(xùn)練(成為模型的一部分)
#*init.xarier_uniform的作用類似于參數(shù)初始化,xavier-保持輸入前后方差一致,uniform-均勻初始化
self.fc_bin = nn.ParameterList([
nn.Parameter(
nn.init.xavier_uniform_(
torch.zeros(sum(self.bin_num) * 2, 128, hidden_dim)))])
#*三個參數(shù)31*2,128,256
#*遍歷module進(jìn)行初始化
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
nn.init.xavier_uniform_(m.weight.data)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
nn.init.constant(m.bias.data, 0.0)
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.normal(m.weight.data, 1.0, 0.02)
nn.init.constant(m.bias.data, 0.0)
#*framem_max和frame_median就是在實現(xiàn)SetPooling的操作
#*第二維度求最大函數(shù)
def frame_max(self, x):
if self.batch_frame is None:
return torch.max(x, 1) #*返回每行的最大值及其索引值
else:
_tmp = [
torch.max(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1)
for i in range(len(self.batch_frame) - 1)
]
max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
return max_list, arg_max_list
#*第二維度求平均函數(shù)
def frame_median(self, x):
if self.batch_frame is None:
return torch.median(x, 1) #*返回每行的平均值及其索引值
else:
_tmp = [
torch.median(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1)
for i in range(len(self.batch_frame) - 1)
]
median_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
arg_median_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
return median_list, arg_median_list
#*前向傳播函數(shù)
#*這里就是核心算法操作順序
def forward(self, silho, batch_frame=None): #*silho是裁剪處理完的數(shù)據(jù)集,silho torch.size([128,30,64,44])指的是128(8*16)個人(樣本),每個人有30幀圖像,圖像大小為64*44
# n: batch_size, s: frame_num, k: keypoints_num, c: channel
if batch_frame is not None:
batch_frame = batch_frame[0].data.cpu().numpy().tolist()
_ = len(batch_frame)
for i in range(len(batch_frame)):
if batch_frame[-(i + 1)] != 0:
break
else:
_ -= 1
batch_frame = batch_frame[:_]
frame_sum = np.sum(batch_frame)
if frame_sum < silho.size(1):
silho = silho[:, :frame_sum, :, :]
self.batch_frame = [0] + np.cumsum(batch_frame).tolist()
n = silho.size(0) #*n=128
x = silho.unsqueeze(2) #*在索引值為2的位置插入一個維度,表示通道數(shù)
#*此處silho torch.size([128,30,1,64,44])
del silho
x = self.set_layer1(x) #*C1層:[128,30,1,64,44]--->[128,30,32,64,44]
x = self.set_layer2(x) #*C2層:[128,30,32,64,44]--->[128,30,32,32,22],含有一層池化,圖像高度寬度壓縮一半
#*下面引入frame_max函數(shù),將C1和C2層操作完的30幀圖像進(jìn)行最大值提取,并合成一幀,這一幀的特征就是SetPooling,因此G1層的torch.size變?yōu)閇128,32,32,22]
gl = self.gl_layer1(self.frame_max(x)[0]) #*G1層:[128,30,32,32,22]--->[128,64,32,22]
gl = self.gl_layer2(gl) #*G2層:[128,64,32,22]--->[128,64,32,22]
gl = self.gl_pooling(gl) #*[128,64,32,22]--->[128,64,16,11],池化一次,圖像高度寬度壓縮一半
x = self.set_layer3(x) #*C3層:[128,30,32,32,22]--->[128,30,64,32,22]
x = self.set_layer4(x) #*C4層:[128,30,64,32,22]--->[128,30,64,16,11],含有一層池化,圖像高度寬度壓縮一半
gl = self.gl_layer3(gl + self.frame_max(x)[0]) #*G3層:融合C4層輸出的SetPooling和G2層(相加),再進(jìn)行卷積 [128,64,16,11]--->[128,128,16,11]
gl = self.gl_layer4(gl) #*G4層:[128,128,16,11]--->[128,128,16,11]
x = self.set_layer5(x) #*C5層:[128,30,64,16,11]--->[128,30,128,16,11]
x = self.set_layer6(x) #*C6層:[128,30,128,16,11]--->[128,30,128,16,11],這里沒有池化層嗷
x = self.frame_max(x)[0] #*進(jìn)行一次SP [128,30,128,16,11]--->[128,128,16,11]
gl = gl + x #*將G4層與C6層融合(相加)
#*HPM部分的操作
feature = list() #*feature是一個列表類型的數(shù)據(jù)
n, c, h, w = gl.size() #*n,c,h,w分別對應(yīng)torch.size([128,128,16,11]
for num_bin in self.bin_num: #*循環(huán)取金字塔數(shù)據(jù)
z = x.view(n, c, num_bin, -1)
#*view函數(shù)的作用相當(dāng)于numpy中的reshape,即重新定義矩陣的形狀
#*參數(shù)-1可以動態(tài)調(diào)整這個維度位置上元素的個數(shù),以保證列表中總元素的數(shù)量是不變的
#*因此在這一for循環(huán)中,z torch.size分別為
#*torch.size([128,128,1,176])
#*torch.size([128,128,2,88])
#*torch.size([128,128,4,44])
#*torch.size([128,128,8,22])
#*torch.size([128,128,16,11])
z = z.mean(3) + z.max(3)[0] #*對最后一個維度求均值和最大值,并對應(yīng)相加
#*這里應(yīng)用的是全局池化,將三維特征變?yōu)橐痪S特征。全局池化的公式是f=maxpool+avgpool
#*其中maxpool和avgpool分別表示全局最大池化核全局平均池化
#*(之所以這樣用是因為作者發(fā)現(xiàn)這樣的實驗效果最佳,具體原因未知??)
feature.append(z) #*append函數(shù)能夠?qū)崿F(xiàn)在列表末尾添加元素,上面計算的z直接添加到feature末尾
#*對于MGP層同樣進(jìn)行全局池化操作
z = gl.view(n, c, num_bin, -1)
z = z.mean(3) + z.max(3)[0]
feature.append(z)
#*實現(xiàn)HPP
feature = torch.cat(feature, 2).permute(2, 0, 1).contiguous()
#*cat函數(shù)實現(xiàn)元素的連接
#*permute函數(shù)調(diào)整維度順序(2->0, 0->1, 1->2)
#*contiguous函數(shù)實現(xiàn)深拷貝(操作后原始數(shù)據(jù)不變)
#*操作后feature torch.size([62,128,128])
feature = feature.matmul(self.fc_bin[0])
#*實現(xiàn)矩陣的乘法,fc_bin:62*128*256
#*可以理解為有62個條帶,每個條帶是128維,對每個條帶進(jìn)行FC全連接
feature = feature.permute(1, 0, 2).contiguous()
return feature, None
4.2 model.py
此文件主要是對模型訓(xùn)練、測試、損失等模塊的初始化(預(yù)定義??)一些雜七雜八的內(nèi)容,對其中幾個自定義函數(shù)作簡要分析。
collate_fn: 定義DataLoader如何取出數(shù)據(jù)集中的步態(tài)圖像
select_frame: 定義圖像的取出是按照隨機(jī)有放回的原則,取出30幀
fit: 對模型進(jìn)行訓(xùn)練,也就是反向傳播的過程,訓(xùn)練出權(quán)重系數(shù)
np2ts: 數(shù)據(jù)類型轉(zhuǎn)換numpy to tensor
transform: 測試模型
import math
import os
import os.path as osp
import random
import sys
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.utils.data as tordata
# from model.network import TripletLoss, SetNet
from triplet import TripletLoss
from gaitset import SetNet
# from model.utils import TripletSampler
from sampler import TripletSampler
class Model:
def __init__(self,
hidden_dim,
lr,
hard_or_full_trip,
margin,
num_workers,
batch_size,
restore_iter,
total_iter,
save_name,
train_pid_num,
frame_num,
model_name,
train_source,
test_source,
img_size=64):
self.save_name = save_name
self.train_pid_num = train_pid_num
self.train_source = train_source
self.test_source = test_source
self.hidden_dim = hidden_dim
self.lr = lr
self.hard_or_full_trip = hard_or_full_trip
self.margin = margin
self.frame_num = frame_num
self.num_workers = num_workers
self.batch_size = batch_size
self.model_name = model_name
self.P, self.M = batch_size
self.restore_iter = restore_iter
self.total_iter = total_iter
self.img_size = img_size
self.encoder = SetNet(self.hidden_dim).float()
self.encoder = nn.DataParallel(self.encoder)
self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
self.triplet_loss = nn.DataParallel(self.triplet_loss)
self.encoder.cuda()
self.triplet_loss.cuda()
self.optimizer = optim.Adam([
{'params': self.encoder.parameters()},
], lr=self.lr)
self.hard_loss_metric = []
self.full_loss_metric = []
self.full_loss_num = []
self.dist_list = []
self.mean_dist = 0.01
self.sample_type = 'all'
def collate_fn(self, batch):
batch_size = len(batch)
feature_num = len(batch[0][0])
seqs = [batch[i][0] for i in range(batch_size)]
frame_sets = [batch[i][1] for i in range(batch_size)]
view = [batch[i][2] for i in range(batch_size)]
seq_type = [batch[i][3] for i in range(batch_size)]
label = [batch[i][4] for i in range(batch_size)]
batch = [seqs, view, seq_type, label, None]
def select_frame(index):
sample = seqs[index]
frame_set = frame_sets[index]
if self.sample_type == 'random':
frame_id_list = random.choices(frame_set, k=self.frame_num)
_ = [feature.loc[frame_id_list].values for feature in sample]
else:
_ = [feature.values for feature in sample]
return _
seqs = list(map(select_frame, range(len(seqs))))
if self.sample_type == 'random':
seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
else:
gpu_num = min(torch.cuda.device_count(), batch_size)
batch_per_gpu = math.ceil(batch_size / gpu_num)
batch_frames = [[
len(frame_sets[i])
for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
if i < batch_size
] for _ in range(gpu_num)]
if len(batch_frames[-1]) != batch_per_gpu:
for _ in range(batch_per_gpu - len(batch_frames[-1])):
batch_frames[-1].append(0)
max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
seqs = [[
np.concatenate([
seqs[i][j]
for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
if i < batch_size
], 0) for _ in range(gpu_num)]
for j in range(feature_num)]
seqs = [np.asarray([
np.pad(seqs[j][_],
((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
'constant',
constant_values=0)
for _ in range(gpu_num)])
for j in range(feature_num)]
batch[4] = np.asarray(batch_frames)
batch[0] = seqs
return batch
def fit(self):
if self.restore_iter != 0:
self.load(self.restore_iter)
self.encoder.train()
self.sample_type = 'random'
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
triplet_sampler = TripletSampler(self.train_source, self.batch_size)
train_loader = tordata.DataLoader(
dataset=self.train_source,
batch_sampler=triplet_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers)
train_label_set = list(self.train_source.label_set)
train_label_set.sort()
_time1 = datetime.now()
_time0 = datetime.now()
for seq, view, seq_type, label, batch_frame in train_loader:
self.restore_iter += 1
self.optimizer.zero_grad()
for i in range(len(seq)):
seq[i] = self.np2var(seq[i]).float()
if batch_frame is not None:
batch_frame = self.np2var(batch_frame).int()
feature, label_prob = self.encoder(*seq, batch_frame)
target_label = [train_label_set.index(l) for l in label]
target_label = self.np2var(np.array(target_label)).long()
triplet_feature = feature.permute(1, 0, 2).contiguous()
triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
(full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
) = self.triplet_loss(triplet_feature, triplet_label)
if self.hard_or_full_trip == 'hard':
loss = hard_loss_metric.mean()
elif self.hard_or_full_trip == 'full':
loss = full_loss_metric.mean()
self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
self.dist_list.append(mean_dist.mean().data.cpu().numpy())
if loss > 1e-9:
loss.backward()
self.optimizer.step()
if self.restore_iter == 80000:
print(datetime.now() - _time0)
if self.restore_iter % 1000 == 0:
print(datetime.now() - _time1)
_time1 = datetime.now()
if self.restore_iter % 100 == 0:
self.save()
print('iter {}:'.format(self.restore_iter), end='')
print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='')
print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='')
print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='')
self.mean_dist = np.mean(self.dist_list)
print(', mean_dist={0:.8f}'.format(self.mean_dist), end='')
print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='')
print(', hard or full=%r' % self.hard_or_full_trip)
sys.stdout.flush()
self.hard_loss_metric = []
self.full_loss_metric = []
self.full_loss_num = []
self.dist_list = []
# Visualization using t-SNE
# if self.restore_iter % 500 == 0:
# pca = TSNE(2)
# pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy())
# for i in range(self.P):
# plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0],
# pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i])
#
# plt.show()
if self.restore_iter == self.total_iter:
break
def ts2var(self, x):
return autograd.Variable(x).cuda()
def np2var(self, x):
return self.ts2var(torch.from_numpy(x))
def transform(self, flag, batch_size=1):
self.encoder.eval()
source = self.test_source if flag == 'test' else self.train_source
self.sample_type = 'all'
data_loader = tordata.DataLoader(
dataset=source,
batch_size=batch_size,
sampler=tordata.sampler.SequentialSampler(source),
collate_fn=self.collate_fn,
num_workers=self.num_workers)
feature_list = list()
view_list = list()
seq_type_list = list()
label_list = list()
for i, x in enumerate(data_loader):
seq, view, seq_type, label, batch_frame = x
for j in range(len(seq)):
seq[j] = self.np2var(seq[j]).float()
if batch_frame is not None:
batch_frame = self.np2var(batch_frame).int()
# print(batch_frame, np.sum(batch_frame))
feature, _ = self.encoder(*seq, batch_frame)
n, num_bin, _ = feature.size()
feature_list.append(feature.view(n, -1).data.cpu().numpy())
view_list += view
seq_type_list += seq_type
label_list += label
return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list
def save(self):
os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
torch.save(self.encoder.state_dict(),
osp.join('checkpoint', self.model_name,
'{}-{:0>5}-encoder.ptm'.format(
self.save_name, self.restore_iter)))
torch.save(self.optimizer.state_dict(),
osp.join('checkpoint', self.model_name,
'{}-{:0>5}-optimizer.ptm'.format(
self.save_name, self.restore_iter)))
# restore_iter: iteration index of the checkpoint to load
def load(self, restore_iter):
self.encoder.load_state_dict(torch.load(osp.join(
'checkpoint', self.model_name,
'{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter))))
self.optimizer.load_state_dict(torch.load(osp.join(
'checkpoint', self.model_name,
'{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))
4.3 triplet.py
定義Batch All的三元損失(triplet loss)函數(shù):
import torch
import torch.nn as nn
import torch.nn.functional as F
class TripletLoss(nn.Module):
def __init__(self, batch_size, hard_or_full, margin):
super(TripletLoss, self).__init__()
self.batch_size = batch_size
self.margin = margin
def forward(self, feature, label):
# feature: [n, m, d], label: [n, m]
n, m, d = feature.size()
hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).byte().view(-1)
hp_mask = hp_mask.bool()
hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).byte().view(-1)
hn_mask = hn_mask.bool()
dist = self.batch_dist(feature)
mean_dist = dist.mean(1).mean(1)
dist = dist.view(-1)
# hard
hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0]
hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0]
hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1)
hard_loss_metric_mean = torch.mean(hard_loss_metric, 1)
# non-zero full
full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1)
full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1)
full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1)
full_loss_metric_sum = full_loss_metric.sum(1)
full_loss_num = (full_loss_metric != 0).sum(1).float()
full_loss_metric_mean = full_loss_metric_sum / full_loss_num
full_loss_metric_mean[full_loss_num == 0] = 0
return full_loss_metric_mean, hard_loss_metric_mean, mean_dist, full_loss_num
def batch_dist(self, x):
x2 = torch.sum(x ** 2, 2)
dist = x2.unsqueeze(2) + x2.unsqueeze(2).transpose(1, 2) - 2 * torch.matmul(x, x.transpose(1, 2))
dist = torch.sqrt(F.relu(dist))
return dist
5. (原作)運行結(jié)果
附錄
我在項目中配置的包及版本號參見以下表格,里面絕大多數(shù)的包是沒有被調(diào)用的,如果用戶使用的是一個沒有安裝任何Package的空白環(huán)境,可以根據(jù)代碼中import的內(nèi)容依次安裝;如果用戶已經(jīng)安裝了Package,但是出現(xiàn)了因版本原因報錯的現(xiàn)象,可以根據(jù)下表核對版本,適當(dāng)升級或降低相應(yīng)版本。
Package | Version |
---|---|
absl-py | 1.3.0 |
aiohttp | 3.8.1 |
aiosignal | 1.2.0 |
argon2-cffi | 20.1.0 |
async-timeout | 4.0.2 |
async_generator | 1.1 |
asynctest | 0.13.0 |
attrs | 21.4.0 |
backcall | 0.2.0 |
beautifulsoup4 | 4.11.1 |
blas | 1 |
bleach | 4.1.0 |
blinker | 1.4 |
brotli | 1.0.9 |
brotli-bin | 1.0.9 |
brotlipy | 0.7.0 |
ca-certificates | 2022.07.19 |
cachetools | 4.2.2 |
certifi | 2022.9.24 |
cffi | 1.14.6 |
charset-normalizer | 2.0.4 |
click | 8.0.4 |
colorama | 0.4.4 |
cryptography | 37.0.1 |
cudatoolkit | 10.0.130 |
cycler | 0.11.0 |
dataclasses | 0.8 |
decorator | 4.4.2 |
defusedxml | 0.7.1 |
dominate | 2.6.0 |
entrypoints | 0.3 |
fftw | 3.3.9 |
fonttools | 4.25.0 |
freetype | 2.10.4 |
frozenlist | 1.2.0 |
glib | 2.69.1 |
google-auth | 2.6.0 |
google-auth-oauthlib | 0.5.2 |
grpcio | 1.42.0 |
gst-plugins-base | 1.18.5 |
gstreamer | 1.18.5 |
h5py | 2.10.0 |
hdf5 | 1.10.4 |
icc_rt | 2022.1.0 |
icu | 58.2 |
idna | 3.3 |
imageio | 2.19.3 |
importlib-metadata | 4.11.3 |
intel-openmp | 2021.4.0 |
ipykernel | 5.3.4 |
ipython | 7.16.1 |
ipython_genutils | 0.2.0 |
jedi | 0.17.0 |
jinja2 | 3.0.3 |
joblib | 1.1.0 |
jpeg | 9e |
jsonschema | 3.0.2 |
jupyter-core | 4.11.1 |
jupyter_client | 7.1.2 |
jupyter_core | 4.8.1 |
jupyterlab_pygments | 0.1.2 |
kiwisolver | 1.3.1 |
lerc | 3 |
libbrotlicommon | 1.0.9 |
libbrotlidec | 1.0.9 |
libbrotlienc | 1.0.9 |
libclang | 12.0.0 |
libdeflate | 1.8 |
libffi | 3.4.2 |
libiconv | 1.16 |
libogg | 1.3.5 |
libpng | 1.6.37 |
libprotobuf | 3.20.1 |
libsodium | 1.0.18 |
libtiff | 4.4.0 |
libvorbis | 1.3.7 |
libwebp | 1.2.4 |
libwebp-base | 1.2.4 |
libxml2 | 2.9.14 |
libxslt | 1.1.35 |
lz4-c | 1.9.3 |
m2w64-gcc-libgfortran | 5.3.0 |
m2w64-gcc-libs | 5.3.0 |
m2w64-gcc-libs-core | 5.3.0 |
m2w64-gmp | 6.1.0 |
m2w64-libwinpthread-git | 5.0.0.4634.697f757 |
markdown | 3.3.4 |
markupsafe | 2.0.1 |
matplotlib | 3.5.2 |
matplotlib-base | 3.5.2 |
mistune | 0.8.4 |
mkl | 2021.4.0 |
mkl-service | 2.4.0 |
mkl_fft | 1.3.1 |
mkl_random | 1.2.2 |
msys2-conda-epoch | 20160418 |
multidict | 6.0.2 |
munkres | 1.1.4 |
nbclient | 0.5.13 |
nbconvert | 6.0.7 |
nbformat | 5.1.3 |
nest-asyncio | 1.5.1 |
networkx | 2.5.1 |
ninja | 1.10.2 |
ninja-base | 1.10.2 |
notebook | 6.4.3 |
numpy | 1.16.2 |
numpy-base | 1.21.5 |
oauthlib | 3.2.0 |
olefile | 0.46 |
opencv-python | 4.6.0.66 |
openssl | 1.1.1q |
packaging | 21.3 |
pandas | 1.1.5 |
pandoc | 2.12 |
pandocfilters | 1.5.0 |
parso | 0.8.3 |
pcre | 8.45 |
pickleshare | 0.7.5 |
pillow | 8.4.0 |
pip | 21.2.2 |
ply | 3.11 |
prettytable | 2.5.0 |
prometheus_client | 0.13.1 |
prompt-toolkit | 3.0.20 |
protobuf | 3.20.1 |
pyasn1 | 0.4.8 |
pyasn1-modules | 0.2.8 |
pycparser | 2.21 |
pyecharts | 1.9.1 |
pygments | 2.11.2 |
pyjwt | 2.4.0 |
pyopenssl | 22.0.0 |
pyparsing | 3.0.9 |
pyqt | 5.15.7 |
pyqt5-sip | 12.11.0 |
pyreadline | 2.1 |
pyrsistent | 0.17.3 |
pysnooper | 1.1.1 |
pysocks | 1.7.1 |
python | 3.7.13 |
python-dateutil | 2.8.2 |
python-fastjsonschema | 2.16.2 |
pytorch | 1.2.0 |
pytz | 2021.3 |
pywavelets | 1.1.1 |
pywin32 | 228 |
pywinpty | 0.5.7 |
pyzmq | 22.2.1 |
qt-main | 5.15.2 |
qt-webengine | 5.15.9 |
qtwebkit | 5.212 |
ranger | 0.1 |
requests | 2.27.1 |
requests-oauthlib | 1.3.0 |
rsa | 4.7.2 |
scikit-image | 0.17.2 |
scikit-learn | 0.24.2 |
scipy | 1.7.3 |
seaborn | 0.11.2 |
send2trash | 1.8.0 |
setuptools | 58.0.4 |
simplejson | 3.17.6 |
sip | 6.6.2 |
six | 1.16.0 |
sklearn | 0 |
soupsieve | 2.3.2.post1 |
sqlite | 3.39.3 |
tensorboard-data-server | 0.6.0 |
tensorboard-plugin-wit | 1.8.1 |
terminado | 0.9.4 |
testpath | 0.5.0 |
threadpoolctl | 2.2.0 |
tifffile | 2020.9.22 |
tk | 8.6.12 |
toml | 0.10.2 |
torch | 1.12.1 |
torchsnooper | 0.8 |
torchvision | 0.4.0 |
tornado | 6.2 |
tqdm | 4.61.1 |
traitlets | 4.3.3 |
typing-extensions | 4.1.1 |
typing_extensions | 4.1.1 |
urllib3 | 1.26.9 |
vc | 14.2 |
vs2015_runtime | 14.27.29016 |
wcwidth | 0.2.5 |
webencodings | 0.5.1 |
werkzeug | 2.0.3 |
wheel | 0.37.1 |
win_inet_pton | 1.1.0 |
wincertstore | 0.2 |
winpty | 0.4.3 |
xarray | 0.16.2 |
xz | 5.2.6 |
yarl | 1.8.1 |
zeromq | 4.3.4 |
zipp | 3.6.0 |
zlib | 1.2.12 |
zstd | 1.5.2 |
參考博客:
【論文翻譯】-- GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition
【原創(chuàng)·論文翻譯】GaitSet-旨在用自己的語言表達(dá)出作者的真實意圖
跑通GaitSet(跑不通你來揍我)
GaitSet源代碼解讀(一)
GaitSet源代碼解讀(二)文章來源:http://www.zghlxwxcb.cn/news/detail-405036.html
GaitSet源代碼解讀(三)文章來源地址http://www.zghlxwxcb.cn/news/detail-405036.html
到了這里,關(guān)于【步態(tài)識別】GaitSet 算法學(xué)習(xí)+配置環(huán)境+代碼調(diào)試運行《GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition》的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!