diff --git a/src/tim/vx/builtin_op_impl.cc b/src/tim/vx/builtin_op_impl.cc index ef8ffd2..4c477a0 100644 --- a/src/tim/vx/builtin_op_impl.cc +++ b/src/tim/vx/builtin_op_impl.cc @@ -46,8 +46,7 @@ BuiltinOpImpl& BuiltinOpImpl::BindInput( uint32_t tensor_id = tensor->GetId(); node_->input.tensors[input_tensor_index++] = tensor_id; if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) { - graph_->AddInput(tensor_id); - graph_->AddInput(tensor); + graph_->ConsumeInput(); } return *this; } @@ -57,9 +56,8 @@ BuiltinOpImpl& BuiltinOpImpl::BindOutput( outputs_tensor_.push_back(tensor); uint32_t tensor_id = tensor->GetId(); node_->output.tensors[output_tensor_index++] = tensor_id; - if (tensor->GetSpec().attr_ == TensorAttribute::OUTPUT) { - graph_->AddOutput(tensor_id); - graph_->AddOutput(tensor); + if (tensor->GetSpec().attr_ & TensorAttribute::OUTPUT) { + graph_->ConsumeOutput(); } return *this; } diff --git a/src/tim/vx/graph.cc b/src/tim/vx/graph.cc index 185b4a5..5c17296 100644 --- a/src/tim/vx/graph.cc +++ b/src/tim/vx/graph.cc @@ -49,6 +49,8 @@ GraphImpl::GraphImpl(ContextImpl* context, const CompileOption& options) : context_(context), graph_(vsi_nn_CreateGraph(context_->context(), 0, 0)), tensor_placeholder_(nullptr), + not_consumed_input_cnt_(0), + not_consumed_output_cnt_(0), options_(options){} GraphImpl::~GraphImpl() { vsi_nn_ReleaseGraph(&graph_); } @@ -133,17 +135,50 @@ void GraphImpl::PrintGraph() const { vsi_nn_PrintGraph(this->graph_); } std::shared_ptr GraphImpl::CreateTensor(const TensorSpec& spec, const void* data) { - return std::make_shared(this, spec, data); + auto tensor = std::make_shared(this, spec, data); + if (spec.attr_ & TensorAttribute::INPUT) { + this->AddInput(tensor); + this->AddInput(tensor->GetId()); + this->ProduceInput(); + } + if (spec.attr_ & TensorAttribute::OUTPUT) { + this->AddOutput(tensor); + this->AddOutput(tensor->GetId()); + this->ProduceOutput(); + } + return tensor; } std::shared_ptr GraphImpl::CreateTensor(const TensorSpec& spec, const DmaBufferDesc& dmafd) { - return std::make_shared(this, spec, dmafd); + auto tensor = std::make_shared(this, spec, dmafd); + if (spec.attr_ & TensorAttribute::INPUT) { + this->AddInput(tensor); + this->AddInput(tensor->GetId()); + this->ProduceInput(); + } + if (spec.attr_ & TensorAttribute::OUTPUT) { + this->AddOutput(tensor); + this->AddOutput(tensor->GetId()); + this->ProduceOutput(); + } + return tensor; } std::shared_ptr GraphImpl::CreateIOTensor(const TensorSpec& spec, void* data) { - return std::make_shared(this, spec, data); + auto tensor = std::make_shared(this, spec, data); + if (spec.attr_ & TensorAttribute::INPUT) { + this->AddInput(tensor); + this->AddInput(tensor->GetId()); + this->ProduceInput(); + } + if (spec.attr_ & TensorAttribute::OUTPUT) { + this->AddOutput(tensor); + this->AddOutput(tensor->GetId()); + this->ProduceOutput(); + } + return tensor; } std::shared_ptr GraphImpl::CreateTensorPlaceHolder() { @@ -185,7 +220,15 @@ bool GraphImpl::Setup() { bool GraphImpl::Compile() { bool status = true; - + if (not_consumed_input_cnt_ > 0 ) { + // Tensor can bind to different operations + VSILOGE("Graph has free input, INPUT tensor may be created with OUTPUT attr."); + return false; + } + if (not_consumed_output_cnt_ != 0) { + VSILOGE("Graph has free output, OUTPUT tensor may be created with INPUT attr."); + return false; + } status = Setup(); std::call_once(verify_graph_once_, [&status, this]() { status = (VSI_SUCCESS == vsi_nn_VerifyGraph(this->graph_)); diff --git a/src/tim/vx/graph_private.h b/src/tim/vx/graph_private.h index 6ade1cb..c693fb3 100644 --- a/src/tim/vx/graph_private.h +++ b/src/tim/vx/graph_private.h @@ -77,6 +77,10 @@ class GraphImpl : public Graph { bool Compile() override; bool CompileToBinary(void* buf, size_t* size) override; bool Run() override; + void ProduceInput() { not_consumed_input_cnt_++; } + void ProduceOutput() { not_consumed_output_cnt_++; } + void ConsumeInput() { not_consumed_input_cnt_--; } + void ConsumeOutput() { not_consumed_output_cnt_--; } protected: ContextImpl* context_; @@ -88,7 +92,9 @@ class GraphImpl : public Graph { std::vector inputs_; std::vector outputs_; std::vector> inputs_tensor_; + int32_t not_consumed_input_cnt_; std::vector> outputs_tensor_; + int32_t not_consumed_output_cnt_; std::map, std::vector>> tensor_consumers_; std::map, std::shared_ptr> tensor_producer_; diff --git a/src/tim/vx/ops/maxpoolgrad_test.cc b/src/tim/vx/ops/maxpoolgrad_test.cc index 153444f..2ed3e6b 100644 --- a/src/tim/vx/ops/maxpoolgrad_test.cc +++ b/src/tim/vx/ops/maxpoolgrad_test.cc @@ -95,7 +95,7 @@ TEST(Fuse_MaxpoolGrad, with_overlay) { auto input_tensor = graph->CreateTensor(input_spec); auto updates_tensor = graph->CreateTensor(updates_spec); - auto output_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); std::vector in_data = { 7, 2, 5, 3, 8, @@ -124,7 +124,7 @@ TEST(Fuse_MaxpoolGrad, with_overlay) { EXPECT_TRUE(graph->Compile()); EXPECT_TRUE(graph->Run()); - + std::vector output_values(golden.size()); EXPECT_TRUE(output_tensor->CopyDataFromTensor(output_values.data())); EXPECT_EQ(golden, output_values); @@ -145,7 +145,7 @@ TEST(Fuse_MaxpoolGrad, with_overlay_multi_channel_multi_batch) { auto input_tensor = graph->CreateTensor(input_spec); auto updates_tensor = graph->CreateTensor(updates_spec); - auto output_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); std::vector in_data = { 7, 2, 5, 3, 8, @@ -204,7 +204,7 @@ TEST(Fuse_MaxpoolGrad, with_overlay_multi_channel_multi_batch) { EXPECT_TRUE(graph->Compile()); EXPECT_TRUE(graph->Run()); - + std::vector output_values(golden.size()); EXPECT_TRUE(output_tensor->CopyDataFromTensor(output_values.data())); EXPECT_EQ(golden, output_values); diff --git a/src/tim/vx/ops/maxpoolwithargmax2_test.cc b/src/tim/vx/ops/maxpoolwithargmax2_test.cc index a90a0b0..e703a25 100644 --- a/src/tim/vx/ops/maxpoolwithargmax2_test.cc +++ b/src/tim/vx/ops/maxpoolwithargmax2_test.cc @@ -138,11 +138,13 @@ TEST(MaxpoolGrad, without_overlay) { out_shape, tim::vx::TensorAttribute::TRANSIENT); tim::vx::TensorSpec output_spec_values(tim::vx::DataType::FLOAT32, out_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + in_shape, tim::vx::TensorAttribute::OUTPUT); auto input_tensor = graph->CreateTensor(input_spec); auto output_tensor_indices = graph->CreateTensor(output_spec_indices); auto output_tensor_values = graph->CreateTensor(output_spec_values); - auto output_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); std::vector in_data = { 7, 2, 5, 3, 10, 2, @@ -210,11 +212,13 @@ TEST(MaxpoolGrad, with_overlay) { out_shape, tim::vx::TensorAttribute::TRANSIENT); tim::vx::TensorSpec output_spec_values(tim::vx::DataType::FLOAT32, out_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + in_shape, tim::vx::TensorAttribute::OUTPUT); auto input_tensor = graph->CreateTensor(input_spec); auto output_tensor_indices = graph->CreateTensor(output_spec_indices); auto output_tensor_values = graph->CreateTensor(output_spec_values); - auto output_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); std::vector in_data = { 7, 2, 5, 3, 8, @@ -282,11 +286,13 @@ TEST(MaxpoolGrad, with_overlay_multi_channel_multi_batch) { out_shape, tim::vx::TensorAttribute::TRANSIENT); tim::vx::TensorSpec output_spec_values(tim::vx::DataType::FLOAT32, out_shape, tim::vx::TensorAttribute::OUTPUT); + tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, + in_shape, tim::vx::TensorAttribute::OUTPUT); auto input_tensor = graph->CreateTensor(input_spec); auto output_tensor_indices = graph->CreateTensor(output_spec_indices); auto output_tensor_values = graph->CreateTensor(output_spec_values); - auto output_tensor = graph->CreateTensor(input_spec); + auto output_tensor = graph->CreateTensor(output_spec); std::vector in_data = { 7, 2, 5, 3, 8, diff --git a/src/tim/vx/tensor_private.h b/src/tim/vx/tensor_private.h index 2ab87d9..3c791d4 100644 --- a/src/tim/vx/tensor_private.h +++ b/src/tim/vx/tensor_private.h @@ -56,6 +56,7 @@ class TensorImpl : public Tensor { bool IsConstTensor() { return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; } + const void* GetDataRef() const { return data_; } GraphImpl* graph_; @@ -94,6 +95,7 @@ class TensorPlaceholder : public Tensor { bool IsConstTensor() { return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; } + const void* GetDataRef() const { return nullptr; } vsi_nn_tensor_id_t id_;