From bb10884f98dd6137cd7ad44d2f70df8fa339b429 Mon Sep 17 00:00:00 2001 From: Chen Feiyue <69809761+chenfeiyue-cfy@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:58:03 +0800 Subject: [PATCH] Added scalar type support (#655) Added SetScalar api to support scalar input Added 2 cases for scalar index Gather Type: New Feature Signed-off-by: Feiyue Chen --- include/tim/vx/tensor.h | 2 + src/tim/vx/ops/gather_test.cc | 82 +++++++++++++++++++++++++++++++++++ src/tim/vx/tensor.cc | 8 ++++ src/tim/vx/tensor_private.h | 9 ++++ 4 files changed, 101 insertions(+) diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index 2623501..c62f5fa 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -161,7 +161,9 @@ class Tensor { virtual void unmap() = 0; virtual bool IsPlaceHolder() = 0; virtual bool IsConstTensor() = 0; + virtual bool IsScalar() = 0; virtual bool SaveTensorToTextByFp32(std::string filename) = 0; + virtual void SetScalar(int8_t is_scalar) = 0; virtual void* ConvertTensorToData(uint8_t* tensorData) = 0; virtual float* ConvertTensorToFloat32Data() = 0; }; diff --git a/src/tim/vx/ops/gather_test.cc b/src/tim/vx/ops/gather_test.cc index 8989b1c..205d1cd 100644 --- a/src/tim/vx/ops/gather_test.cc +++ b/src/tim/vx/ops/gather_test.cc @@ -113,3 +113,85 @@ TEST(Gather, shape_2_2_indices_2) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_EQ(golden, output); } + +TEST(Gather, scalar_index_input2D) { + auto ctx = tim::vx::Context::Create(); + + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({5,2}); + tim::vx::ShapeType index_shape({1}); + tim::vx::ShapeType out_shape({5}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, in_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec index_spec(tim::vx::DataType::INT32, index_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, out_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto index_tensor = graph->CreateTensor(index_spec); + index_tensor->SetScalar(1); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1,2,3,4,5, + 6,7,8,9,10}; + + std::vector index = {1}; + std::vector golden = {6,7,8,9,10}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); + EXPECT_TRUE(index_tensor->CopyDataToTensor(index.data(), index.size())); + auto op = graph->CreateOperation(1, 0); + (*op).BindInputs({input_tensor, index_tensor}).BindOutputs({output_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +} + +TEST(Gather, scalar_index_input1D) { + auto ctx = tim::vx::Context::Create(); + + auto graph = ctx->CreateGraph(); + + tim::vx::ShapeType in_shape({5}); + tim::vx::ShapeType index_shape({1}); + tim::vx::ShapeType out_shape({}); + tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, in_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec index_spec(tim::vx::DataType::INT32, index_shape, + tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec gatherout_spec(tim::vx::DataType::FLOAT32, out_shape, + tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, out_shape, + tim::vx::TensorAttribute::OUTPUT); + + auto input_tensor = graph->CreateTensor(input_spec); + auto index_tensor = graph->CreateTensor(index_spec); + index_tensor->SetScalar(1); + auto gatherout_tensor = graph->CreateTensor(gatherout_spec); + auto output_tensor = graph->CreateTensor(output_spec); + + std::vector in_data = { + 1,2,3,4,5}; + + std::vector index = {1}; + std::vector golden = {2}; + + EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size())); + EXPECT_TRUE(index_tensor->CopyDataToTensor(index.data(), index.size())); + auto gather = graph->CreateOperation(0); + (*gather).BindInputs({input_tensor, index_tensor}).BindOutputs({gatherout_tensor}); + + EXPECT_TRUE(graph->Compile()); + EXPECT_TRUE(graph->Run()); + + std::vector output(golden.size()); + EXPECT_TRUE(gatherout_tensor->CopyDataFromTensor(output.data())); + EXPECT_EQ(golden, output); +} \ No newline at end of file diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index ccde1f4..b6dd9ac 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -213,6 +213,14 @@ float* TensorImpl::ConvertTensorToFloat32Data() { graph_->graph(), vsi_nn_GetTensor(graph_->graph(), id_)); } +void TensorImpl::SetScalar(int8_t is_scalar) { + bool retn = vsi_nn_SetTensorIsScalar(vsi_nn_GetTensor(graph_->graph(), id_),is_scalar); + if (retn != VSI_SUCCESS) { + VSILOGE("Setting scalar fail!"); + } + return; +} + bool TensorImpl::SwapHandle(void* new_ptr, bool is_new_ptr_malloc_by_ovxlib, void** old_ptr) { bool retn = true; diff --git a/src/tim/vx/tensor_private.h b/src/tim/vx/tensor_private.h index d76cbef..bb72397 100644 --- a/src/tim/vx/tensor_private.h +++ b/src/tim/vx/tensor_private.h @@ -60,9 +60,14 @@ class TensorImpl : public Tensor { bool IsConstTensor() override { return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; } + bool IsScalar() override { + return vsi_nn_GetTensorIsScalar(vsi_nn_GetTensor(graph_->graph(), id_)); + } bool SaveTensorToTextByFp32(std::string filename) override; void* ConvertTensorToData(uint8_t* tensorData) override; float* ConvertTensorToFloat32Data() override; + void SetScalar(int8_t is_scalar) override; + GraphImpl* graph_; vsi_nn_tensor_id_t id_; TensorSpec spec_; @@ -114,6 +119,9 @@ class TensorPlaceholder : public Tensor { bool IsConstTensor() override { return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; } + bool IsScalar() override { + return false; + } bool SaveTensorToTextByFp32(std::string filename) override { (void)filename; return false; @@ -124,6 +132,7 @@ class TensorPlaceholder : public Tensor { } float* ConvertTensorToFloat32Data() override { return nullptr; } + void SetScalar(int8_t is_scalar) override { (void) is_scalar; return; } vsi_nn_tensor_id_t id_; TensorSpec spec_; };