diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index c9de876..7fd00dd 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -181,6 +181,7 @@ void LayoutInferContext::UpdateGraphOutputMap(const std::shared_ptr& REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_PROD, ReduceProd); \ REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ANY, ReduceAny); \ REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_SUM, ReduceSum); \ + REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ALL, ReduceAll); \ default: \ VSILOGW("Op %d: Default layout inference pass for reduce.", reduce_type);\ assert(false); \ diff --git a/src/tim/transform/ops/reduce_layout_inference.h b/src/tim/transform/ops/reduce_layout_inference.h index 2544f32..7a03463 100644 --- a/src/tim/transform/ops/reduce_layout_inference.h +++ b/src/tim/transform/ops/reduce_layout_inference.h @@ -89,6 +89,7 @@ using ReduceAnyLayoutInfer = ReduceLayoutInfer; using ReduceProdLayoutInfer = ReduceLayoutInfer; using ReduceMeanLayoutInfer = ReduceLayoutInfer; using ReduceSumLayoutInfer = ReduceLayoutInfer; +using ReduceAllLayoutInfer = ReduceLayoutInfer; } // namespace transform } // namespace tim diff --git a/src/tim/vx/ops/reduce_sum_test.cc b/src/tim/vx/ops/reduce_test.cc similarity index 60% rename from src/tim/vx/ops/reduce_sum_test.cc rename to src/tim/vx/ops/reduce_test.cc index c524c5d..b294db5 100644 --- a/src/tim/vx/ops/reduce_sum_test.cc +++ b/src/tim/vx/ops/reduce_test.cc @@ -128,4 +128,105 @@ TEST(Reduce_sum, KeepDims) { std::vector output(golden.size()); EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); +} + +TEST(Reduce_all, KeepDims) { + auto ctx = tim::vx::Context::Create(); + + auto graph = ctx->CreateGraph(); + tim::vx::ShapeType input_shape({2, 3, 2}); + tim::vx::ShapeType output_shape({1, 3, 1}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::BOOL8, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::BOOL8, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + std::vector axis = {2, 0}; + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto reduce_all = graph->CreateOperation(axis, true); + (*reduce_all).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + bool in_data[] = {true, true, true, true, true, true, + true, true, false, true, true, true}; + bool golden[] = {true, false, true}; + bool* p_in = in_data; + input_tensor->CopyDataToTensor(p_in, 12 * sizeof(bool)); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + bool output[3 * sizeof(bool)]; + bool* p_output = output; + EXPECT_TRUE(output_tensor->CopyDataFromTensor(p_output)); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(golden[i] == output[i]); + } +} + +TEST(Reduce_all, NotKeepDims) { + auto ctx = tim::vx::Context::Create(); + + auto graph = ctx->CreateGraph(); + tim::vx::ShapeType input_shape({2, 3, 2}); + tim::vx::ShapeType output_shape({2}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::BOOL8, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::BOOL8, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + std::vector axis = {1, 2, 2, 2}; + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + auto reduce_all = + graph->CreateOperation(axis, false); + (*reduce_all).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + bool in_data[] = {true, true, true, true, true, false, + true, true, true, true, true, true}; + bool golden[] = {true, false}; + bool* p_in = in_data; + input_tensor->CopyDataToTensor(p_in, 12 * sizeof(bool)); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + bool output[2 * sizeof(bool)]; + bool* p_output = output; + EXPECT_TRUE(output_tensor->CopyDataFromTensor(p_output)); + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(golden[i] == output[i]); + } +} + +TEST(Reduce_max, NotKeepDims) { + auto ctx = tim::vx::Context::Create(); + + auto graph = ctx->CreateGraph(); + tim::vx::ShapeType input_shape({2, 3}); + tim::vx::ShapeType output_shape({3}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + std::vector axis = {0}; + auto reduce_sum = graph->CreateOperation(axis, false); + (*reduce_sum).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + std::vector in_data = {-1.0f, -2.0f, 3.0f, 4.0f, 5.0f, -6.0f}; + std::vector golden = {-1.0f, 4.0f, 5.0f}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); } \ No newline at end of file