42 錨框【動(dòng)手學(xué)深度學(xué)習(xí)v2】

邊框的加入使得圖像處理任務(wù)變得復(fù)雜了許多
邊緣框表示的是物體的真實(shí)位置。
錨框則是算法對(duì)物體位置的猜測(cè)。具體步驟見(jiàn)下:

計(jì)算兩個(gè)框的相似度方法:IoU-交并比
運(yùn)算方法是兩框交集除以兩框并集

我的理解應(yīng)該是拿框住的范圍與真實(shí)物體框的范圍求雅可比值,再設(shè)置一個(gè)雅可比值,小于則負(fù)類(背景)大于(關(guān)聯(lián))

例子:一個(gè)圖片中含有4個(gè)邊緣框(對(duì)應(yīng)下圖中列數(shù))和9個(gè)錨框(對(duì)應(yīng)下圖中行數(shù))。則矩陣中的每個(gè)值就是該邊緣框與錨框的IoU值。
1、找出該矩陣中的最高值,即相似度最高的一對(duì)邊緣框和錨框,并將該數(shù)值所在的行與列中的其他數(shù)值刪掉。
2、尋找剩余數(shù)值中的最大值。
3、直到所有邊緣框都找到了對(duì)應(yīng)的錨框,則剩下的錨框都為負(fù)類(背景)。
注:一個(gè)錨框就是一個(gè)訓(xùn)練樣本

非極大值抑制:將多個(gè)相似的框進(jìn)行合并
具體操作方法:找到非背景類的最大預(yù)測(cè)值,去除與它相似度較高的其他框


代碼實(shí)現(xiàn)
%matplotlib inline import torch from d2l import torch as d2l torch.set_printoptions(2) # 精簡(jiǎn)輸出精度
錨框的生成:以一個(gè)像素為中心生成多個(gè)高寬不同的框

#@save def multibox_prior(data, sizes, ratios): """生成以每個(gè)像素為中心具有不同形狀的錨框""" in_height, in_width = data.shape[-2:] device, num_sizes, num_ratios = data.device, len(sizes), len(ratios) boxes_per_pixel = (num_sizes + num_ratios - 1) size_tensor = torch.tensor(sizes, device=device) ratio_tensor = torch.tensor(ratios, device=device) # 為了將錨點(diǎn)移動(dòng)到像素的中心,需要設(shè)置偏移量。 # 因?yàn)橐粋€(gè)像素的高為1且寬為1,我們選擇偏移我們的中心0.5 offset_h, offset_w = 0.5, 0.5 steps_h = 1.0 / in_height # 在y軸上縮放步長(zhǎng) steps_w = 1.0 / in_width # 在x軸上縮放步長(zhǎng) # 生成錨框的所有中心點(diǎn) center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij') shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1) # 生成“boxes_per_pixel”個(gè)高和寬, # 之后用于創(chuàng)建錨框的四角坐標(biāo)(xmin,xmax,ymin,ymax) w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]), sizes[0] * torch.sqrt(ratio_tensor[1:])))\ * in_height / in_width # 處理矩形輸入 h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]), sizes[0] / torch.sqrt(ratio_tensor[1:]))) # 除以2來(lái)獲得半高和半寬 anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat( in_height * in_width, 1) / 2 # 每個(gè)中心點(diǎn)都將有“boxes_per_pixel”個(gè)錨框, # 所以生成含所有錨框中心的網(wǎng)格,重復(fù)了“boxes_per_pixel”次 out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1).repeat_interleave(boxes_per_pixel, dim=0) output = out_grid + anchor_manipulations return output.unsqueeze(0)
可以看到返回的錨框變量Y
的形狀是(批量大小,錨框的數(shù)量,4)。
img = d2l.plt.imread('../img/catdog.jpg') h, w = img.shape[:2] print(h, w) X = torch.rand(size=(1, 3, h, w)) Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5]) Y.shape
561 728 torch.Size([1, 2042040, 4])
2042040是錨框數(shù)
def show_bboxes(axes, bboxes, labels=None, colors=None): """顯示所有邊界框""" def _make_list(obj, default_values=None): if obj is None: obj = default_values elif not isinstance(obj, (list, tuple)): obj = [obj] return obj labels = _make_list(labels) colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c']) for i, bbox in enumerate(bboxes): color = colors[i % len(colors)] rect = d2l.bbox_to_rect(bbox.detach().numpy(), color) axes.add_patch(rect) if labels and len(labels) > i: text_color = 'k' if color == 'w' else 'w' axes.text(rect.xy[0], rect.xy[1], labels[i], va='center', ha='center', fontsize=9, color=text_color, bbox=dict(facecolor=color, lw=0))
d2l.set_figsize() bbox_scale = torch.tensor((w, h, w, h)) fig = d2l.plt.imshow(img) show_bboxes(fig.axes, boxes[250, 250, :, :] * bbox_scale, ['s=0.75, r=1', 's=0.5, r=1', 's=0.25, r=1', 's=0.75, r=2', 's=0.75, r=0.5'])
知識(shí)補(bǔ)充:
一個(gè)錨框只能對(duì)應(yīng)一個(gè)真實(shí)框
錨框是以像素為中心,變換高度和寬度生成的
標(biāo)簽: