【圖像分割】【深度學(xué)習(xí)】SAM官方Pytorch代碼-各功能模塊解析
Segment Anything:建立了迄今為止最大的分割數(shù)據(jù)集,在1100萬(wàn)張圖像上有超過(guò)1億個(gè)掩碼,模型的設(shè)計(jì)和訓(xùn)練是靈活的,其重要的特點(diǎn)是Zero-shot(零樣本遷移性)轉(zhuǎn)移到新的圖像分布和任務(wù),一個(gè)圖像分割新的任務(wù)、模型和數(shù)據(jù)集。SAM由三個(gè)部分組成:一個(gè)強(qiáng)大的圖像編碼器(Image encoder)計(jì)算圖像嵌入,一個(gè)提示編碼器(Prompt encoder)嵌入提示,然后將兩個(gè)信息源組合在一個(gè)輕量級(jí)掩碼解碼器(Mask decoder)中來(lái)預(yù)測(cè)分割掩碼。本博客將大致講解SAM各模塊的功能。
前言
在詳細(xì)解析SAM代碼之前,首要任務(wù)是成功運(yùn)行SAM代碼【win10下參考教程】,后續(xù)學(xué)習(xí)才有意義。本博客將大致講解各個(gè)子模塊的功能代碼,暫時(shí)不會(huì)詳細(xì)講解神經(jīng)網(wǎng)絡(luò)的代碼部分。
博主將各功能模塊的代碼在不同的博文中進(jìn)行了詳細(xì)的解析,點(diǎn)擊【win10下參考教程】,博文的目錄鏈接放在前言部分。
模型加載
博主以【SAM官方代碼示例】為例,源碼提供了3種不同大小的模型。
# 選擇合適的模型以及加載對(duì)應(yīng)權(quán)重
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
sam_model_registry函數(shù)在segment_anything/build_sam.py文件內(nèi)定義
SAM的3種模型通過(guò)字典形式保存。
sam_model_registry = {
"default": build_sam_vit_h,
"vit_h": build_sam_vit_h,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
}
sam_model_registry中的3種模型結(jié)構(gòu)是一致的,部分參數(shù)不同導(dǎo)致模型的大小有別。
def build_sam_vit_h(checkpoint=None):
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)
def build_sam_vit_l(checkpoint=None):
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
最后是_build_sam方法,完成了sam模型的初始化以及權(quán)重的加載,這里可以注意到sam模型由三個(gè)神經(jīng)網(wǎng)絡(luò)模塊組成:ImageEncoderViT(Image encoder)、PromptEncoder和MaskDecoder。具體的參數(shù)的作用和意義在后續(xù)的神經(jīng)網(wǎng)絡(luò)的具體的學(xué)習(xí)中講解。
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam
論文中SAM的結(jié)構(gòu)示意圖:
SamPredictor類
sam模型被封裝在SamPredictor類的對(duì)象中,方便使用。
predictor = SamPredictor(sam)
predictor.set_image(image)
image_encoder操作在set_image時(shí)就已經(jīng)執(zhí)行了,而不是在predic時(shí)
SamPredictor類在segment_anything/predictor.py文件:
init
初始化了mask預(yù)測(cè)模型sam,以及數(shù)據(jù)處理工具對(duì)象,重置了圖片相關(guān)數(shù)據(jù)信息(ResizeLongestSide)。
def __init__(
self,
sam_model: Sam,
) -> None:
super().__init__()
# sam mask預(yù)測(cè)模型
self.model = sam_model
# 用于數(shù)據(jù)預(yù)處理
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
# 圖片相關(guān)數(shù)據(jù)信息
self.reset_image()
reset_image
self.is_image_set與 self.features息息相關(guān),self.features保存圖片經(jīng)過(guò)Image encoder后的特征數(shù)據(jù),self.is_image_set是一個(gè)信號(hào)信息,用來(lái)表示self.features是否已經(jīng)保存了特征數(shù)據(jù),在剛初始化時(shí),self.features是none,self.is_image_set便是false。
def reset_image(self) -> None:
# 圖像設(shè)置flag
self.is_image_set = False
# 圖像編碼特征
self.features = None
self.orig_h = None
self.orig_w = None
self.input_h = None
self.input_w = None
set_image
首先確認(rèn)輸入是否是RGB或BGR三通道圖像,將BGR圖像統(tǒng)一為RGB,而后并對(duì)圖像尺寸(apply_image)和channel順序作出調(diào)整滿足神經(jīng)網(wǎng)絡(luò)的輸入要求。
def set_image(
self,
image: np.ndarray,
image_format: str = "RGB",
) -> None:
# 圖像不是['RGB', 'BGR']格式則報(bào)錯(cuò)
assert image_format in [
"RGB",
"BGR",
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
# H,W,C
if image_format != self.model.image_format:
image = image[..., ::-1] # H,W,C中 C通道的逆序RGB-->BGR
# Transform the image to the form expected by the model 改變圖像尺寸
input_image = self.transform.apply_image(image)
# torch 淺拷貝 轉(zhuǎn)tensor
input_image_torch = torch.as_tensor(input_image, device=self.device)
# permute H,W,C-->C,H,W
# contiguous 連續(xù)內(nèi)存
# [None, :, :, :] C,H,W -->1,C,H,W
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
self.set_torch_image(input_image_torch, image.shape[:2])
set_torch_image
用padding填補(bǔ)縮放后的圖片,在H和W滿足神經(jīng)網(wǎng)絡(luò)需要的標(biāo)準(zhǔn)尺寸,而后通過(guò)image_encoder模型獲得圖像特征數(shù)據(jù)并保存在self.features中,同時(shí)self.is_image_set設(shè)為true。
注意image_encoder過(guò)程不是在predict_torch時(shí)與Prompt encoder過(guò)程和Mask decoder過(guò)程一同執(zhí)行的,而是在set_image時(shí)就已經(jīng)執(zhí)行了。
def set_torch_image(
self,
transformed_image: torch.Tensor,
original_image_size: Tuple[int, ...],
) -> None:
# 滿足輸入是四個(gè)維度且為B,C,H,W
assert (
len(transformed_image.shape) == 4
and transformed_image.shape[1] == 3
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
self.reset_image()
# 原始圖像的尺寸
self.original_size = original_image_size
# torch圖像的尺寸
self.input_size = tuple(transformed_image.shape[-2:])
# torch圖像進(jìn)行padding
input_image = self.model.preprocess(transformed_image)
# image_encoder網(wǎng)絡(luò)模塊對(duì)圖像進(jìn)行編碼
self.features = self.model.image_encoder(input_image)
# 圖像設(shè)置flag
self.is_image_set = True
這里可以暫時(shí)不考慮image_encoder模型的代碼細(xì)節(jié)。
predict
predict對(duì)輸入到模型中進(jìn)行預(yù)測(cè)的數(shù)據(jù)(標(biāo)記點(diǎn)apply_coords和標(biāo)記框apply_boxes)進(jìn)行一個(gè)預(yù)處理,并接受和處理模型返回的預(yù)測(cè)結(jié)果。
def predict(
self,
# 標(biāo)記點(diǎn)的坐標(biāo)
point_coords: Optional[np.ndarray] = None,
# 標(biāo)記點(diǎn)的標(biāo)簽
point_labels: Optional[np.ndarray] = None,
# 標(biāo)記框的坐標(biāo)
box: Optional[np.ndarray] = None,
# 輸入的mask
mask_input: Optional[np.ndarray] = None,
# 輸出多個(gè)mask供選擇
multimask_output: bool = True,
# ture 返回掩碼logits, false返回閾值處理的二進(jìn)制掩碼。
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# 假設(shè)沒(méi)有設(shè)置圖像,報(bào)錯(cuò)
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
# Transform input prompts
# 輸入提示轉(zhuǎn)換為torch
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
if point_coords is not None:
# 標(biāo)記點(diǎn)坐標(biāo)對(duì)應(yīng)的標(biāo)記點(diǎn)標(biāo)簽不能為空
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
# 圖像改變了原始尺寸,所以對(duì)應(yīng)的點(diǎn)位置也會(huì)發(fā)生改變
point_coords = self.transform.apply_coords(point_coords, self.original_size)
# 標(biāo)記點(diǎn)坐標(biāo)和標(biāo)記點(diǎn)標(biāo)簽 np-->tensor
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
# 增加維度
# coords_torch:N,2-->1,N,2
# labels_torch: N-->1,N
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
# 圖像改變了原始尺寸,所以對(duì)應(yīng)的框坐標(biāo)位置也會(huì)發(fā)生改變
box = self.transform.apply_boxes(box, self.original_size)
# 標(biāo)記框坐標(biāo) np-->tensor
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
# 增加維度 N,4-->1,N,4
box_torch = box_torch[None, :]
if mask_input is not None:
# mask np-->tensor
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
# 增加維度 1,H,W-->B,1,H,W
mask_input_torch = mask_input_torch[None, :, :, :]
# 輸入數(shù)據(jù)預(yù)處理完畢,可以輸入到網(wǎng)絡(luò)中
masks, iou_predictions, low_res_masks = self.predict_torch(
coords_torch,
labels_torch,
box_torch,
mask_input_torch,
multimask_output,
return_logits=return_logits,
)
# 因?yàn)閎atchsize為1,壓縮維度
# mask
masks = masks[0].detach().cpu().numpy()
# score
iou_predictions = iou_predictions[0].detach().cpu().numpy()
low_res_masks = low_res_masks[0].detach().cpu().numpy()
return masks, iou_predictions, low_res_masks
源碼在segment_anything/modeling/sam.py內(nèi)
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
# mask上采樣到與輸入到模型中的圖片尺寸一致
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
# mask resize 到與未做處理的原始圖片尺寸一致
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
predict_torch
輸入數(shù)據(jù)經(jīng)過(guò)預(yù)處理后輸入到模型中預(yù)測(cè)結(jié)果。
Prompt encoder過(guò)程和Mask decoder過(guò)程是在predict_torch時(shí)執(zhí)行的。
def predict_torch(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# 假設(shè)沒(méi)有設(shè)置圖像,報(bào)錯(cuò)
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
# 綁定標(biāo)記點(diǎn)和標(biāo)記點(diǎn)標(biāo)簽
if point_coords is not None:
points = (point_coords, point_labels)
else:
points = None
# ----- EPrompt encoder -----
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=boxes,
masks=mask_input,
)
# ----- Prompt encoder -----
# ----- Mask decoder -----
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# ----- Mask decoder -----
# 上采樣mask掩膜到原始圖片尺寸
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
這里可以暫時(shí)不考慮Prompt encoder和Mask decoder模型的代碼細(xì)節(jié)。
get_image_embedding
獲得圖像image_encoder的特征。
def get_image_embedding(self) -> torch.Tensor:
if not self.is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert self.features is not None, "Features must exist if an image has been set."
return self.features
device
獲得模型所使用的設(shè)備
def device(self) -> torch.device:
return self.model.device
ResizeLongestSide類
ResizeLongestSide是專門用來(lái)處理圖片、標(biāo)記點(diǎn)和標(biāo)記框的工具類。
ResizeLongestSide類在segment_anything/utils/transforms.py文件:
init
設(shè)置了所有輸入到神經(jīng)網(wǎng)絡(luò)的標(biāo)準(zhǔn)圖片尺寸
def __init__(self, target_length: int) -> None:
self.target_length = target_length
apply_image
原圖尺寸根據(jù)標(biāo)準(zhǔn)尺寸計(jì)算調(diào)整(get_preprocess_shape)得新尺寸。
def apply_image(self, image: np.ndarray) -> np.ndarray:
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
# to_pil_image將numpy裝變?yōu)镻IL.Image,而后resize
return np.array(resize(to_pil_image(image), target_size))
一個(gè)簡(jiǎn)單的示意圖,通過(guò)計(jì)算獲得與標(biāo)準(zhǔn)尺寸對(duì)應(yīng)的縮放比例并縮放圖片,后續(xù)通過(guò)padding補(bǔ)零操作(虛線部分),將所有圖片的尺寸都變成標(biāo)準(zhǔn)尺寸。
不直接使用resize的目的是為了不破壞原圖片中各個(gè)物體的比例關(guān)系。
apply_coords
圖像改變了原始尺寸,對(duì)應(yīng)的標(biāo)記點(diǎn)坐標(biāo)位置也要改變([get_preprocess_shape](#get_preprocess_shape))。
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
old_h, old_w = original_size
# 圖像改變了原始尺寸,所以對(duì)應(yīng)的標(biāo)記點(diǎn)坐標(biāo)位置也會(huì)發(fā)生改變
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
# 深拷貝coords
coords = deepcopy(coords).astype(float)
# 改變對(duì)應(yīng)標(biāo)記點(diǎn)坐標(biāo)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
apply_boxes
圖像改變了原始尺寸,對(duì)應(yīng)的標(biāo)記框坐標(biāo)位置也要改變([get_preprocess_shape](#get_preprocess_shape))。
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
# 圖像改變了原始尺寸,所以對(duì)應(yīng)的框坐標(biāo)位置也會(huì)發(fā)生改變
# reshape: N,4-->N,2,2
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
# reshape: N,2,2-->N,4
return boxes.reshape(-1, 4)
get_preprocess_shape
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
# H和W的長(zhǎng)邊(大值)作為基準(zhǔn),計(jì)算比例,縮放H W的大小
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
# 四舍五入
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
總結(jié)
盡可能簡(jiǎn)單、詳細(xì)的介紹SAM中各個(gè)子模塊的功能代碼,后續(xù)會(huì)講解SAM中三個(gè)深度學(xué)習(xí)網(wǎng)絡(luò)模塊的代碼。文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-512381.html
強(qiáng)調(diào)一點(diǎn),在預(yù)測(cè)過(guò)程中sam模型是被封裝在SamPredictor類中,將sam的forward預(yù)測(cè)的流程分別拆解到SamPredictor類的不同方法中、分不同階段進(jìn)行。
sam中forward函數(shù)對(duì)Image encoder、Prompt encoder和Mask decoder三個(gè)操作是連續(xù)的,如下圖所示:
源碼暫未開(kāi)源這部分,因此個(gè)人自覺(jué)forward只是訓(xùn)練過(guò)程中使用的,預(yù)測(cè)過(guò)程并未涉及,希望大家不要被搞暈,最后有大佬自己寫train部分的代碼話可以踢我一下。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-512381.html
到了這里,關(guān)于【圖像分割】【深度學(xué)習(xí)】SAM官方Pytorch代碼-各模塊的功能解析的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!