From e2180a6341d97385b4f7fde7a732a8e5e7b8b6e8 Mon Sep 17 00:00:00 2001 From: liyuenan <37231553+liyuenan2333@users.noreply.github.com> Date: Fri, 14 Jan 2022 12:34:38 +0800 Subject: [PATCH] Support that op's all inputs are constant (#264) Signed-off-by: yuenan.li Co-authored-by: yuenan.li --- include/tim/vx/graph.h | 4 +++- include/tim/vx/operation.h | 3 ++- include/tim/vx/ops/conv2d.h | 1 + src/tim/transform/layout_inference.cc | 10 ++++++++++ src/tim/vx/graph.cc | 13 +++++++++++-- src/tim/vx/graph_private.h | 4 ++-- src/tim/vx/operation.cc | 15 +++++++++++++++ src/tim/vx/ops/conv2d.cc | 8 ++++++++ 8 files changed, 52 insertions(+), 6 deletions(-) diff --git a/include/tim/vx/graph.h b/include/tim/vx/graph.h index be6fe37..d71483a 100644 --- a/include/tim/vx/graph.h +++ b/include/tim/vx/graph.h @@ -76,11 +76,13 @@ class Graph { virtual const std::vector> GetConsumersOp( std::shared_ptr tensor) const = 0; - virtual std::vector> GetProducerOp( + virtual std::shared_ptr GetProducerOp( std::shared_ptr tensor) = 0; virtual void PrintGraph() const = 0; + const std::vector> GetConstantInputs() const; + protected: std::vector> op_vector_; }; diff --git a/include/tim/vx/operation.h b/include/tim/vx/operation.h index eff4077..b2af4a4 100644 --- a/include/tim/vx/operation.h +++ b/include/tim/vx/operation.h @@ -48,8 +48,9 @@ class Operation { uint32_t accumulator_bits = 0); std::unique_ptr& impl(); const std::unique_ptr& impl() const; - + virtual const std::vector> ConstantInputsTensor() const; protected: + bool IsAllInputsConst() const; std::unique_ptr impl_; }; diff --git a/include/tim/vx/ops/conv2d.h b/include/tim/vx/ops/conv2d.h index 5c1b6fd..5e77ae3 100644 --- a/include/tim/vx/ops/conv2d.h +++ b/include/tim/vx/ops/conv2d.h @@ -85,6 +85,7 @@ class Conv2d : public DirectMapOp { std::shared_ptr Clone(std::shared_ptr& graph) const override; + const std::vector> ConstantInputsTensor() const override; protected: const uint32_t weights_; const PadType padding_; diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index a576a00..1e29e19 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -298,6 +298,16 @@ std::pair, MakeShared(t_src->GetShape().size())); } + auto const_inputs = src_graph->GetConstantInputs(); + for (auto const_in : const_inputs) { + auto input = + infer_graph->CreateTensor(const_in->GetSpec(), const_in->GetDataRef()); + layout_infer_ctx->UpdateTensorMap(const_in, input); + tensor_queue.push_back(const_in); + layout_infer_ctx->SetPermuteVector(const_in, + MakeShared(const_in->GetShape().size())); + } + while (!tensor_queue.empty()) { const auto& tensor = tensor_queue.front(); tensor_queue.pop_front(); diff --git a/src/tim/vx/graph.cc b/src/tim/vx/graph.cc index 06a6d9e..4819496 100644 --- a/src/tim/vx/graph.cc +++ b/src/tim/vx/graph.cc @@ -35,6 +35,15 @@ namespace tim { namespace vx { +const std::vector> Graph::GetConstantInputs() const { + std::vector> const_inputs; + for (auto op : op_vector_) { + auto const_i = op->ConstantInputsTensor(); + const_inputs.insert(const_inputs.end(), const_i.begin(), const_i.end()); + } + return const_inputs; + } + GraphImpl::GraphImpl(ContextImpl* context) : context_(context), graph_(vsi_nn_CreateGraph(context_->context(), 0, 0)), @@ -91,7 +100,7 @@ void GraphImpl::UpdateTensorProducerMap(const std::shared_ptr& tensor, const Operation* op) { for (const auto& added_op : op_vector_) { if (added_op.get() == op) { - tensor_producer_[tensor].push_back(added_op); + tensor_producer_[tensor] = added_op; } } } @@ -107,7 +116,7 @@ const std::vector> GraphImpl::GetConsumersOp( } } -std::vector> GraphImpl::GetProducerOp( +std::shared_ptr GraphImpl::GetProducerOp( std::shared_ptr tensor) { auto producer = tensor_producer_.find(tensor); if (tensor_producer_.end() != producer) { diff --git a/src/tim/vx/graph_private.h b/src/tim/vx/graph_private.h index e0242cc..9449659 100644 --- a/src/tim/vx/graph_private.h +++ b/src/tim/vx/graph_private.h @@ -60,7 +60,7 @@ class GraphImpl : public Graph { const Operation* op) override; const std::vector> GetConsumersOp( std::shared_ptr tensor) const override; - std::vector> GetProducerOp( + std::shared_ptr GetProducerOp( std::shared_ptr tensor) override; void PrintGraph() const override; @@ -87,7 +87,7 @@ class GraphImpl : public Graph { std::vector> inputs_tensor_; std::vector> outputs_tensor_; std::map, std::vector>> tensor_consumers_; - std::map, std::vector>> tensor_producer_; + std::map, std::shared_ptr> tensor_producer_; }; } // namespace vx diff --git a/src/tim/vx/operation.cc b/src/tim/vx/operation.cc index c24ab90..3369d92 100644 --- a/src/tim/vx/operation.cc +++ b/src/tim/vx/operation.cc @@ -76,5 +76,20 @@ Operation& Operation::BindOutputs( return *this; } +bool Operation::IsAllInputsConst() const{ + for (auto tensor : impl_->inputs_tensor_) { + if (!tensor->IsConstTensor()) return false; + } + return true; +} + +const std::vector> Operation::ConstantInputsTensor() const{ + if (this->IsAllInputsConst()) { + return impl_->inputs_tensor_; + } else { + return {}; + } +} + } // namespace vx } // namespace tim \ No newline at end of file diff --git a/src/tim/vx/ops/conv2d.cc b/src/tim/vx/ops/conv2d.cc index 14b50d8..a71494e 100644 --- a/src/tim/vx/ops/conv2d.cc +++ b/src/tim/vx/ops/conv2d.cc @@ -88,6 +88,14 @@ std::shared_ptr Conv2d::Clone(std::shared_ptr& graph) const { this->kernel_layout_); } +const std::vector> Conv2d::ConstantInputsTensor() const { + if (this->IsAllInputsConst()) { + return {this->impl_->inputs_tensor_[0]}; + } else { + return {}; + } +} + } // namespace ops } // namespace vx } // namespace tim \ No newline at end of file