Rename RoiAlign & RoiPool (#446)

This commit is contained in:
Antkillerfarm 2022-07-29 11:10:25 +08:00 committed by GitHub
parent 96c9d5df01
commit 32241dc4ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 16 deletions

View File

@ -30,7 +30,7 @@ namespace vx {
namespace ops { namespace ops {
/** /**
* ## ROI_ALIGN * ## RoiAlign
* *
* Select and scale the feature map of each region of interest to a unified output * Select and scale the feature map of each region of interest to a unified output
* size by average pooling sampling points from bilinear interpolation. * size by average pooling sampling points from bilinear interpolation.
@ -47,9 +47,9 @@ namespace ops {
* used to compute the output. * used to compute the output.
*/ */
class ROI_Align : public DirectMapOp { class RoiAlign : public DirectMapOp {
public: public:
ROI_Align(Graph* graph, int32_t output_height, int32_t output_width, RoiAlign(Graph* graph, int32_t output_height, int32_t output_width,
float height_ratio, float width_ratio, int32_t height_sample_num, float height_ratio, float width_ratio, int32_t height_sample_num,
int32_t width_sample_num); int32_t width_sample_num);
@ -65,6 +65,8 @@ class ROI_Align : public DirectMapOp {
int32_t width_sample_num_; int32_t width_sample_num_;
}; };
using ROI_Align = RoiAlign;
} // namespace ops } // namespace ops
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim

View File

@ -33,7 +33,7 @@ namespace vx {
namespace ops { namespace ops {
/** /**
* ## ROI_POOL * ## RoiPool
* *
* Select and scale the feature map of each region of interest to a unified output * Select and scale the feature map of each region of interest to a unified output
* size by max-pooling. * size by max-pooling.
@ -44,9 +44,9 @@ namespace ops {
* *
*/ */
class ROI_Pool : public DirectMapOp { class RoiPool : public DirectMapOp {
public: public:
ROI_Pool(Graph* graph, PoolType type, float scale, RoiPool(Graph* graph, PoolType type, float scale,
const std::array<uint32_t, 2>& size); const std::array<uint32_t, 2>& size);
std::shared_ptr<Operation> Clone( std::shared_ptr<Operation> Clone(
@ -58,6 +58,8 @@ class ROI_Pool : public DirectMapOp {
std::array<uint32_t, 2> size_; std::array<uint32_t, 2> size_;
}; };
using ROI_Pool = RoiPool;
} // namespace ops } // namespace ops
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim

View File

@ -30,7 +30,7 @@ namespace tim {
namespace vx { namespace vx {
namespace ops { namespace ops {
ROI_Align::ROI_Align(Graph* graph, int32_t output_height, int32_t output_width, RoiAlign::RoiAlign(Graph* graph, int32_t output_height, int32_t output_width,
float height_ratio, float width_ratio, int32_t height_sample_num, float height_ratio, float width_ratio, int32_t height_sample_num,
int32_t width_sample_num) int32_t width_sample_num)
: DirectMapOp(graph, VSI_NN_OP_ROI_ALIGN), : DirectMapOp(graph, VSI_NN_OP_ROI_ALIGN),
@ -49,9 +49,9 @@ ROI_Align::ROI_Align(Graph* graph, int32_t output_height, int32_t output_width,
this->impl()->node()->nn_param.roi_align.width_sample_num = width_sample_num; this->impl()->node()->nn_param.roi_align.width_sample_num = width_sample_num;
} }
std::shared_ptr<Operation> ROI_Align::Clone( std::shared_ptr<Operation> RoiAlign::Clone(
std::shared_ptr<Graph>& graph) const { std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<ROI_Align>( return graph->CreateOperation<RoiAlign>(
this->output_height_, this->output_width_, this->height_ratio_, this->output_height_, this->output_width_, this->height_ratio_,
this->width_ratio_, this->height_sample_num_, this->width_sample_num_); this->width_ratio_, this->height_sample_num_, this->width_sample_num_);
} }

View File

@ -29,7 +29,7 @@
#include "tim/vx/graph.h" #include "tim/vx/graph.h"
#include "tim/vx/types.h" #include "tim/vx/types.h"
TEST(ROI_Align, shape_4_2_1_1_float32) { TEST(RoiAlign, shape_4_2_1_1_float32) {
auto ctx = tim::vx::Context::Create(); auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph(); auto graph = ctx->CreateGraph();
@ -83,7 +83,7 @@ TEST(ROI_Align, shape_4_2_1_1_float32) {
graph->CreateTensor(batch_index_spec, batch_index_data.data()); graph->CreateTensor(batch_index_spec, batch_index_data.data());
auto output_tensor = graph->CreateTensor(output_spec); auto output_tensor = graph->CreateTensor(output_spec);
auto roi_align = graph->CreateOperation<tim::vx::ops::ROI_Align>( auto roi_align = graph->CreateOperation<tim::vx::ops::RoiAlign>(
out_height, out_width, height_ratio, width_ratio, height_sample_num, out_height, out_width, height_ratio, width_ratio, height_sample_num,
width_sample_num); width_sample_num);
(*roi_align) (*roi_align)

View File

@ -31,7 +31,7 @@ namespace tim {
namespace vx { namespace vx {
namespace ops { namespace ops {
ROI_Pool::ROI_Pool(Graph* graph, PoolType type, float scale, RoiPool::RoiPool(Graph* graph, PoolType type, float scale,
const std::array<uint32_t, 2>& size) const std::array<uint32_t, 2>& size)
: DirectMapOp(graph, VSI_NN_OP_ROI_POOL), : DirectMapOp(graph, VSI_NN_OP_ROI_POOL),
type_(type), type_(type),
@ -43,9 +43,9 @@ ROI_Pool::ROI_Pool(Graph* graph, PoolType type, float scale,
this->impl()->node()->nn_param.roi_pool.size[1] = size[1]; this->impl()->node()->nn_param.roi_pool.size[1] = size[1];
} }
std::shared_ptr<Operation> ROI_Pool::Clone( std::shared_ptr<Operation> RoiPool::Clone(
std::shared_ptr<Graph>& graph) const { std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<ROI_Pool>( return graph->CreateOperation<RoiPool>(
this->type_, this->scale_, this->size_); this->type_, this->scale_, this->size_);
} }

View File

@ -29,7 +29,7 @@
#include "tim/vx/graph.h" #include "tim/vx/graph.h"
#include "tim/vx/types.h" #include "tim/vx/types.h"
TEST(ROI_Pool, shape_4_2_1_1_float32) { TEST(RoiPool, shape_4_2_1_1_float32) {
auto ctx = tim::vx::Context::Create(); auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph(); auto graph = ctx->CreateGraph();
@ -81,7 +81,7 @@ TEST(ROI_Pool, shape_4_2_1_1_float32) {
std::array<uint32_t, 2> size; std::array<uint32_t, 2> size;
size[0] = out_height; size[0] = out_height;
size[1] = out_width; size[1] = out_width;
auto roi_pool = graph->CreateOperation<tim::vx::ops::ROI_Pool>(tim::vx::PoolType::MAX, scale, size); auto roi_pool = graph->CreateOperation<tim::vx::ops::RoiPool>(tim::vx::PoolType::MAX, scale, size);
(*roi_pool) (*roi_pool)
.BindInput(input_tensor) .BindInput(input_tensor)
.BindInput(regions_tensor) .BindInput(regions_tensor)