From 20db77ee614d9064b84ef8abdf06bd1465a5bff0 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Thu, 29 Sep 2022 14:25:08 +0800 Subject: [PATCH] Added two cases in strided_slice Signed-off-by: Chen Xin --- src/tim/vx/ops/stridedslice_test.cc | 78 +++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/src/tim/vx/ops/stridedslice_test.cc b/src/tim/vx/ops/stridedslice_test.cc index c23a743..42f54f6 100644 --- a/src/tim/vx/ops/stridedslice_test.cc +++ b/src/tim/vx/ops/stridedslice_test.cc @@ -109,3 +109,81 @@ TEST(StridedSlice, shape_) { EXPECT_TRUE(ret) << "Failed at execute"; } + +TEST(StridedSlice, shrinkmask_1) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({3, 2, 1}); + tim::vx::ShapeType output_shape({3, 2}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = {1, 1, 2, 2, 3, 3}; + std::vector golden = {1, 1, 2, 2, 3, 3}; + + std::vector begin = {0, 0, 0}; + std::vector end = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + // The ith bits in MASK_SHRINK will mask input_shape[i]. + uint32_t MASK_BEGIN = 0, MASK_END = 0, MASK_SHRINK = 0b100; + + auto op = graph->CreateOperation( + begin, end, strides, MASK_BEGIN, MASK_END, MASK_SHRINK); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float)); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +} + +TEST(StridedSlice, endmask_1) { + auto ctx = tim::vx::Context::Create(); + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType input_shape({3, 2, 1}); + tim::vx::ShapeType output_shape({3, 2, 1}); + + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = {1, 1, 2, 2, 3, 3}; + std::vector golden = {1, 1, 2, 2, 3, 3}; + + std::vector begin = {2, 0, 0}; + std::vector end = {3, 2, 1}; + std::vector strides = {1, 1, 1}; + // The ith bits in MASK_BEGIN will mask begin[i]. MASK_END is similar. + uint32_t MASK_BEGIN = 0b001, MASK_END = 0, MASK_SHRINK = 0; + + auto op = graph->CreateOperation( + begin, end, strides, MASK_BEGIN, MASK_END, MASK_SHRINK); + (*op).BindInputs({input_tensor}).BindOutputs({output_tensor}); + + input_tensor->CopyDataToTensor(in_data.data(), + in_data.size() * sizeof(float)); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +} \ No newline at end of file