diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index aa6df245480cb..ddcc1db84b752 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -3009,63 +3009,18 @@ def generate_proposals(scores, im_info, anchors, variances) """ - if _non_static_mode(): - assert return_rois_num, "return_rois_num should be True in dygraph mode." - attrs = ('pre_nms_topN', pre_nms_top_n, 'post_nms_topN', post_nms_top_n, - 'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta) - rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals( - scores, bbox_deltas, im_info, anchors, variances, *attrs) - return rpn_rois, rpn_roi_probs, rpn_rois_num - - helper = LayerHelper('generate_proposals', **locals()) - - check_variable_and_dtype(scores, 'scores', ['float32'], - 'generate_proposals') - check_variable_and_dtype(bbox_deltas, 'bbox_deltas', ['float32'], - 'generate_proposals') - check_variable_and_dtype(im_info, 'im_info', ['float32', 'float64'], - 'generate_proposals') - check_variable_and_dtype(anchors, 'anchors', ['float32'], - 'generate_proposals') - check_variable_and_dtype(variances, 'variances', ['float32'], - 'generate_proposals') - - rpn_rois = helper.create_variable_for_type_inference( - dtype=bbox_deltas.dtype) - rpn_roi_probs = helper.create_variable_for_type_inference( - dtype=scores.dtype) - outputs = { - 'RpnRois': rpn_rois, - 'RpnRoiProbs': rpn_roi_probs, - } - if return_rois_num: - rpn_rois_num = helper.create_variable_for_type_inference(dtype='int32') - rpn_rois_num.stop_gradient = True - outputs['RpnRoisNum'] = rpn_rois_num - - helper.append_op(type="generate_proposals", - inputs={ - 'Scores': scores, - 'BboxDeltas': bbox_deltas, - 'ImInfo': im_info, - 'Anchors': anchors, - 'Variances': variances - }, - attrs={ - 'pre_nms_topN': pre_nms_top_n, - 'post_nms_topN': post_nms_top_n, - 'nms_thresh': nms_thresh, - 'min_size': min_size, - 'eta': eta - }, - outputs=outputs) - rpn_rois.stop_gradient = True - rpn_roi_probs.stop_gradient = True - - if return_rois_num: - return rpn_rois, rpn_roi_probs, rpn_rois_num - else: - return rpn_rois, rpn_roi_probs + return paddle.vision.ops.generate_proposals(scores=scores, + bbox_deltas=bbox_deltas, + img_size=im_info[:2], + anchors=anchors, + variances=variances, + pre_nms_top_n=pre_nms_top_n, + post_nms_top_n=post_nms_top_n, + nms_thresh=nms_thresh, + min_size=min_size, + eta=eta, + return_rois_num=return_rois_num, + name=name) def box_clip(input, im_info, name=None): diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py b/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py index 32d7d308e5392..b1a4b45d7d257 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py @@ -254,6 +254,99 @@ def init_test_params(self): self.pixel_offset = False +class testGenerateProposalsAPI(unittest.TestCase): + + def setUp(self): + np.random.seed(678) + self.scores_np = np.random.rand(2, 3, 4, 4).astype('float32') + self.bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32') + self.img_size_np = np.array([[8, 8], [6, 6]]).astype('float32') + self.anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4), + [4, 4, 3, 4]).astype('float32') + self.variances_np = np.ones((4, 4, 3, 4)).astype('float32') + + self.roi_expected, self.roi_probs_expected, self.rois_num_expected = generate_proposals_v2_in_python( + self.scores_np, + self.bbox_deltas_np, + self.img_size_np, + self.anchors_np, + self.variances_np, + pre_nms_topN=10, + post_nms_topN=5, + nms_thresh=0.5, + min_size=0.1, + eta=1.0, + pixel_offset=False) + self.roi_expected = np.array(self.roi_expected).squeeze(1) + self.roi_probs_expected = np.array(self.roi_probs_expected).squeeze(1) + self.rois_num_expected = np.array(self.rois_num_expected) + + def test_dynamic(self): + paddle.disable_static() + scores = paddle.to_tensor(self.scores_np) + bbox_deltas = paddle.to_tensor(self.bbox_deltas_np) + img_size = paddle.to_tensor(self.img_size_np) + anchors = paddle.to_tensor(self.anchors_np) + variances = paddle.to_tensor(self.variances_np) + + rois, roi_probs, rois_num = paddle.vision.ops.generate_proposals( + scores, + bbox_deltas, + img_size, + anchors, + variances, + pre_nms_top_n=10, + post_nms_top_n=5, + return_rois_num=True) + self.assertTrue(np.allclose(self.roi_expected, rois.numpy())) + self.assertTrue(np.allclose(self.roi_probs_expected, roi_probs.numpy())) + self.assertTrue(np.allclose(self.rois_num_expected, rois_num.numpy())) + + def test_static(self): + paddle.enable_static() + scores = paddle.static.data(name='scores', + shape=[2, 3, 4, 4], + dtype='float32') + bbox_deltas = paddle.static.data(name='bbox_deltas', + shape=[2, 12, 4, 4], + dtype='float32') + img_size = paddle.static.data(name='img_size', + shape=[2, 2], + dtype='float32') + anchors = paddle.static.data(name='anchors', + shape=[4, 4, 3, 4], + dtype='float32') + variances = paddle.static.data(name='variances', + shape=[4, 4, 3, 4], + dtype='float32') + rois, roi_probs, rois_num = paddle.vision.ops.generate_proposals( + scores, + bbox_deltas, + img_size, + anchors, + variances, + pre_nms_top_n=10, + post_nms_top_n=5, + return_rois_num=True) + exe = paddle.static.Executor() + rois, roi_probs, rois_num = exe.run( + paddle.static.default_main_program(), + feed={ + 'scores': self.scores_np, + 'bbox_deltas': self.bbox_deltas_np, + 'img_size': self.img_size_np, + 'anchors': self.anchors_np, + 'variances': self.variances_np, + }, + fetch_list=[rois.name, roi_probs.name, rois_num.name], + return_numpy=False) + + self.assertTrue(np.allclose(self.roi_expected, np.array(rois))) + self.assertTrue( + np.allclose(self.roi_probs_expected, np.array(roi_probs))) + self.assertTrue(np.allclose(self.rois_num_expected, np.array(rois_num))) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 545ba25f5b420..cdb8417b6b9c2 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -29,6 +29,7 @@ 'deform_conv2d', 'DeformConv2D', 'distribute_fpn_proposals', + 'generate_proposals', 'read_file', 'decode_jpeg', 'roi_pool', @@ -1658,3 +1659,146 @@ def _nms(boxes, iou_threshold): return keep_boxes_idxs[topk_sub_indices] return keep_boxes_idxs[sorted_sub_indices][:top_k] + + +def generate_proposals(scores, + bbox_deltas, + img_size, + anchors, + variances, + pre_nms_top_n=6000, + post_nms_top_n=1000, + nms_thresh=0.5, + min_size=0.1, + eta=1.0, + pixel_offset=False, + return_rois_num=False, + name=None): + """ + This operation proposes RoIs according to each box with their + probability to be a foreground object. And + the proposals of RPN output are calculated by anchors, bbox_deltas and scores. Final proposals + could be used to train detection net. + + For generating proposals, this operation performs following steps: + + 1. Transpose and resize scores and bbox_deltas in size of + (H * W * A, 1) and (H * W * A, 4) + 2. Calculate box locations as proposals candidates. + 3. Clip boxes to image + 4. Remove predicted boxes with small area. + 5. Apply non-maximum suppression (NMS) to get final proposals as output. + + Args: + scores (Tensor): A 4-D Tensor with shape [N, A, H, W] represents + the probability for each box to be an object. + N is batch size, A is number of anchors, H and W are height and + width of the feature map. The data type must be float32. + bbox_deltas (Tensor): A 4-D Tensor with shape [N, 4*A, H, W] + represents the difference between predicted box location and + anchor location. The data type must be float32. + img_size (Tensor): A 2-D Tensor with shape [N, 2] represents origin + image shape information for N batch, including height and width of the input sizes. + The data type can be float32 or float64. + anchors (Tensor): A 4-D Tensor represents the anchors with a layout + of [H, W, A, 4]. H and W are height and width of the feature map, + num_anchors is the box count of each position. Each anchor is + in (xmin, ymin, xmax, ymax) format an unnormalized. The data type must be float32. + variances (Tensor): A 4-D Tensor. The expanded variances of anchors with a layout of + [H, W, num_priors, 4]. Each variance is in + (xcenter, ycenter, w, h) format. The data type must be float32. + pre_nms_top_n (float, optional): Number of total bboxes to be kept per + image before NMS. `6000` by default. + post_nms_top_n (float, optional): Number of total bboxes to be kept per + image after NMS. `1000` by default. + nms_thresh (float, optional): Threshold in NMS. The data type must be float32. `0.5` by default. + min_size (float, optional): Remove predicted boxes with either height or + width less than this value. `0.1` by default. + eta(float, optional): Apply in adaptive NMS, only works if adaptive `threshold > 0.5`, + `adaptive_threshold = adaptive_threshold * eta` in each iteration. 1.0 by default. + pixel_offset (bool, optional): Whether there is pixel offset. If True, the offset of `img_size` will be 1. 'False' by default. + return_rois_num (bool, optional): Whether to return `rpn_rois_num` . When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's + num of each image in one batch. 'False' by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + - rpn_rois (Tensor): The generated RoIs. 2-D Tensor with shape ``[N, 4]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. + - rpn_roi_probs (Tensor): The scores of generated RoIs. 2-D Tensor with shape ``[N, 1]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. + - rpn_rois_num (Tensor): Rois's num of each image in one batch. 1-D Tensor with shape ``[B,]`` while ``B`` is the batch size. And its sum equals to RoIs number ``N`` . + + Examples: + .. code-block:: python + + import paddle + + scores = paddle.rand((2,4,5,5), dtype=paddle.float32) + bbox_deltas = paddle.rand((2, 16, 5, 5), dtype=paddle.float32) + img_size = paddle.to_tensor([[224.0, 224.0], [224.0, 224.0]]) + anchors = paddle.rand((2,5,4,4), dtype=paddle.float32) + variances = paddle.rand((2,5,10,4), dtype=paddle.float32) + rois, roi_probs, roi_nums = paddle.vision.ops.generate_proposals(scores, bbox_deltas, + img_size, anchors, variances, return_rois_num=True) + print(rois, roi_probs, roi_nums) + """ + + if _non_static_mode(): + assert return_rois_num, "return_rois_num should be True in dygraph mode." + attrs = ('pre_nms_topN', pre_nms_top_n, 'post_nms_topN', post_nms_top_n, + 'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta, + 'pixel_offset', pixel_offset) + rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals_v2( + scores, bbox_deltas, img_size, anchors, variances, *attrs) + + return rpn_rois, rpn_roi_probs, rpn_rois_num + + helper = LayerHelper('generate_proposals_v2', **locals()) + + check_variable_and_dtype(scores, 'scores', ['float32'], + 'generate_proposals_v2') + check_variable_and_dtype(bbox_deltas, 'bbox_deltas', ['float32'], + 'generate_proposals_v2') + check_variable_and_dtype(img_size, 'img_size', ['float32', 'float64'], + 'generate_proposals_v2') + check_variable_and_dtype(anchors, 'anchors', ['float32'], + 'generate_proposals_v2') + check_variable_and_dtype(variances, 'variances', ['float32'], + 'generate_proposals_v2') + + rpn_rois = helper.create_variable_for_type_inference( + dtype=bbox_deltas.dtype) + rpn_roi_probs = helper.create_variable_for_type_inference( + dtype=scores.dtype) + outputs = { + 'RpnRois': rpn_rois, + 'RpnRoiProbs': rpn_roi_probs, + } + if return_rois_num: + rpn_rois_num = helper.create_variable_for_type_inference(dtype='int32') + rpn_rois_num.stop_gradient = True + outputs['RpnRoisNum'] = rpn_rois_num + + helper.append_op(type="generate_proposals_v2", + inputs={ + 'Scores': scores, + 'BboxDeltas': bbox_deltas, + 'ImShape': img_size, + 'Anchors': anchors, + 'Variances': variances + }, + attrs={ + 'pre_nms_topN': pre_nms_top_n, + 'post_nms_topN': post_nms_top_n, + 'nms_thresh': nms_thresh, + 'min_size': min_size, + 'eta': eta, + 'pixel_offset': pixel_offset + }, + outputs=outputs) + rpn_rois.stop_gradient = True + rpn_roi_probs.stop_gradient = True + if not return_rois_num: + rpn_rois_num = None + + return rpn_rois, rpn_roi_probs, rpn_rois_num