From 75d39e2cfd6f91c73bf62cefc6411a420b7135d7 Mon Sep 17 00:00:00 2001 From: liyuenan <37231553+liyuenan2333@users.noreply.github.com> Date: Wed, 29 Dec 2021 11:06:28 +0800 Subject: [PATCH] Support layout inference for transpose (#250) Added interface GetProdeucerOp(tensor) in Graph Signed-off-by: yuenan.li --- include/tim/vx/graph.h | 7 ++ src/tim/transform/layout_inference.cc | 2 + .../ops/transpose_layout_inference.h | 76 +++++++++++++++++++ src/tim/vx/graph.cc | 20 +++++ src/tim/vx/graph_private.h | 26 ++++--- src/tim/vx/operation.cc | 1 + 6 files changed, 122 insertions(+), 10 deletions(-) create mode 100644 src/tim/transform/ops/transpose_layout_inference.h diff --git a/include/tim/vx/graph.h b/include/tim/vx/graph.h index 4527320..be6fe37 100644 --- a/include/tim/vx/graph.h +++ b/include/tim/vx/graph.h @@ -69,9 +69,16 @@ class Graph { const std::shared_ptr& tensor, const Operation* op) = 0; + virtual void UpdateTensorProducerMap( + const std::shared_ptr& tensor, + const Operation* op) = 0; + virtual const std::vector> GetConsumersOp( std::shared_ptr tensor) const = 0; + virtual std::vector> GetProducerOp( + std::shared_ptr tensor) = 0; + virtual void PrintGraph() const = 0; protected: diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index 088c012..7a94277 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -59,6 +59,7 @@ #include "ops/deconv2d_layout_inference.h" #include "ops/batchnorm_layout_inference.h" #include "ops/default_layout_inference.h" +#include "ops/transpose_layout_inference.h" #include #include @@ -257,6 +258,7 @@ std::vector> HandleLayoutInfer( REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ARGMIN, Arg); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DECONVOLUTION, DeConv2d); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH_NORM, BatchNorm); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PERMUTE, Transpose); REGIST_LOGICAL_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_OPS); REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE); // use default layout inference diff --git a/src/tim/transform/ops/transpose_layout_inference.h b/src/tim/transform/ops/transpose_layout_inference.h new file mode 100644 index 0000000..1f332e2 --- /dev/null +++ b/src/tim/transform/ops/transpose_layout_inference.h @@ -0,0 +1,76 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_TRANSPOSE_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_TRANSPOSE_LAYOUT_INFERENCE_H_ + +#include "tim/vx/ops/transpose.h" + +#include "ops/op_layout_inference.h" +#include "permute_vector.h" +#include "operation_private.h" + +namespace tim { +namespace transform { +class TransposeLayoutInfer : public OpLayoutInfer { + public: + TransposeLayoutInfer( + const std::shared_ptr op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + + void OnInputs( + std::vector>& next_tensors) override { + auto src_input = op_->impl()->InputsTensor()[0]; + auto infer_input = context_->GetMapedTensor(src_input); + auto input_pv = context_->GetPermuteVector(src_input); + + std::vector perm(op_->impl()->node()->nn_param.permute.dim_num); + memcpy(perm.data(), op_->impl()->node()->nn_param.permute.perm, + op_->impl()->node()->nn_param.permute.dim_num * sizeof(uint32_t)); + IPermuteVectorPtr perm_pv = MakeShared(perm.size()); + for (uint32_t i = 0; i < perm.size(); i++) { + perm_pv->At(i) = perm[i]; + } + + IPermuteVectorPtr final_pv = input_pv->Reverse()->Add(perm_pv); + + if (final_pv->IsAligned()) { + //skip transpose op by treating its input as its output. + context_->UpdateTensorMap(op_->impl()->OutputsTensor()[0], infer_input); + } else { + auto transpose_op = + context_->infer_graph_->CreateOperation( + final_pv->AsStdVec()); + transpose_op->BindInput(infer_input); + auto infer_out = CreateOutputsTensor(final_pv); + transpose_op->BindOutput(infer_out[0]); + } + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], MakeShared(perm.size())); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; + +} // namespace transform +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/vx/graph.cc b/src/tim/vx/graph.cc index abda269..9a47c50 100644 --- a/src/tim/vx/graph.cc +++ b/src/tim/vx/graph.cc @@ -87,6 +87,15 @@ void GraphImpl::UpdateTensorConsumersMap(const std::shared_ptr& tensor, } } +void GraphImpl::UpdateTensorProducerMap(const std::shared_ptr& tensor, + const Operation* op) { + for (const auto& added_op : op_vector_) { + if (added_op.get() == op) { + tensor_producer_[tensor].push_back(added_op); + } + } +} + const std::vector> GraphImpl::GetConsumersOp( std::shared_ptr tensor) const { auto consumers = tensor_consumers_.find(tensor); @@ -98,6 +107,17 @@ const std::vector> GraphImpl::GetConsumersOp( } } +std::vector> GraphImpl::GetProducerOp( + std::shared_ptr tensor) { + auto producer = tensor_producer_.find(tensor); + if (tensor_producer_.end() != producer) { + return producer->second; + } else { + VSILOGD("Tensor has no producer, may be graph input."); + return {}; + } +} + void GraphImpl::PrintGraph() const { vsi_nn_PrintGraph(this->graph_); } std::shared_ptr GraphImpl::CreateTensor(const TensorSpec& spec, diff --git a/src/tim/vx/graph_private.h b/src/tim/vx/graph_private.h index 241536f..e0242cc 100644 --- a/src/tim/vx/graph_private.h +++ b/src/tim/vx/graph_private.h @@ -56,19 +56,24 @@ class GraphImpl : public Graph { void UpdateTensorConsumersMap(const std::shared_ptr& tensor, const Operation* op) override; + void UpdateTensorProducerMap(const std::shared_ptr& tensor, + const Operation* op) override; const std::vector> GetConsumersOp( std::shared_ptr tensor) const override; - void PrintGraph() const override; - /// Implement parents' virtual functions - std::shared_ptr CreateTensor(const TensorSpec& spec, - const void* data = nullptr) override; - std::shared_ptr CreateTensor(const TensorSpec& spec, - const DmaBufferDesc& dmafd) override; - std::shared_ptr CreateTensorPlaceHolder() override; - bool Compile() override; + std::vector> GetProducerOp( + std::shared_ptr tensor) override; - bool CompileToBinary(void* buf, size_t* size) override; - bool Run() override; + void PrintGraph() const override; + + std::shared_ptr CreateTensor(const TensorSpec& spec, + const void* data = nullptr) override; + std::shared_ptr CreateTensor(const TensorSpec& spec, + const DmaBufferDesc& dmafd) override; + std::shared_ptr CreateTensorPlaceHolder() override; + + bool Compile() override; + bool CompileToBinary(void* buf, size_t* size) override; + bool Run() override; protected: ContextImpl* context_; @@ -82,6 +87,7 @@ class GraphImpl : public Graph { std::vector> inputs_tensor_; std::vector> outputs_tensor_; std::map, std::vector>> tensor_consumers_; + std::map, std::vector>> tensor_producer_; }; } // namespace vx diff --git a/src/tim/vx/operation.cc b/src/tim/vx/operation.cc index a5f8572..f3c825a 100644 --- a/src/tim/vx/operation.cc +++ b/src/tim/vx/operation.cc @@ -100,6 +100,7 @@ Operation& Operation::BindInput(const std::shared_ptr& tensor) { Operation& Operation::BindOutput(const std::shared_ptr& tensor) { impl_->BindOutput(tensor); + impl_->graph_->UpdateTensorProducerMap(tensor, this); return *this; }