官網(wǎng)鏈接
Deploying PyTorch in Python via a REST API with Flask — PyTorch Tutorials 2.0.1+cu117 documentation
通過flask的rest API在python中部署pytorch
在本教程中,我們將使用Flask部署PyTorch模型,并開放用于模型推斷的REST API。特別是,我們將部署一個(gè)預(yù)訓(xùn)練的DenseNet 121模型來檢測(cè)圖像。
這是關(guān)于在生產(chǎn)環(huán)境中部署PyTorch模型的系列教程中的第一篇。使用Flask這種方式是迄今為止部署PyTorch模型的最簡(jiǎn)單方法,但它不適用于具有高性能要求的用例。
- 如果你已經(jīng)熟悉了TorchScript,你可以直接跳到我們的加載一個(gè)TorchScript模型在c++教程。(Loading a TorchScript Model in C++ )
- 如果你需要對(duì)TorchScript進(jìn)行復(fù)習(xí),請(qǐng)查看我們的TorchScript入門教程。(Intro a TorchScript )
API定義
我們將首先定義API 路徑、請(qǐng)求和響應(yīng)類型。我們的API路徑是 /predict ,它接受帶有包含圖像的文件參數(shù)的HTTP POST請(qǐng)求。響應(yīng)將是JSON響應(yīng),其中包含預(yù)測(cè)結(jié)果。
{"class_id": "n02124075", "class_name": "Egyptian_cat"}{"class_id": "n02124075", "class_name": "Egyptian_cat"}
依賴項(xiàng)
運(yùn)行以下命令安裝所需的依賴項(xiàng):
$ pip install Flask==2.0.1 torchvision==0.10.0
簡(jiǎn)單Web服務(wù)器
下面是一個(gè)簡(jiǎn)單的web服務(wù)器,摘自Flask的文檔
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
將上面的代碼片段保存在一個(gè)名為app.py的文件中,現(xiàn)在你可以通過輸入以下命令來運(yùn)行Flask開發(fā)服務(wù)器:
$ FLASK_ENV=development FLASK_APP=app.py flask run
當(dāng)您在web瀏覽器中訪問http://localhost:5000/時(shí),您將看到Hello World!文本
我們將對(duì)上面的代碼片段做一些修改,使其適合我們的API定義。首先,我們將把方法重命名為predict。我們將把請(qǐng)求路徑更新為/predict。由于圖像文件將通過HTTP POST請(qǐng)求發(fā)送,我們將更新它,使其也只接受POST請(qǐng)求。
@app.route('/predict', methods=['POST'])
def predict():
return 'Hello World!'
我們還將更改響應(yīng)類型,以便它返回一個(gè)包含ImageNet類id和名稱的JSON響應(yīng)。更新后的app.py文件現(xiàn)在將是:
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
推理
在下一節(jié)中,我們將重點(diǎn)討論如何編寫推理代碼。這將涉及兩個(gè)部分,一個(gè)是我們準(zhǔn)備圖像,以便它可以饋送到DenseNet,接下來,我們將編寫代碼以從模型中獲得實(shí)際預(yù)測(cè)。
準(zhǔn)備圖像
DenseNet模型要求圖像為3通道RGB圖像,大小為224 x 224。我們還將用所需的均值和標(biāo)準(zhǔn)差值對(duì)圖像張量進(jìn)行歸一化。你可以在這里讀到更多(here)。
我們將使用torchvision 庫(kù)中的 transforms ,并構(gòu)建一個(gè)變換管道,它可以根據(jù)要求變換我們的圖像。你可以在這里閱讀更多關(guān)于變換的內(nèi)容(here)。
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
上述方法接受字節(jié)數(shù)據(jù)的圖像,應(yīng)用一些列的transforms 并返回一個(gè)張量。為了測(cè)試上述方法,以字節(jié)模式讀取圖像文件(首先將../_static/img/sample_file.jpeg替換為計(jì)算機(jī)上文件的實(shí)際路徑)并查看是否返回一個(gè)張量:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor)
預(yù)測(cè)
現(xiàn)在將使用預(yù)訓(xùn)練的DenseNet 121模型來預(yù)測(cè)圖像類別。我們將使用torchvision庫(kù),加載模型并獲得推理結(jié)果。雖然我們將在本例中使用預(yù)訓(xùn)練模型,但您可以對(duì)自己的模型使用相同的方法。了解更多關(guān)于加載模型的信息(tutorial)。
from torchvision import models
# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
return y_hat
張量y_hat 將包含預(yù)測(cè)類別的索引id, 然而,我們需要一個(gè)人類可讀的類名。為此,我們需要一個(gè)類別id和命名的映射。下載 imagenet_class_index.json 這個(gè)文件( this file),并記住保存它的位置(或者,如果您遵循本教程中的確切步驟,將其保存在教程/_static中)。這個(gè)文件包含ImageNet類別id到ImageNet類名的映射。我們將加載這個(gè)JSON文件并獲取預(yù)測(cè)類別索引的類名。
import json
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
在使用imagenet_class_index字典之前,首先我們將把張量值轉(zhuǎn)換為字符串值,因?yàn)?strong>imagenet_class_index字典中的鍵是字符串。我們將測(cè)試上述方法:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))
你應(yīng)該得到這樣的返回:
['n02124075', 'Egyptian_cat']
數(shù)組中的第一項(xiàng)是ImageNet類別id,第二項(xiàng)是人類可讀的名稱。
注意
您是否注意到model變量不是get_prediction方法的局部變量,或者說為什么model是一個(gè)全局變量?就內(nèi)存和計(jì)算而言,加載模型可能是一項(xiàng)昂貴的操作。如果我們?cè)趃et_prediction方法中加載模型,那么每次調(diào)用該方法時(shí)都會(huì)不必要地加載模型。因?yàn)槲覀冋跇?gòu)建一個(gè)web服務(wù)器,每秒可能有數(shù)千個(gè)請(qǐng)求,我們不應(yīng)該浪費(fèi)時(shí)間為每個(gè)推理加載模型。因此,我們只將模型加載到內(nèi)存中一次。在生產(chǎn)系統(tǒng)中,為了能夠大規(guī)模地處理請(qǐng)求,必須高效地使用計(jì)算,因此通常應(yīng)該在處理請(qǐng)求之前加載模型。
在我們的API服務(wù)器中集成模型
在最后一部分中,我們將把模型添加到Flask API服務(wù)器中。由于我們的API服務(wù)器應(yīng)該接受一個(gè)圖像文件,我們將更新我們的預(yù)測(cè)方法來從請(qǐng)求中讀取文件:
from flask import request
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
# we will get the file from the request
file = request.files['file']
# convert that to bytes
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
app.py文件現(xiàn)在已經(jīng)完成。以下是完整版本;將路徑替換為您保存文件的路徑,它應(yīng)該運(yùn)行:
import io
import json
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
讓我們測(cè)試一下我們的web服務(wù)器!運(yùn)行:
$ FLASK_ENV=development FLASK_APP=app.py flask run
我們可以使用requests庫(kù)向我們的應(yīng)用發(fā)送POST請(qǐng)求:
import requests
resp = requests.post("http://localhost:5000/predict",
files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
打印rep .json()將顯示以下內(nèi)容:文章來源:http://www.zghlxwxcb.cn/news/detail-564507.html
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
下一個(gè)步驟
我們編寫的服務(wù)器非常簡(jiǎn)單,可能無法完成生產(chǎn)應(yīng)用程序所需的所有功能。所以,這里有一些你可以做的事情來讓它變得更好:文章來源地址http://www.zghlxwxcb.cn/news/detail-564507.html
- 請(qǐng)求路徑 /predict假定請(qǐng)求中總是有一個(gè)圖像文件。這可能并不適用于所有請(qǐng)求。我們的用戶可以發(fā)送帶有不同參數(shù)的圖像或根本不發(fā)送圖像。
- 用戶也可以發(fā)送非圖像類型的文件。由于我們不處理錯(cuò)誤,這將破壞我們的服務(wù)器。顯式添加異常的錯(cuò)誤處理路徑,將使我們能夠更好地處理錯(cuò)誤輸入
- 盡管該模型可以識(shí)別大量的圖像類別,但它可能無法識(shí)別所有的圖像。優(yōu)化實(shí)現(xiàn)以處理模型無法識(shí)別圖像中的任何內(nèi)容的情況。
- 我們以開發(fā)模式運(yùn)行Flask服務(wù)器,這種模式不適合部署到生產(chǎn)環(huán)境中。您可以查看本教程,了解如何在生產(chǎn)環(huán)境中部署Flask服務(wù)器。(this tutorial )
- 您還可以通過創(chuàng)建帶有表單的頁(yè)面來添加UI,該表單接受圖像并顯示預(yù)測(cè)結(jié)果。請(qǐng)查看類似項(xiàng)目的演示及其源代碼。(source code.)
- 在本教程中,我們只展示了如何構(gòu)建一個(gè)每次可以返回單個(gè)圖像預(yù)測(cè)的服務(wù)。我們可以修改我們的服務(wù),使其能夠一次返回多個(gè)圖像的預(yù)測(cè)結(jié)果。此外,service-streamer庫(kù)會(huì)自動(dòng)將請(qǐng)求排隊(duì)到您的服務(wù)中,并將它們采樣到可以饋送到模型中的小批量中。您可以查看本教程(this tutorial.)。
- 最后,我們鼓勵(lì)您查看頁(yè)面頂部鏈接的關(guān)于部署PyTorch模型的其他教程.
到了這里,關(guān)于PyTorch翻譯官網(wǎng)教程-DEPLOYING PYTORCH IN PYTHON VIA A REST API WITH FLASK的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!