From f6948657f6999fe6ed6843fd32b96fa0d8927861 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 17 Jun 2022 03:27:34 +0000 Subject: [PATCH 1/4] add generate_proposals into paddle.vision --- .../test_generate_proposals_v2_op.py | 180 +++++++++++++ python/paddle/vision/ops.py | 240 +++++++++++++++++- 2 files changed, 407 insertions(+), 13 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py b/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py index 32d7d308e5392..99fd57700b982 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py @@ -254,6 +254,186 @@ def init_test_params(self): self.pixel_offset = False +class testGenerateProposalsFunctionAPI(unittest.TestCase): + + def setUp(self): + np.random.seed(678) + self.scores_np = np.random.rand(2, 3, 4, 4).astype('float32') + self.bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32') + self.im_shape_np = np.array([[8, 8], [6, 6]]).astype('float32') + self.anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4), + [4, 4, 3, 4]).astype('float32') + self.variances_np = np.ones((4, 4, 3, 4)).astype('float32') + + self.roi_expected, self.roi_probs_expected, self.rois_num_expected = generate_proposals_v2_in_python( + self.scores_np, + self.bbox_deltas_np, + self.im_shape_np, + self.anchors_np, + self.variances_np, + pre_nms_topN=10, + post_nms_topN=5, + nms_thresh=0.5, + min_size=0.1, + eta=1.0, + pixel_offset=False) + self.roi_expected = np.array(self.roi_expected).squeeze(1) + self.roi_probs_expected = np.array(self.roi_probs_expected).squeeze(1) + self.rois_num_expected = np.array(self.rois_num_expected) + + def test_dynamic(self): + paddle.disable_static() + scores = paddle.to_tensor(self.scores_np) + bbox_deltas = paddle.to_tensor(self.bbox_deltas_np) + im_shape = paddle.to_tensor(self.im_shape_np) + anchors = paddle.to_tensor(self.anchors_np) + variances = paddle.to_tensor(self.variances_np) + + rois, roi_probs, rois_num = paddle.vision.ops.generate_proposals( + scores, + bbox_deltas, + im_shape, + anchors, + variances, + pre_nms_top_n=10, + post_nms_top_n=5, + return_rois_num=True) + self.assertTrue(np.allclose(self.roi_expected, rois.numpy())) + self.assertTrue(np.allclose(self.roi_probs_expected, roi_probs.numpy())) + self.assertTrue(np.allclose(self.rois_num_expected, rois_num.numpy())) + + def test_static(self): + paddle.enable_static() + scores = paddle.static.data(name='scores', + shape=[2, 3, 4, 4], + dtype='float32') + bbox_deltas = paddle.static.data(name='bbox_deltas', + shape=[2, 12, 4, 4], + dtype='float32') + im_shape = paddle.static.data(name='im_shape', + shape=[2, 2], + dtype='float32') + anchors = paddle.static.data(name='anchors', + shape=[4, 4, 3, 4], + dtype='float32') + variances = paddle.static.data(name='variances', + shape=[4, 4, 3, 4], + dtype='float32') + rois, roi_probs, rois_num = paddle.vision.ops.generate_proposals( + scores, + bbox_deltas, + im_shape, + anchors, + variances, + pre_nms_top_n=10, + post_nms_top_n=5, + return_rois_num=True) + exe = paddle.static.Executor() + rois, roi_probs, rois_num = exe.run( + paddle.static.default_main_program(), + feed={ + 'scores': self.scores_np, + 'bbox_deltas': self.bbox_deltas_np, + 'im_shape': self.im_shape_np, + 'anchors': self.anchors_np, + 'variances': self.variances_np, + }, + fetch_list=[rois.name, roi_probs.name, rois_num.name], + return_numpy=False) + + self.assertTrue(np.allclose(self.roi_expected, np.array(rois))) + self.assertTrue( + np.allclose(self.roi_probs_expected, np.array(roi_probs))) + self.assertTrue(np.allclose(self.rois_num_expected, np.array(rois_num))) + + +class testGenerateProposalsClassAPI(unittest.TestCase): + + def setUp(self): + np.random.seed(678) + self.scores_np = np.random.rand(2, 3, 4, 4).astype('float32') + self.bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32') + self.im_shape_np = np.array([[8, 8], [6, 6]]).astype('float32') + self.anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4), + [4, 4, 3, 4]).astype('float32') + self.variances_np = np.ones((4, 4, 3, 4)).astype('float32') + + self.roi_expected, self.roi_probs_expected, self.rois_num_expected = generate_proposals_v2_in_python( + self.scores_np, + self.bbox_deltas_np, + self.im_shape_np, + self.anchors_np, + self.variances_np, + pre_nms_topN=10, + post_nms_topN=5, + nms_thresh=0.5, + min_size=0.1, + eta=1.0, + pixel_offset=False) + self.roi_expected = np.array(self.roi_expected).squeeze(1) + self.roi_probs_expected = np.array(self.roi_probs_expected).squeeze(1) + self.rois_num_expected = np.array(self.rois_num_expected) + + def test_dynamic(self): + paddle.disable_static() + + generate_proposals_module = paddle.vision.ops.GenerateProposals( + pre_nms_top_n=10, post_nms_top_n=5, return_rois_num=True) + scores = paddle.to_tensor(self.scores_np) + bbox_deltas = paddle.to_tensor(self.bbox_deltas_np) + im_shape = paddle.to_tensor(self.im_shape_np) + anchors = paddle.to_tensor(self.anchors_np) + variances = paddle.to_tensor(self.variances_np) + + rois, roi_probs, rois_num = generate_proposals_module( + scores, bbox_deltas, im_shape, anchors, variances) + + self.assertTrue(np.allclose(self.roi_expected, rois.numpy())) + self.assertTrue(np.allclose(self.roi_probs_expected, roi_probs.numpy())) + self.assertTrue(np.allclose(self.rois_num_expected, rois_num.numpy())) + + def test_static(self): + paddle.enable_static() + + generate_proposals_module = paddle.vision.ops.GenerateProposals( + pre_nms_top_n=10, post_nms_top_n=5, return_rois_num=True) + + scores = paddle.static.data(name='scores', + shape=[2, 3, 4, 4], + dtype='float32') + bbox_deltas = paddle.static.data(name='bbox_deltas', + shape=[2, 12, 4, 4], + dtype='float32') + im_shape = paddle.static.data(name='im_shape', + shape=[2, 2], + dtype='float32') + anchors = paddle.static.data(name='anchors', + shape=[4, 4, 3, 4], + dtype='float32') + variances = paddle.static.data(name='variances', + shape=[4, 4, 3, 4], + dtype='float32') + rois, roi_probs, rois_num = generate_proposals_module( + scores, bbox_deltas, im_shape, anchors, variances) + exe = paddle.static.Executor() + rois, roi_probs, rois_num = exe.run( + paddle.static.default_main_program(), + feed={ + 'scores': self.scores_np, + 'bbox_deltas': self.bbox_deltas_np, + 'im_shape': self.im_shape_np, + 'anchors': self.anchors_np, + 'variances': self.variances_np, + }, + fetch_list=[rois.name, roi_probs.name, rois_num.name], + return_numpy=False) + + self.assertTrue(np.allclose(self.roi_expected, np.array(rois))) + self.assertTrue( + np.allclose(self.roi_probs_expected, np.array(roi_probs))) + self.assertTrue(np.allclose(self.rois_num_expected, np.array(rois_num))) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 69fba204dd314..510aaeb6ecb85 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -24,19 +24,9 @@ from paddle import _C_ops __all__ = [ #noqa - 'yolo_loss', - 'yolo_box', - 'deform_conv2d', - 'DeformConv2D', - 'read_file', - 'decode_jpeg', - 'roi_pool', - 'RoIPool', - 'psroi_pool', - 'PSRoIPool', - 'roi_align', - 'RoIAlign', - 'nms', + 'yolo_loss', 'yolo_box', 'deform_conv2d', 'DeformConv2D', 'read_file', + 'decode_jpeg', 'roi_pool', 'RoIPool', 'psroi_pool', 'PSRoIPool', + 'roi_align', 'RoIAlign', 'nms', 'generate_proposals', 'GenerateProposals' ] @@ -1538,3 +1528,227 @@ def _nms(boxes, iou_threshold): return keep_boxes_idxs[topk_sub_indices] return keep_boxes_idxs[sorted_sub_indices][:top_k] + + +def generate_proposals(scores, + bbox_deltas, + im_shape, + anchors, + variances, + pre_nms_top_n=6000, + post_nms_top_n=1000, + nms_thresh=0.5, + min_size=0.1, + eta=1.0, + pixel_offset=False, + return_rois_num=False, + name=None): + """ + This operation proposes RoIs according to each box with their + probability to be a foreground object and + the box can be calculated by anchors. Bbox_deltais and scores + to be an object are the output of RPN. Final proposals + could be used to train detection net. + + For generating proposals, this operation performs following steps: + + 1. Transposes and resizes scores and bbox_deltas in size of + (H*W*A, 1) and (H*W*A, 4) + 2. Calculate box locations as proposals candidates. + 3. Clip boxes to image + 4. Remove predicted boxes with small area. + 5. Apply NMS to get final proposals as output. + + Args: + scores (Tensor): A 4-D Tensor with shape [N, A, H, W] represents + the probability for each box to be an object. + N is batch size, A is number of anchors, H and W are height and + width of the feature map. The data type must be float32. + bbox_deltas (Tensor): A 4-D Tensor with shape [N, 4*A, H, W] + represents the difference between predicted box location and + anchor location. The data type must be float32. + im_info (Tensor): A 2-D Tensor with shape [N, 2] represents origin + image shape information for N batch, including height and width of the input sizes. + The data type can be float32 or float64. + anchors (Tensor): A 4-D Tensor represents the anchors with a layout + of [H, W, A, 4]. H and W are height and width of the feature map, + num_anchors is the box count of each position. Each anchor is + in (xmin, ymin, xmax, ymax) format an unnormalized. The data type must be float32. + variances (Tensor): A 4-D Tensor. The expanded variances of anchors with a layout of + [H, W, num_priors, 4]. Each variance is in + (xcenter, ycenter, w, h) format. The data type must be float32. + pre_nms_top_n (float): Number of total bboxes to be kept per + image before NMS. The data type must be float32. `6000` by default. + post_nms_top_n (float): Number of total bboxes to be kept per + image after NMS. The data type must be float32. `1000` by default. + nms_thresh (float): Threshold in NMS. The data type must be float32. `0.5` by default. + min_size (float): Remove predicted boxes with either height or + width < min_size. The data type must be float32. `0.1` by default. + eta(float): Apply in adaptive NMS, if adaptive `threshold > 0.5`, + `adaptive_threshold = adaptive_threshold * eta` in each iteration. + return_rois_num (bool): When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's + num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents + the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model. + 'False' by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + tuple, result with format ``(rpn_rois, rpn_roi_probs, rpn_rois_num)``. + - **rpn_rois**: The generated RoIs. 2-D Tensor with shape ``[N, 4]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. + - **rpn_roi_probs**: The scores of generated RoIs. 2-D Tensor with shape ``[N, 1]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. + - **rpn_rois_num**: Rois's num of each image in one batch. 1-D Tensor with shape ``[B,]`` while ``B`` is the batch size. And its sum equals to RoIs number ``N`` . + + Examples: + .. code-block:: python + + import paddle + + scores = paddle.rand((2,4,5,5), dtype=paddle.float32) + bbox_deltas = paddle.rand((2, 16, 5, 5), dtype=paddle.float32) + im_shape = paddle.to_tensor([[224.0, 224.0], [224.0, 224.0]]) + anchors = paddle.rand((2,5,4,4), dtype=paddle.float32) + variances = paddle.rand((2,5,10,4), dtype=paddle.float32) + rois, roi_probs, roi_nums = paddle.vision.ops.generate_proposals(scores, bbox_deltas, + im_shape, anchors, variances, return_rois_num=True) + print(rois, roi_probs, roi_nums) + """ + + if _non_static_mode(): + assert return_rois_num, "return_rois_num should be True in dygraph mode." + attrs = ('pre_nms_topN', pre_nms_top_n, 'post_nms_topN', post_nms_top_n, + 'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta, + 'pixel_offset', pixel_offset) + rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals_v2( + scores, bbox_deltas, im_shape, anchors, variances, *attrs) + + return rpn_rois, rpn_roi_probs, rpn_rois_num + + helper = LayerHelper('generate_proposals_v2', **locals()) + + check_variable_and_dtype(scores, 'scores', ['float32'], + 'generate_proposals_v2') + check_variable_and_dtype(bbox_deltas, 'bbox_deltas', ['float32'], + 'generate_proposals_v2') + check_variable_and_dtype(im_shape, 'im_shape', ['float32', 'float64'], + 'generate_proposals_v2') + check_variable_and_dtype(anchors, 'anchors', ['float32'], + 'generate_proposals_v2') + check_variable_and_dtype(variances, 'variances', ['float32'], + 'generate_proposals_v2') + + rpn_rois = helper.create_variable_for_type_inference( + dtype=bbox_deltas.dtype) + rpn_roi_probs = helper.create_variable_for_type_inference( + dtype=scores.dtype) + outputs = { + 'RpnRois': rpn_rois, + 'RpnRoiProbs': rpn_roi_probs, + } + if return_rois_num: + rpn_rois_num = helper.create_variable_for_type_inference(dtype='int32') + rpn_rois_num.stop_gradient = True + outputs['RpnRoisNum'] = rpn_rois_num + + helper.append_op(type="generate_proposals_v2", + inputs={ + 'Scores': scores, + 'BboxDeltas': bbox_deltas, + 'ImShape': im_shape, + 'Anchors': anchors, + 'Variances': variances + }, + attrs={ + 'pre_nms_topN': pre_nms_top_n, + 'post_nms_topN': post_nms_top_n, + 'nms_thresh': nms_thresh, + 'min_size': min_size, + 'eta': eta, + 'pixel_offset': pixel_offset + }, + outputs=outputs) + rpn_rois.stop_gradient = True + rpn_roi_probs.stop_gradient = True + if not return_rois_num: + rpn_rois_num = None + + return rpn_rois, rpn_roi_probs, rpn_rois_num + + +class GenerateProposals(Layer): + """ + This interface is used to construct a callable object of the ``GenerateProposals`` class. Please + refer to :ref:`api_paddle_vision_ops_generate_proposals`. + + Args: + pre_nms_top_n (float): Number of total bboxes to be kept per + image before NMS. The data type must be float32. `6000` by default. + post_nms_top_n (float): Number of total bboxes to be kept per + image after NMS. The data type must be float32. `1000` by default. + nms_thresh (float): Threshold in NMS. The data type must be float32. `0.5` by default. + min_size (float): Remove predicted boxes with either height or + width < min_size. The data type must be float32. `0.1` by default. + eta(float): Apply in adaptive NMS, if adaptive `threshold > 0.5`, + `adaptive_threshold = adaptive_threshold * eta` in each iteration. + return_rois_num (bool): When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's + num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents + the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model. + 'False' by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Shape: + - scores: 4-D Tensor with shape (N, A, H, W). + - bbox_deltas: 2-D Tensor with shape (N, 4*A, H, W). + - im_info: 4-D Tensor with shape (N, 2). + - anchors: 4-D Tensor with shape (H, W, A, A). + - variances: 4-D Tensor with shape (H, W, num_priors, 4) + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + + generate_proposals_module = paddle.vision.ops.GenerateProposals(return_rois_num=True) + + scores = paddle.rand((2,4,5,5), dtype=paddle.float32) + bbox_deltas = paddle.rand((2, 16, 5, 5), dtype=paddle.float32) + im_shape = paddle.to_tensor([[224.0, 224.0], [224.0, 224.0]]) + anchors = paddle.rand((2,5,4,4), dtype=paddle.float32) + variances = paddle.rand((2,5,10,4), dtype=paddle.float32) + rois, roi_probs, roi_nums = generate_proposals_module(scores, bbox_deltas, + im_shape, anchors, variances) + print(rois, roi_probs, roi_nums) + + """ + + def __init__(self, + pre_nms_top_n=6000, + post_nms_top_n=1000, + nms_thresh=0.5, + min_size=0.1, + eta=1.0, + pixel_offset=False, + return_rois_num=False, + name=None): + super(GenerateProposals, self).__init__() + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.nms_thresh = nms_thresh + self.min_size = min_size + self.eta = eta + self.pixel_offset = pixel_offset + self.return_rois_num = return_rois_num + self.name = name + + def forward(self, scores, bbox_deltas, im_info, anchors, variances): + return generate_proposals(scores, bbox_deltas, im_info, anchors, + variances, self.pre_nms_top_n, + self.post_nms_top_n, self.nms_thresh, + self.min_size, self.eta, self.pixel_offset, + self.return_rois_num, self.name) From c69436a65dc9111cc0059a563763e8460de6ef55 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 30 Jun 2022 02:49:48 +0000 Subject: [PATCH 2/4] remove class api --- .../test_generate_proposals_v2_op.py | 89 +------------------ python/paddle/vision/ops.py | 81 +---------------- 2 files changed, 2 insertions(+), 168 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py b/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py index 99fd57700b982..ab6d17c38289b 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py @@ -254,7 +254,7 @@ def init_test_params(self): self.pixel_offset = False -class testGenerateProposalsFunctionAPI(unittest.TestCase): +class testGenerateProposalsAPI(unittest.TestCase): def setUp(self): np.random.seed(678) @@ -347,93 +347,6 @@ def test_static(self): self.assertTrue(np.allclose(self.rois_num_expected, np.array(rois_num))) -class testGenerateProposalsClassAPI(unittest.TestCase): - - def setUp(self): - np.random.seed(678) - self.scores_np = np.random.rand(2, 3, 4, 4).astype('float32') - self.bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32') - self.im_shape_np = np.array([[8, 8], [6, 6]]).astype('float32') - self.anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4), - [4, 4, 3, 4]).astype('float32') - self.variances_np = np.ones((4, 4, 3, 4)).astype('float32') - - self.roi_expected, self.roi_probs_expected, self.rois_num_expected = generate_proposals_v2_in_python( - self.scores_np, - self.bbox_deltas_np, - self.im_shape_np, - self.anchors_np, - self.variances_np, - pre_nms_topN=10, - post_nms_topN=5, - nms_thresh=0.5, - min_size=0.1, - eta=1.0, - pixel_offset=False) - self.roi_expected = np.array(self.roi_expected).squeeze(1) - self.roi_probs_expected = np.array(self.roi_probs_expected).squeeze(1) - self.rois_num_expected = np.array(self.rois_num_expected) - - def test_dynamic(self): - paddle.disable_static() - - generate_proposals_module = paddle.vision.ops.GenerateProposals( - pre_nms_top_n=10, post_nms_top_n=5, return_rois_num=True) - scores = paddle.to_tensor(self.scores_np) - bbox_deltas = paddle.to_tensor(self.bbox_deltas_np) - im_shape = paddle.to_tensor(self.im_shape_np) - anchors = paddle.to_tensor(self.anchors_np) - variances = paddle.to_tensor(self.variances_np) - - rois, roi_probs, rois_num = generate_proposals_module( - scores, bbox_deltas, im_shape, anchors, variances) - - self.assertTrue(np.allclose(self.roi_expected, rois.numpy())) - self.assertTrue(np.allclose(self.roi_probs_expected, roi_probs.numpy())) - self.assertTrue(np.allclose(self.rois_num_expected, rois_num.numpy())) - - def test_static(self): - paddle.enable_static() - - generate_proposals_module = paddle.vision.ops.GenerateProposals( - pre_nms_top_n=10, post_nms_top_n=5, return_rois_num=True) - - scores = paddle.static.data(name='scores', - shape=[2, 3, 4, 4], - dtype='float32') - bbox_deltas = paddle.static.data(name='bbox_deltas', - shape=[2, 12, 4, 4], - dtype='float32') - im_shape = paddle.static.data(name='im_shape', - shape=[2, 2], - dtype='float32') - anchors = paddle.static.data(name='anchors', - shape=[4, 4, 3, 4], - dtype='float32') - variances = paddle.static.data(name='variances', - shape=[4, 4, 3, 4], - dtype='float32') - rois, roi_probs, rois_num = generate_proposals_module( - scores, bbox_deltas, im_shape, anchors, variances) - exe = paddle.static.Executor() - rois, roi_probs, rois_num = exe.run( - paddle.static.default_main_program(), - feed={ - 'scores': self.scores_np, - 'bbox_deltas': self.bbox_deltas_np, - 'im_shape': self.im_shape_np, - 'anchors': self.anchors_np, - 'variances': self.variances_np, - }, - fetch_list=[rois.name, roi_probs.name, rois_num.name], - return_numpy=False) - - self.assertTrue(np.allclose(self.roi_expected, np.array(rois))) - self.assertTrue( - np.allclose(self.roi_probs_expected, np.array(roi_probs))) - self.assertTrue(np.allclose(self.rois_num_expected, np.array(rois_num))) - - if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 510aaeb6ecb85..7f371eca1d01a 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -26,7 +26,7 @@ __all__ = [ #noqa 'yolo_loss', 'yolo_box', 'deform_conv2d', 'DeformConv2D', 'read_file', 'decode_jpeg', 'roi_pool', 'RoIPool', 'psroi_pool', 'PSRoIPool', - 'roi_align', 'RoIAlign', 'nms', 'generate_proposals', 'GenerateProposals' + 'roi_align', 'RoIAlign', 'nms', 'generate_proposals' ] @@ -1595,7 +1595,6 @@ def generate_proposals(scores, None by default. Returns: - tuple, result with format ``(rpn_rois, rpn_roi_probs, rpn_rois_num)``. - **rpn_rois**: The generated RoIs. 2-D Tensor with shape ``[N, 4]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. - **rpn_roi_probs**: The scores of generated RoIs. 2-D Tensor with shape ``[N, 1]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. - **rpn_rois_num**: Rois's num of each image in one batch. 1-D Tensor with shape ``[B,]`` while ``B`` is the batch size. And its sum equals to RoIs number ``N`` . @@ -1674,81 +1673,3 @@ def generate_proposals(scores, rpn_rois_num = None return rpn_rois, rpn_roi_probs, rpn_rois_num - - -class GenerateProposals(Layer): - """ - This interface is used to construct a callable object of the ``GenerateProposals`` class. Please - refer to :ref:`api_paddle_vision_ops_generate_proposals`. - - Args: - pre_nms_top_n (float): Number of total bboxes to be kept per - image before NMS. The data type must be float32. `6000` by default. - post_nms_top_n (float): Number of total bboxes to be kept per - image after NMS. The data type must be float32. `1000` by default. - nms_thresh (float): Threshold in NMS. The data type must be float32. `0.5` by default. - min_size (float): Remove predicted boxes with either height or - width < min_size. The data type must be float32. `0.1` by default. - eta(float): Apply in adaptive NMS, if adaptive `threshold > 0.5`, - `adaptive_threshold = adaptive_threshold * eta` in each iteration. - return_rois_num (bool): When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's - num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents - the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model. - 'False' by default. - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - - Shape: - - scores: 4-D Tensor with shape (N, A, H, W). - - bbox_deltas: 2-D Tensor with shape (N, 4*A, H, W). - - im_info: 4-D Tensor with shape (N, 2). - - anchors: 4-D Tensor with shape (H, W, A, A). - - variances: 4-D Tensor with shape (H, W, num_priors, 4) - - Returns: - None - - Examples: - .. code-block:: python - - import paddle - - generate_proposals_module = paddle.vision.ops.GenerateProposals(return_rois_num=True) - - scores = paddle.rand((2,4,5,5), dtype=paddle.float32) - bbox_deltas = paddle.rand((2, 16, 5, 5), dtype=paddle.float32) - im_shape = paddle.to_tensor([[224.0, 224.0], [224.0, 224.0]]) - anchors = paddle.rand((2,5,4,4), dtype=paddle.float32) - variances = paddle.rand((2,5,10,4), dtype=paddle.float32) - rois, roi_probs, roi_nums = generate_proposals_module(scores, bbox_deltas, - im_shape, anchors, variances) - print(rois, roi_probs, roi_nums) - - """ - - def __init__(self, - pre_nms_top_n=6000, - post_nms_top_n=1000, - nms_thresh=0.5, - min_size=0.1, - eta=1.0, - pixel_offset=False, - return_rois_num=False, - name=None): - super(GenerateProposals, self).__init__() - self.pre_nms_top_n = pre_nms_top_n - self.post_nms_top_n = post_nms_top_n - self.nms_thresh = nms_thresh - self.min_size = min_size - self.eta = eta - self.pixel_offset = pixel_offset - self.return_rois_num = return_rois_num - self.name = name - - def forward(self, scores, bbox_deltas, im_info, anchors, variances): - return generate_proposals(scores, bbox_deltas, im_info, anchors, - variances, self.pre_nms_top_n, - self.post_nms_top_n, self.nms_thresh, - self.min_size, self.eta, self.pixel_offset, - self.return_rois_num, self.name) From cc783563a610922c0692c5ae9fd12b4cb8185721 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 30 Jun 2022 08:06:08 +0000 Subject: [PATCH 3/4] im_info -> img_size --- .../test_generate_proposals_v2_op.py | 14 ++--- python/paddle/vision/ops.py | 56 +++++++++---------- 2 files changed, 34 insertions(+), 36 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py b/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py index ab6d17c38289b..b1a4b45d7d257 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py @@ -260,7 +260,7 @@ def setUp(self): np.random.seed(678) self.scores_np = np.random.rand(2, 3, 4, 4).astype('float32') self.bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32') - self.im_shape_np = np.array([[8, 8], [6, 6]]).astype('float32') + self.img_size_np = np.array([[8, 8], [6, 6]]).astype('float32') self.anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4), [4, 4, 3, 4]).astype('float32') self.variances_np = np.ones((4, 4, 3, 4)).astype('float32') @@ -268,7 +268,7 @@ def setUp(self): self.roi_expected, self.roi_probs_expected, self.rois_num_expected = generate_proposals_v2_in_python( self.scores_np, self.bbox_deltas_np, - self.im_shape_np, + self.img_size_np, self.anchors_np, self.variances_np, pre_nms_topN=10, @@ -285,14 +285,14 @@ def test_dynamic(self): paddle.disable_static() scores = paddle.to_tensor(self.scores_np) bbox_deltas = paddle.to_tensor(self.bbox_deltas_np) - im_shape = paddle.to_tensor(self.im_shape_np) + img_size = paddle.to_tensor(self.img_size_np) anchors = paddle.to_tensor(self.anchors_np) variances = paddle.to_tensor(self.variances_np) rois, roi_probs, rois_num = paddle.vision.ops.generate_proposals( scores, bbox_deltas, - im_shape, + img_size, anchors, variances, pre_nms_top_n=10, @@ -310,7 +310,7 @@ def test_static(self): bbox_deltas = paddle.static.data(name='bbox_deltas', shape=[2, 12, 4, 4], dtype='float32') - im_shape = paddle.static.data(name='im_shape', + img_size = paddle.static.data(name='img_size', shape=[2, 2], dtype='float32') anchors = paddle.static.data(name='anchors', @@ -322,7 +322,7 @@ def test_static(self): rois, roi_probs, rois_num = paddle.vision.ops.generate_proposals( scores, bbox_deltas, - im_shape, + img_size, anchors, variances, pre_nms_top_n=10, @@ -334,7 +334,7 @@ def test_static(self): feed={ 'scores': self.scores_np, 'bbox_deltas': self.bbox_deltas_np, - 'im_shape': self.im_shape_np, + 'img_size': self.img_size_np, 'anchors': self.anchors_np, 'variances': self.variances_np, }, diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 8d1a2cf13339f..cc5a0caf71f47 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1534,7 +1534,7 @@ def _nms(boxes, iou_threshold): def generate_proposals(scores, bbox_deltas, - im_shape, + img_size, anchors, variances, pre_nms_top_n=6000, @@ -1547,19 +1547,18 @@ def generate_proposals(scores, name=None): """ This operation proposes RoIs according to each box with their - probability to be a foreground object and - the box can be calculated by anchors. Bbox_deltais and scores - to be an object are the output of RPN. Final proposals + probability to be a foreground object. And + the proposals of RPN output are calculated by anchors, bbox_deltas and scores. Final proposals could be used to train detection net. For generating proposals, this operation performs following steps: - 1. Transposes and resizes scores and bbox_deltas in size of - (H*W*A, 1) and (H*W*A, 4) + 1. Transpose and resize scores and bbox_deltas in size of + (H * W * A, 1) and (H * W * A, 4) 2. Calculate box locations as proposals candidates. 3. Clip boxes to image 4. Remove predicted boxes with small area. - 5. Apply NMS to get final proposals as output. + 5. Apply non-maximum suppression (NMS) to get final proposals as output. Args: scores (Tensor): A 4-D Tensor with shape [N, A, H, W] represents @@ -1569,7 +1568,7 @@ def generate_proposals(scores, bbox_deltas (Tensor): A 4-D Tensor with shape [N, 4*A, H, W] represents the difference between predicted box location and anchor location. The data type must be float32. - im_info (Tensor): A 2-D Tensor with shape [N, 2] represents origin + img_size (Tensor): A 2-D Tensor with shape [N, 2] represents origin image shape information for N batch, including height and width of the input sizes. The data type can be float32 or float64. anchors (Tensor): A 4-D Tensor represents the anchors with a layout @@ -1579,27 +1578,26 @@ def generate_proposals(scores, variances (Tensor): A 4-D Tensor. The expanded variances of anchors with a layout of [H, W, num_priors, 4]. Each variance is in (xcenter, ycenter, w, h) format. The data type must be float32. - pre_nms_top_n (float): Number of total bboxes to be kept per - image before NMS. The data type must be float32. `6000` by default. - post_nms_top_n (float): Number of total bboxes to be kept per - image after NMS. The data type must be float32. `1000` by default. - nms_thresh (float): Threshold in NMS. The data type must be float32. `0.5` by default. - min_size (float): Remove predicted boxes with either height or - width < min_size. The data type must be float32. `0.1` by default. - eta(float): Apply in adaptive NMS, if adaptive `threshold > 0.5`, - `adaptive_threshold = adaptive_threshold * eta` in each iteration. - return_rois_num (bool): When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's - num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents - the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model. - 'False' by default. + pre_nms_top_n (float, optional): Number of total bboxes to be kept per + image before NMS. `6000` by default. + post_nms_top_n (float, optional): Number of total bboxes to be kept per + image after NMS. `1000` by default. + nms_thresh (float, optional): Threshold in NMS. The data type must be float32. `0.5` by default. + min_size (float, optional): Remove predicted boxes with either height or + width less than this value. `0.1` by default. + eta(float, optional): Apply in adaptive NMS, only works if adaptive `threshold > 0.5`, + `adaptive_threshold = adaptive_threshold * eta` in each iteration. 1.0 by default. + pixel_offset (bool, optional): Whether there is pixel offset. If True, the offset of `img_size` will be 1. 'False' by default. + return_rois_num (bool, optional): Whether to return `rpn_rois_num` . When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's + num of each image in one batch. 'False' by default. name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. Returns: - - **rpn_rois**: The generated RoIs. 2-D Tensor with shape ``[N, 4]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. - - **rpn_roi_probs**: The scores of generated RoIs. 2-D Tensor with shape ``[N, 1]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. - - **rpn_rois_num**: Rois's num of each image in one batch. 1-D Tensor with shape ``[B,]`` while ``B`` is the batch size. And its sum equals to RoIs number ``N`` . + - rpn_rois (Tensor): The generated RoIs. 2-D Tensor with shape ``[N, 4]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. + - rpn_roi_probs (Tensor): The scores of generated RoIs. 2-D Tensor with shape ``[N, 1]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. + - rpn_rois_num (Tensor): Rois's num of each image in one batch. 1-D Tensor with shape ``[B,]`` while ``B`` is the batch size. And its sum equals to RoIs number ``N`` . Examples: .. code-block:: python @@ -1608,11 +1606,11 @@ def generate_proposals(scores, scores = paddle.rand((2,4,5,5), dtype=paddle.float32) bbox_deltas = paddle.rand((2, 16, 5, 5), dtype=paddle.float32) - im_shape = paddle.to_tensor([[224.0, 224.0], [224.0, 224.0]]) + img_size = paddle.to_tensor([[224.0, 224.0], [224.0, 224.0]]) anchors = paddle.rand((2,5,4,4), dtype=paddle.float32) variances = paddle.rand((2,5,10,4), dtype=paddle.float32) rois, roi_probs, roi_nums = paddle.vision.ops.generate_proposals(scores, bbox_deltas, - im_shape, anchors, variances, return_rois_num=True) + img_size, anchors, variances, return_rois_num=True) print(rois, roi_probs, roi_nums) """ @@ -1622,7 +1620,7 @@ def generate_proposals(scores, 'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta, 'pixel_offset', pixel_offset) rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals_v2( - scores, bbox_deltas, im_shape, anchors, variances, *attrs) + scores, bbox_deltas, img_size, anchors, variances, *attrs) return rpn_rois, rpn_roi_probs, rpn_rois_num @@ -1632,7 +1630,7 @@ def generate_proposals(scores, 'generate_proposals_v2') check_variable_and_dtype(bbox_deltas, 'bbox_deltas', ['float32'], 'generate_proposals_v2') - check_variable_and_dtype(im_shape, 'im_shape', ['float32', 'float64'], + check_variable_and_dtype(img_size, 'img_size', ['float32', 'float64'], 'generate_proposals_v2') check_variable_and_dtype(anchors, 'anchors', ['float32'], 'generate_proposals_v2') @@ -1656,7 +1654,7 @@ def generate_proposals(scores, inputs={ 'Scores': scores, 'BboxDeltas': bbox_deltas, - 'ImShape': im_shape, + 'ImShape': img_size, 'Anchors': anchors, 'Variances': variances }, From a5488b7b04c5e1b09ba9b5e0c8212fe2850bc4f3 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 13 Jul 2022 04:10:48 +0000 Subject: [PATCH 4/4] change fluid impl to current version --- python/paddle/fluid/layers/detection.py | 71 +++++-------------------- 1 file changed, 14 insertions(+), 57 deletions(-) diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index f89c95b93a1d3..9a7ab0ebbb5aa 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -17,6 +17,8 @@ from __future__ import print_function +import paddle + from .layer_function_generator import generate_layer_fn from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper @@ -3007,63 +3009,18 @@ def generate_proposals(scores, im_info, anchors, variances) """ - if _non_static_mode(): - assert return_rois_num, "return_rois_num should be True in dygraph mode." - attrs = ('pre_nms_topN', pre_nms_top_n, 'post_nms_topN', post_nms_top_n, - 'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta) - rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals( - scores, bbox_deltas, im_info, anchors, variances, *attrs) - return rpn_rois, rpn_roi_probs, rpn_rois_num - - helper = LayerHelper('generate_proposals', **locals()) - - check_variable_and_dtype(scores, 'scores', ['float32'], - 'generate_proposals') - check_variable_and_dtype(bbox_deltas, 'bbox_deltas', ['float32'], - 'generate_proposals') - check_variable_and_dtype(im_info, 'im_info', ['float32', 'float64'], - 'generate_proposals') - check_variable_and_dtype(anchors, 'anchors', ['float32'], - 'generate_proposals') - check_variable_and_dtype(variances, 'variances', ['float32'], - 'generate_proposals') - - rpn_rois = helper.create_variable_for_type_inference( - dtype=bbox_deltas.dtype) - rpn_roi_probs = helper.create_variable_for_type_inference( - dtype=scores.dtype) - outputs = { - 'RpnRois': rpn_rois, - 'RpnRoiProbs': rpn_roi_probs, - } - if return_rois_num: - rpn_rois_num = helper.create_variable_for_type_inference(dtype='int32') - rpn_rois_num.stop_gradient = True - outputs['RpnRoisNum'] = rpn_rois_num - - helper.append_op(type="generate_proposals", - inputs={ - 'Scores': scores, - 'BboxDeltas': bbox_deltas, - 'ImInfo': im_info, - 'Anchors': anchors, - 'Variances': variances - }, - attrs={ - 'pre_nms_topN': pre_nms_top_n, - 'post_nms_topN': post_nms_top_n, - 'nms_thresh': nms_thresh, - 'min_size': min_size, - 'eta': eta - }, - outputs=outputs) - rpn_rois.stop_gradient = True - rpn_roi_probs.stop_gradient = True - - if return_rois_num: - return rpn_rois, rpn_roi_probs, rpn_rois_num - else: - return rpn_rois, rpn_roi_probs + return paddle.vision.ops.generate_proposals(scores=scores, + bbox_deltas=bbox_deltas, + img_size=im_info[:2], + anchors=anchors, + variances=variances, + pre_nms_top_n=pre_nms_top_n, + post_nms_top_n=post_nms_top_n, + nms_thresh=nms_thresh, + min_size=min_size, + eta=eta, + return_rois_num=return_rois_num, + name=name) def box_clip(input, im_info, name=None):