下文中的代碼都使用參考教程中的例子。
會(huì)給出一點(diǎn)自己的解釋。
參考教程:
mixing-scripting-and-tracing
script and optimize for mobile recipe
https://pytorch.org/docs/stable/jit.html
OPTIMIZING VISION TRANSFORMER MODEL FOR DEPLOYMEN
Introduction
我們訓(xùn)練好并保存的pytorch,支持在python語(yǔ)言下的使用,但是不支持在一些c++語(yǔ)言下使用。為了能讓我們的模型在high-performance environment c++環(huán)境下使用,我們需要對(duì)模型進(jìn)行格式轉(zhuǎn)換。
好消息!torch本身是有模型格式轉(zhuǎn)換的功能的,所以我們不需要下載額外的包,就可以把它轉(zhuǎn)為能在c++使用的torchscript模型。
復(fù)習(xí)一下nn.Module()
之前的章節(jié)中有講過(guò),torch中所有模型都是基于nn.Module()這個(gè)類,模型的定義都繼承了這個(gè)類的屬性與方法。
一個(gè)完整的模型要包括以下三個(gè)基本的部分:
- 一個(gè)構(gòu)造函數(shù),用于調(diào)用模型模塊
- parameters和sub-modules。它們?cè)跇?gòu)造函數(shù)中被初始化,并能在調(diào)用中被使用。
- forward()函數(shù),決定了模型調(diào)用的順序。
教程中給出了下面一個(gè)簡(jiǎn)單的例子。
例子中定義了一個(gè)名為MyCell的類,它繼承了torch.nn.Module()的功能。因?yàn)檫@個(gè)模型中沒(méi)有需要訓(xùn)練的參數(shù)和網(wǎng)絡(luò)層,所以先跳過(guò)parameters和sub-modules這一步。要注意這里使用了super,調(diào)用了父類的構(gòu)造函數(shù)。
在forward()的部分,該方法的傳入?yún)?shù)為x和h(忽略了self)。計(jì)算過(guò)程中只使用了torch.tanh(x+h),這一步?jīng)]有參數(shù)需要更新。返回的結(jié)果為new_h。
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
def forward(self, x, h):
new_h = torch.tanh(x + h)
return new_h
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))
接下來(lái)對(duì)這個(gè)小模型進(jìn)行一些改動(dòng),增加一些需要訓(xùn)練的參數(shù)。在教程例子中,它給這個(gè)模型增加了一個(gè)線性層。
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4,4) # 在這部分增加了一個(gè)線性層
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h) # 在調(diào)用的時(shí)候也使用了線性層,這里的參數(shù)需要在訓(xùn)練中更新
return new_h, new_h
my_cell = MyCell()
x = torch.rand(3,4)
h = torch.rand(3,4)
print(my_cell(x,h))
可以看一下我們輸出的結(jié)果中,多了一個(gè)grad_fn,之前我們?cè)?jīng)解釋過(guò),這個(gè)是反向傳播中梯度計(jì)算的方法,因?yàn)楝F(xiàn)在有了要學(xué)習(xí)的參數(shù),所以增加了這個(gè)方法。
pytorch具有很高的靈活性。在教程中提到了重要的一點(diǎn)是,很多框架都會(huì)在給出完整定義的情況下再進(jìn)行求導(dǎo)的計(jì)算,而在pytorch中不是的,pytorch會(huì)在計(jì)算進(jìn)行的時(shí)候記錄這個(gè)操作,并在求導(dǎo)的過(guò)程中replay。所以pytorch時(shí)并沒(méi)有很明確的對(duì)這些求導(dǎo)操作做出定義。
我自己也不是太理解這些話。我的個(gè)人理解是在backwards過(guò)程中tensor的grad_fn是隨著當(dāng)前步更新的,而不是預(yù)設(shè)好的。下面放出原文。
Many frameworks take the approach of computing symbolic derivatives given a full program representation. However, in PyTorch, we use a gradient tape. We record operations as they occur, and replay them backwards in computing derivatives. In this way, the framework does not have to explicitly define derivatives for all constructs in the language.
Torchscript
torchscript的作用就是根據(jù)pytorch code來(lái)創(chuàng)建一個(gè)模型,這個(gè)模型可以在非python環(huán)境下被使用。所以在pytorch中訓(xùn)練的模型,能夠很容易地被應(yīng)用到一個(gè)非python依賴的生產(chǎn)環(huán)境中去。
我們先來(lái)看一下代碼,熟悉一下其中的方法的作用。
torch.jit.ScriptModule()
ScriptModule()也繼承了nn.Module()類,所以它也有很多和nn.Module()一樣的方法。比如children(),named_children()等。
它還包括一些神秘的方法。比如:
PROPERTY code 返回forward()函數(shù)中代碼。這個(gè)功能是nn.Module()中沒(méi)有的。
PROPERTY graph 返回forward()函數(shù)中的graph。
torch.jit.script()
torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)
script() 的作用是檢查一個(gè)function或者nn.Module()的源碼,并把它編譯成torchscript code并返回一個(gè)ScriptModule或者ScriptFunctions。
TorchScript本身是python language的一個(gè)子集,所以它并不能完全支持python中的所有功能,但是一些模型相關(guān)的計(jì)算它都是支持的。
更詳細(xì)的介紹可以參考。
https://pytorch.org/docs/stable/jit_language_reference.html#language-reference
里面提到了一些對(duì)torchscript的限制,比如函數(shù)中的參數(shù)類型是不可以發(fā)生改變的,在python語(yǔ)言中你可以判斷參數(shù)的種類并作出對(duì)應(yīng)的操作,在torchscript中這是一個(gè)錯(cuò)誤操作。torchscript中的參數(shù)為做特別說(shuō)明的情況下,均默認(rèn)為tensor。
這里的輸入可以是一個(gè)function也可以是一個(gè)nn.Module(),要注意這里的example_inputs是有格式要求的:
(Union[List[Tuple], Dict[Callable, List[Tuple]], None])。
我們對(duì)我們定義的MyCell進(jìn)行script,輸入是一個(gè)nn.Module(),返回結(jié)果是一個(gè)ScriptModule()。
torch.jit.trace()
torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)
torch.jit.trace()可以對(duì)一個(gè)function進(jìn)行追蹤,并返回一個(gè)可執(zhí)行object或者一個(gè)ScriptFunction。你必須提供一個(gè)example_inputs。
- The resulting recording of a standalone function produces ScriptFunction.
- The resulting recording of nn.Module.forward or nn.Module produces ScriptModule.
當(dāng)傳入的是一個(gè)普調(diào)的function時(shí),如下圖,返回的結(jié)果是一個(gè)scriptfunction。
不管傳入的是nn.Module還是它本身的forward函數(shù),返回的結(jié)果都是一樣的。
一些注意事項(xiàng)
trace方法和script方法存在一些區(qū)別。比如說(shuō)trace方法只會(huì)跟蹤你的輸入走過(guò)的路徑,當(dāng)你的模型中存在if-else或者別的分支時(shí),trace方法不會(huì)記錄你的輸入沒(méi)有經(jīng)過(guò)的那些分支。而script方法會(huì)分析你的源碼,并進(jìn)行完整的記錄。
這樣看起來(lái),似乎只要使用script方法就足夠了,完全沒(méi)必須要使用trace方法。接下來(lái)我們來(lái)看一看使用script方法可能遇到哪些問(wèn)題。
- RuntimeError attribute lookup is not defined on python value of type
當(dāng)被script的模型接收另一個(gè)模型作為參數(shù)時(shí),這個(gè)模型傳入的類型實(shí)際上是TracedModule或者ScriptModule。這種情況下,被script的模型無(wú)法使用另一個(gè)模型中的一些Module格式下可以使用的參數(shù)。比如說(shuō)它想要使用 model2.n_layers,這樣就會(huì)出現(xiàn)錯(cuò)誤,它應(yīng)該把n_layers作為參數(shù)傳進(jìn)去。 - RuntimeError python value of type ‘…’ cannot be used as a value.
使用全局變量時(shí)會(huì)出現(xiàn)這種問(wèn)題。 - RuntimeError all inputs of range must be ‘…’, found Tensor (inferred) in argument。
torchscript函數(shù)默認(rèn)的參數(shù)類型都是torch.tensor。當(dāng)你想使用別的類型時(shí),你需要明確的給出指定。比如
def forward(self, input_seq, input_length, max_length : int):
使用示例
tracing Modules
torchscript提供了一個(gè)方法,幫你獲取你的模型的完整定義。首先來(lái)看一下tracing方法的作用。
使用上方定義的帶線性層的小模型。
來(lái)看一下jit.trace做了什么操作,它首先傳入了my_cell,然后傳入了對(duì)應(yīng)的輸入。trace方法會(huì)調(diào)用這個(gè)Module,并且記錄其中的每一步操作,并創(chuàng)造一個(gè)ScriptModule的實(shí)例。
我們可以看一下它的code。
使用trace方法會(huì)有一些天然的缺陷。它追蹤了你的輸入在function中經(jīng)過(guò)的每一步操作,所以如果你的function中存在判斷語(yǔ)句時(shí),未被觸發(fā)的操作就會(huì)被忽略掉。
使用教程中給出的例子。
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.dg.code)
print(traced_cell.code)
在這個(gè)例子中,MyDecisionGate函數(shù)進(jìn)行了一個(gè)判斷,假如傳入的x的總和大于0,就返回x本身,假如x的總和小于0,就返回-x。
Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
我們可以看到因?yàn)槲覀兊妮斎氩⒉荒茏哌^(guò)if-else的兩條路徑,所以我們trace的結(jié)果中也只有一條路。我們的if-else方法不見了。
scripting Module
在上面的trace方法中,它對(duì)你的輸入走過(guò)的路徑進(jìn)行記錄,所以它看不到輸入沒(méi)有經(jīng)過(guò)的地方。而我們的第二個(gè)方法,script() 則是直接對(duì)你的源碼進(jìn)行分析,所以能夠保留比較完整的結(jié)果。
Mixing scripting and tracing
假如你的代碼中有些不希望被torch.jit.script記錄的常量,你可以使用trace和script的組合,將這些常量隱藏。
對(duì)這部分的理解是,對(duì)于有多個(gè)分支并且又有你想要隱藏的參數(shù)的情況下,可以使用trace和script的組合。多分支的部分用script記錄,隱藏參數(shù)的部分用trace記錄。
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
scripted_gate = torch.jit.script(MyDecisionGate())
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
第一個(gè)例子,torch.jit.script和traced module內(nèi)聯(lián)。
class MyRNNLoop(torch.nn.Module):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
def forward(self, xs):
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
第二個(gè)例子,torch.jit.trace()和scripted module內(nèi)聯(lián)。
class WrapRNN(torch.nn.Module):
def __init__(self):
super(WrapRNN, self).__init__()
self.loop = torch.jit.script(MyRNNLoop())
def forward(self, xs):
y, h = self.loop(xs)
return torch.relu(y)
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
我們觀察一下第二個(gè)例子,比較一下最后使用jit.trace和jit.script有什么區(qū)別。
大家可以看到使用trace時(shí),loop的返回結(jié)果是_0, y;使用script時(shí),lopp返回的結(jié)果是y, h。
保存和加載模型
torchscript可以將模型獨(dú)立地保存下來(lái),保存的信息包括模型的code,parameters, attribute和debug information。這些完整的信息讓我們的模型可以獨(dú)立地表達(dá),并在一個(gè)完全不同的進(jìn)程中被加載,下面給出了代碼例子。
traced.save('wrapped_rnn.pt')
loaded = torch.jit.load('wrapped_rnn.pt')
print(loaded)
print(loaded.code)
實(shí)踐與優(yōu)化
放一下源碼的鏈接OPTIMIZING VISION TRANSFORMER MODEL FOR DEPLOYMENT。鏈接里內(nèi)容更詳細(xì),有條件的直接看源碼。我只是crop出來(lái)了中間和torchscript相關(guān)的部分。
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# should be 1.8.0
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
ten = transform(img)[None,]
out = model(ten)
clsidx = torch.argmax(out)
print(clsidx.item())
將模型以script 的形式保存下來(lái)文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-495399.html
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
比較一下兩者的時(shí)間,兩者在時(shí)間上是沒(méi)有什么明顯差別的。在教程中使用了一些模型加速的方法,所以inference的時(shí)間會(huì)變快。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-495399.html
到了這里,關(guān)于第九章 番外篇:TORCHSCRIPT的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!