使用Flask簡(jiǎn)單部署深度學(xué)習(xí)模型
一、安裝 Flask
pip install Flask==2.0.2
pip install Flask_Cors==3.0.9
pip install Pillow
二、Flask程序運(yùn)行過(guò)程
- 當(dāng)客戶(hù)端想要獲取資源時(shí),一般會(huì)通過(guò)瀏覽器發(fā)起HTTP請(qǐng)求。
- 此時(shí),Web服務(wù)器會(huì)把來(lái)自客戶(hù)端的所有請(qǐng)求都交給Flask程序?qū)嵗?/li>
- 程序?qū)嵗褂肳erkzeug來(lái)做路由分發(fā)(URL請(qǐng)求和視圖函數(shù)之間的對(duì)應(yīng)關(guān)系)。
- 根據(jù)每個(gè)URL請(qǐng)求,找到具體的視圖函數(shù)并進(jìn)行調(diào)用。在Flask程序中,路由的實(shí)現(xiàn)一般是通過(guò)程序?qū)嵗难b飾器實(shí)現(xiàn)。
- Flask調(diào)用視圖函數(shù)后,可以返回兩種內(nèi)容:
- 字符串內(nèi)容:將視圖函數(shù)的返回值作為響應(yīng)的內(nèi)容,返回給客戶(hù)端(瀏覽器)。
- HTML模板內(nèi)容:獲取到數(shù)據(jù)后,把數(shù)據(jù)傳入HTML模板文件中,模板引擎負(fù)責(zé)渲染HTTP響應(yīng)數(shù)據(jù),然后返回響應(yīng)數(shù)據(jù)給客戶(hù)端(瀏覽器)。
三、 Flask開(kāi)發(fā)
# 1.導(dǎo)入Flask擴(kuò)展
from flask import Flask
# 2.創(chuàng)建Flask應(yīng)用程序?qū)嵗?/span>
# 需要傳入__name__,作用是為了確定資源所在的路徑
app = Flask(__name__)
# 3.定義路由及視圖函數(shù)
# Flask中定義路由是通過(guò)裝飾器實(shí)現(xiàn)的
# 路由默認(rèn)只支持GET,如果需要增加,自行制定
@app.route('/', methods=['GET', 'POST'])
def index():
return "hellow flask"
# 4.啟動(dòng)程序
if __name__ == '__main__':
# 執(zhí)行了app.run,就會(huì)將Flask程序運(yùn)行在簡(jiǎn)易服務(wù)器上
app.run()
四、 使用Flask框架完成前后端交互
import os
import io
import json
import time
import argparse
import cv2
import torch
import imageio
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request, render_template
from data.custom_transforms import FixedResize, AddIgnoreRegions, ToTensor, Normalize
import base64
from utils.utils import get_output, mkdir_if_missing
import numpy as np
from flask_cors import CORS
from utils.common_config import get_model
from utils.config import create_config
import torchvision.transforms as transforms
# 設(shè)置允許的文件格式
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG'])
# 創(chuàng)建Flask實(shí)例
app = Flask(__name__)
CORS(app)
# 導(dǎo)入調(diào)色板文件
palette_path = "palette.json"
assert os.path.exists(palette_path), f"palette {palette_path} not found."
with open(palette_path, "rb") as f:
pallette_dict = json.load(f)
pallette = []
for v in pallette_dict.values():
pallette += v
weights_path = "configs/PADResults/PASCALContext/hrnet_w18/pad_net/best_model.pth.tar"
assert os.path.exists(weights_path),"weights path does not exits.."
# Parser
parser = argparse.ArgumentParser(description='Vanilla Training')
parser.add_argument('--config_env', default='configs/env.yml',
help='Config file for the environment')
parser.add_argument('--config_exp', default='configs/pascal/pad_net.yml',
help='Config file for the experiment')
args = parser.parse_args()
# Retrieve config file
cv2.setNumThreads(0)
p = create_config(args.config_env, args.config_exp)
# select device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
# create model
model = get_model(p)
model = torch.nn.DataParallel(model)
model = model.cuda()
#loal model weights
# model.load_state_dict(torch.load(p['best_model']))
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
# 圖像處理
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.ToPILImage(),
transforms.Resize([512, 512]),
# AddIgnoreRegions(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) ])
img = io.BytesIO(image_bytes)
image = Image.open(img)
if image.mode != "RGB":
raise ValueError("input file does not RGB image...")
image = np.array(image, dtype='uint8')
print(my_transforms(image).shape)
return my_transforms(image).to(device)
# return image.unsqueeze(0).to(device)
# 獲取當(dāng)前時(shí)間表示的字符串的小數(shù)部分,精確到0.1毫秒
def get_secondFloat(timestamp):
secondFloat = ('%.4f' %(timestamp%1))[1:]
return secondFloat
# 獲取當(dāng)前時(shí)間表示的字符串,精確到0.1毫秒
def get_timeString():
now_timestamp = time.time()
now_structTime = time.localtime(now_timestamp)
timeString_pattern = '%Y%m%d_%H%M%S'
now_timeString_1 = time.strftime(timeString_pattern, now_structTime)
return now_timeString_1
def get_prediction(p, image_bytes):
model.eval()
tasks = p.TASKS.NAMES
results_dirPath = 'static/results'
# save_dirs = os.path.join(results_dirPath, task)
if os.path.isdir(results_dirPath):
mkdir_if_missing(results_dirPath)
#
inputs = transform_image(image_bytes=image_bytes)
inputs = inputs.cuda(non_blocking=True)
inputs = inputs.reshape(1, 3, 512, 512)
print(inputs.shape)
# print(inputs)
output = model(inputs)
# 保存預(yù)測(cè)結(jié)果為圖片
for task in tasks: # normals 1,512,512,3
if task == 'normals' :
output_task = get_output(output[task], task).cpu().data.numpy()
# for jj in range(0,1):
for jj in range(int(inputs.size()[0])):
result = cv2.resize(output_task[jj], dsize=(512, 512),
interpolation=p.TASKS.INFER_FLAGVALS[task])
imageio.imwrite(os.path.join(results_dirPath, task + '.png'), result.astype(np.uint8))
elif task == 'semseg' :
prediction = output['semseg'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
mask = Image.fromarray(prediction)
mask.putpalette(pallette)
mask.save(os.path.join(results_dirPath, task + '.png'))
else:
pass
return {"semseg": os.path.join(results_dirPath, 'semseg.png'),
"normals": os.path.join(results_dirPath, 'normals.png')
}
# 前后端交互
@app.route('/predict', methods=['GET', 'POST'])
@torch.no_grad()
def predict():
image = request.files['file']
print(image.filename)
received_dirPath = 'webimage/received_images'
if not os.path.isdir(received_dirPath):
os.makedirs(received_dirPath)
imageFilePath = os.path.join(received_dirPath, image.filename)
# print("save finished")
img_bytes = image.read()
# print(img_bytes)
result_info = get_prediction(p, img_bytes)
print(result_info)
return jsonify({'status': 1,
'semseg_url': result_info['semseg'],
'normals_url': result_info['normals']
})
@app.route('/', methods=["GET", "POST"])
def root():
return render_template("./predict.html")
if __name__ == '__main__':
app.run(host="127.0.0.1", port=5005)
五、 前端HTML頁(yè)面
這部分借鑒的別人的代碼文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-424484.html
<!DOCTYPE html>
<html>
<head>
<title>多任務(wù)學(xué)習(xí)展示</title>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<script src="https://apps.bdimg.com/libs/jquery/2.1.4/jquery.min.js"></script>
</head>
<body>
<!--<h3>請(qǐng)選擇圖片文件:PNG/JPG/JPEG/SVG/GIF</h3>-->
<h1 style="background-color:lightcoral;text-align:center;font-family:arial;color:cornflowerblue;font-size:50px;">多任務(wù)學(xué)習(xí)</h1>
<div style="text-align: left;margin-left: 0px;margin-top: 0px;/* width: 60px; */">
<div style="float:left; margin-left: 100px;margin-top: 150px;">
<img src="static/2008_000036.jpg" id="img0" style="margin-left:10px;width: 20rem;height: 20rem;">
<br>
<a href="javascript:;" class="file" style="text-align: center">選擇文件
<input type="file" name="file" id="file0" style="text-align: center"><br>
</a>
</div>
<div style="margin-left: 525px; margin-top: 0px;width: 20px;height: 0px;">
<input type="button" id="b0" onclick="test()" value="使用多任務(wù)模型進(jìn)行預(yù)測(cè)" style="margin-top: 250px;margin-left: 75px;width: auto;">
</div>
<div style="margin-right: px;margin-left: 880px;margin-top: 0px;">
<!--<pre id="out" style="width:320px;height:50px;line-height: 50px;margin-top:20px;"></pre>-->
<div style="margin-right: 50px;margin-top: 0px;">
<img src="static/sem_2008_000036.png" id="img1" style="width: 20rem;height: 20rem;margin-top: 0px;">
語(yǔ)義分割
</div>
<div style="margin-right: 50px">
<img src="static/nor_008_000036.png" id="img2" style="margin-top:20px;width: 20rem;height: 20rem;">
表面法線估計(jì)
</div>
</div>
</div>
<script type="text/javascript">
$("#file0").change(function(){
var objUrl = getObjectURL(this.files[0]) ;//獲取文件信息
console.log("objUrl = "+objUrl);
if (objUrl) {
$("#img0").attr("src", objUrl);
}
});
function test() {
var fileobj = $("#file0")[0].files[0];
console.log(fileobj);
var form = new FormData();
form.append("file", fileobj);
var Con1 = $("#img1");
var Con2 = $("#img2");
var out='';
var flower='';
var results = $.ajax({
type: 'POST',
url: "predict",
data: form,
async: false, //同步執(zhí)行
processData: false, // 告訴jquery要傳輸data對(duì)象
contentType: false, //告訴jquery不需要增加請(qǐng)求頭對(duì)于contentType的設(shè)置
dataType: "json",
success: function (arg) {
out = arg;
console.log(out);
var r = window.confirm("預(yù)測(cè)完成,顯示圖片");
if(r == true) {
document.getElementById("img1").src=out['semseg_url'];
document.getElementById("img2").src=out['normals_url'];
}
},error:function(){
console.log("后臺(tái)處理錯(cuò)誤");
}
});
}
function getObjectURL(file) {
var url = null;
if(window.createObjectURL!=undefined) {
url = window.createObjectURL(file) ;
}else if (window.URL!=undefined) { // mozilla(firefox)
url = window.URL.createObjectURL(file) ;
}else if (window.webkitURL!=undefined) { // webkit or chrome
url = window.webkitURL.createObjectURL(file) ;
}
return url ;
}
</script>
<style>
.file {
position: relative;
/*display: inline-block;*/
background: #CCC ;
border: 1px solid #CCC;
padding: 4px 4px;
overflow: hidden;
text-decoration: none;
text-indent: 0;
width:100px;
height:30px;
line-height: 30px;
border-radius: 5px;
color: #333;
font-size: 13px;
}
.file input {
position: absolute;
font-size: 13px;
right: 0;
top: 0;
opacity: 0;
border: 1px solid #333;
padding: 4px 4px;
overflow: hidden;
text-indent: 0;
width:100px;
height:30px;
line-height: 30px;
border-radius: 5px;
color: #FFFFFF;
}
#b0{
background: #1899FF;
border: 1px solid #CCC;
padding: 4px 10px;
overflow: hidden;
text-indent: 0;
width:60px;
height:28px;
line-height: 20px;
border-radius: 5px;
color: #FFFFFF;
font-size: 13px;
}
body{
background: paleturquoise;
}
/*.gradient{*/
/*filter:alpha(opacity=100 finishopacity=50 style=1 startx=0,starty=0,finishx=0,finishy=150) progid:DXImageTransform.Microsoft.gradient(startcolorstr=#fff,endcolorstr=#ccc,gradientType=0);*/
/*-ms-filter:alpha(opacity=100 finishopacity=50 style=1 startx=0,starty=0,finishx=0,finishy=150) progid:DXImageTransform.Microsoft.gradient(startcolorstr=#fff,endcolorstr=#ccc,gradientType=0);!*IE8*!*/
/*background:#1899FF; !* 一些不支持背景漸變的瀏覽器 *!*/
/*background:-moz-linear-gradient(top, #fff, #1899FF);*/
/*background:-webkit-gradient(linear, 0 0, 0 bottom, from(#fff), to(#ccc));*/
/*background:-o-linear-gradient(top, #fff, #ccc);*/
/*}*/
</style>
</body>
</html>
六、 debug
1. 將需要讀取的圖片放在static文件夾下,否則讀取不到
2. 在本地測(cè)試時(shí)通過(guò)映射訪問(wèn)服務(wù)器的127.0.0.1
文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-424484.html
到了這里,關(guān)于使用Flask簡(jiǎn)單部署深度學(xué)習(xí)模型的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!