From fae5cede7a6638611b56c2ecad19dd3d9901d0cd Mon Sep 17 00:00:00 2001 From: liyuenan <37231553+liyuenan2333@users.noreply.github.com> Date: Thu, 27 May 2021 10:33:44 +0800 Subject: [PATCH] Support layout inference for ops (#77) Signed-off-by: yuenan.li Co-authored-by: yuenan.li --- include/tim/vx/tensor.h | 2 +- src/tim/transform/layout_inference.cc | 11 +++ .../ops/activation_layout_inference.h | 24 ++++- src/tim/transform/ops/addn_layout_inference.h | 59 +++++++++++++ .../ops/elementwise_layout_inference.h | 4 +- .../ops/l2normalization_layout_inference.h | 59 +++++++++++++ src/tim/transform/ops/lrn_layout_inference.h | 65 ++++++++++++++ src/tim/transform/ops/op_layout_inference.cc | 85 ++++++++++++++++-- src/tim/transform/ops/op_layout_inference.h | 12 ++- src/tim/transform/ops/pad_layout_inference.h | 4 +- .../transform/ops/split_layout_inference.h | 64 ++++++++++++++ .../ops/stridedslice_layout_inference.h | 87 +++++++++++++++++++ src/tim/vx/tensor_private.h | 4 +- 13 files changed, 463 insertions(+), 17 deletions(-) create mode 100644 src/tim/transform/ops/addn_layout_inference.h create mode 100644 src/tim/transform/ops/l2normalization_layout_inference.h create mode 100644 src/tim/transform/ops/lrn_layout_inference.h create mode 100644 src/tim/transform/ops/split_layout_inference.h create mode 100644 src/tim/transform/ops/stridedslice_layout_inference.h diff --git a/include/tim/vx/tensor.h b/include/tim/vx/tensor.h index a2defaf..b334840 100644 --- a/include/tim/vx/tensor.h +++ b/include/tim/vx/tensor.h @@ -142,7 +142,7 @@ class Tensor { virtual const ShapeType& GetShape() = 0; virtual DataType GetDataType() = 0; virtual const Quantization& GetQuantization() = 0; - virtual const TensorSpec& GetSpec() = 0; + virtual TensorSpec& GetSpec() = 0; virtual uint32_t GetId() = 0; virtual bool CopyDataToTensor(const void* data, uint32_t size_in_bytes = 0) = 0; virtual bool CopyDataFromTensor(void* data) = 0; diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index e936ced..2b290be 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -45,6 +45,11 @@ #include "ops/reduce_layout_inference.h" #include "ops/fullyconnected_layout_inference.h" #include "ops/resize_layout_inference.h" +#include "ops/split_layout_inference.h" +#include "ops/stridedslice_layout_inference.h" +#include "ops/lrn_layout_inference.h" +#include "ops/l2normalization_layout_inference.h" +#include "ops/addn_layout_inference.h" #include #include @@ -211,6 +216,12 @@ std::vector> HandleLayoutInfer( REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_FCL2, FullyConnected); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RESIZE, Resize); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPLIT, Split); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_STRIDED_SLICE, StridedSlice); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LRN2, LRN); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_L2_NORMALIZE, L2Normalization); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ADDN, AddN); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PRELU, PRelu); default: VSILOGW("Op %d: Default layout inference pass.", op_id); assert(false); diff --git a/src/tim/transform/ops/activation_layout_inference.h b/src/tim/transform/ops/activation_layout_inference.h index 1b56d77..a8d06a3 100644 --- a/src/tim/transform/ops/activation_layout_inference.h +++ b/src/tim/transform/ops/activation_layout_inference.h @@ -80,7 +80,29 @@ class LeakyReluLayoutInfer : public OpLayoutInfer { } }; -// TODO(yzw): Add Prelu +class PReluLayoutInfer : public OpLayoutInfer { + public: + PReluLayoutInfer( + const std::shared_ptr op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + + void OnInputs( + std::vector>& next_tensors) override { + ReverseInputsPermuteVector(); + auto src_input = op_->impl()->InputsTensor()[0]; + auto input_pv = context_->GetPermuteVector(src_input); + auto prelu = context_->infer_graph_->CreateOperation( + op_->impl()->node()->nn_param.prelu.axis); + auto out_infer = CreateOutputsTensor(input_pv); + for (const auto& i_src : op_->impl()->InputsTensor()) { + (*prelu).BindInput(context_->GetMapedTensor(i_src)); + } + (*prelu).BindOutput(out_infer[0]); + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; using ReluLayoutInfer = ActivationLayoutInfer; using Relu1LayoutInfer = ActivationLayoutInfer; diff --git a/src/tim/transform/ops/addn_layout_inference.h b/src/tim/transform/ops/addn_layout_inference.h new file mode 100644 index 0000000..aaabdc7 --- /dev/null +++ b/src/tim/transform/ops/addn_layout_inference.h @@ -0,0 +1,59 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_ADDN_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_ADDN_LAYOUT_INFERENCE_H_ + +#include "src/tim/transform/ops/op_layout_inference.h" +#include "src/tim/vx/operation_private.h" +#include "tim/vx/ops/addn.h" + +namespace tim { +namespace transform { +class AddNLayoutInfer : public OpLayoutInfer { + public: + AddNLayoutInfer( + const std::shared_ptr& op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + void OnInputs( + std::vector>& next_tensors) override { + auto required_pv = AlignPermuteVectorForMutilInputs(); + uint32_t num_inputs = op_->impl()->input_cnt_; + + auto addn = + context_->infer_graph_->CreateOperation(num_inputs); + + for (const auto& i_src : op_->impl()->InputsTensor()) { + (*addn).BindInput(context_->GetMapedTensor(i_src)); + } + auto infer_out = CreateOutputsTensor(required_pv); + (*addn).BindOutput(infer_out[0]); + + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; +} // namespace transform +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/transform/ops/elementwise_layout_inference.h b/src/tim/transform/ops/elementwise_layout_inference.h index bc8fbff..43821ba 100644 --- a/src/tim/transform/ops/elementwise_layout_inference.h +++ b/src/tim/transform/ops/elementwise_layout_inference.h @@ -42,7 +42,7 @@ class ElementWiseLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { - auto required_pv = AlignPermuteVectorForMutilInputs(); + auto required_pv = AlignPermuteVectorForElementWise(); auto elementwise = context_->infer_graph_->CreateOperation(); for (const auto& i_src : op_->impl()->InputsTensor()) { (*elementwise).BindInput(context_->GetMapedTensor(i_src)); @@ -63,7 +63,7 @@ class MultiplyLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { - auto required_pv = AlignPermuteVectorForMutilInputs(); + auto required_pv = AlignPermuteVectorForElementWise(); auto multiply = context_->infer_graph_->CreateOperation( op_->impl()->node()->nn_param.multiply.scale); diff --git a/src/tim/transform/ops/l2normalization_layout_inference.h b/src/tim/transform/ops/l2normalization_layout_inference.h new file mode 100644 index 0000000..a9c5f6e --- /dev/null +++ b/src/tim/transform/ops/l2normalization_layout_inference.h @@ -0,0 +1,59 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_L2_NORMALIZATION_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_L2_NORMALIZATION_LAYOUT_INFERENCE_H_ + +#include "src/tim/transform/ops/op_layout_inference.h" +#include "src/tim/vx/operation_private.h" +#include "tim/vx/ops/l2normalization.h" + +namespace tim { +namespace transform { +class L2NormalizationLayoutInfer : public OpLayoutInfer { + public: + L2NormalizationLayoutInfer( + const std::shared_ptr& op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + void OnInputs( + std::vector>& next_tensors) override { + auto src_input = op_->impl()->InputsTensor()[0]; + auto input_pv = context_->GetPermuteVector(src_input); + + int32_t axis = + MapAxis(input_pv->AsStdVec(), op_->impl()->node()->nn_param.lrn.axis); + + auto l2norm = + context_->infer_graph_->CreateOperation(axis); + auto infer_out = CreateOutputsTensor(input_pv); + (*l2norm).BindInput(context_->GetMapedTensor(src_input)); + (*l2norm).BindOutput(infer_out[0]); + + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; +} // namespace transform +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/transform/ops/lrn_layout_inference.h b/src/tim/transform/ops/lrn_layout_inference.h new file mode 100644 index 0000000..c3cb2f8 --- /dev/null +++ b/src/tim/transform/ops/lrn_layout_inference.h @@ -0,0 +1,65 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_LRN_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_LRN_LAYOUT_INFERENCE_H_ + +#include "tim/vx/ops/localresponsenormalization.h" + +#include "src/tim/transform/ops/op_layout_inference.h" +#include "src/tim/vx/operation_private.h" + +namespace tim { +namespace transform { +class LRNLayoutInfer : public OpLayoutInfer { + public: + LRNLayoutInfer( + const std::shared_ptr& op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + void OnInputs( + std::vector>& next_tensors) override { + auto src_input = op_->impl()->InputsTensor()[0]; + auto input_pv = context_->GetPermuteVector(src_input); + + uint32_t size = op_->impl()->node()->nn_param.lrn.size; + float alpha = op_->impl()->node()->nn_param.lrn.alpha; + float beta = op_->impl()->node()->nn_param.lrn.beta; + float bias = op_->impl()->node()->nn_param.lrn.bias; + int32_t axis = + MapAxis(input_pv->AsStdVec(), op_->impl()->node()->nn_param.lrn.axis); + + auto lrn = context_->infer_graph_ + ->CreateOperation( + size, alpha, beta, bias, axis); + auto infer_out = CreateOutputsTensor(input_pv); + (*lrn).BindInput(context_->GetMapedTensor(src_input)); + (*lrn).BindOutput(infer_out[0]); + + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; +} // namespace transform +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index d9b1ef7..e9008b9 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -180,7 +180,7 @@ OpLayoutInfer::AlignPermuteVectorForMutilInputs() { : perm_out = PermuteConstTensor(i_src, required_pv); } else { auto final_pv = - context_->GetPermuteVector(i_src)->Reverse()->Add(required_pv); + context_->GetPermuteVector(i_src)->Reverse()->Add(required_pv); final_pv->IsAligned() ? perm_out = context_->GetMapedTensor(i_src) : perm_out = InsertPermute( context_->GetMapedTensor(i_src), final_pv); @@ -192,6 +192,49 @@ OpLayoutInfer::AlignPermuteVectorForMutilInputs() { return required_pv; } +std::shared_ptr +OpLayoutInfer::AlignPermuteVectorForElementWise() { + auto src_inputs = op_->impl()->InputsTensor(); + std::shared_ptr required_pv = nullptr; + std::shared_ptr ref_input; + for (const auto& in : src_inputs) { + if (!in->IsConstTensor()) { + required_pv = context_->GetPermuteVector(in); + ref_input = in; + break; + } + } + + for (auto i_src : src_inputs) { + std::shared_ptr perm_out; + if (i_src->IsConstTensor()) { + if (required_pv->IsAligned()) { + perm_out = context_->infer_graph_->CreateTensor(i_src->GetSpec(), + i_src->GetDataRef()); + } else if (i_src->GetShape().size() == required_pv->Rank()) { + perm_out = PermuteConstTensor(i_src, required_pv); + // need shape expansion + } else { + auto ref_shape = ref_input->GetShape(); + auto origin_shape = i_src->GetShape(); + auto expanded_shape = GetExpandedShape(ref_shape, origin_shape); + i_src->GetSpec().SetShape(expanded_shape); + perm_out = PermuteConstTensor(i_src, required_pv); + } + } else { + auto final_pv = + context_->GetPermuteVector(i_src)->Reverse()->Add(required_pv); + final_pv->IsAligned() + ? perm_out = context_->GetMapedTensor(i_src) + : perm_out = InsertPermute(context_->GetMapedTensor(i_src), final_pv); + } + context_->UpdateTensorMap(i_src, perm_out); + context_->SetPermuteVector(i_src, required_pv); + } + return required_pv; +} + + void OpLayoutInfer::ReverseInputsPermuteVector() { for (const auto& i_src : op_->impl()->InputsTensor()) { std::shared_ptr perm_out; @@ -213,6 +256,21 @@ void OpLayoutInfer::ReverseInputsPermuteVector() { } } +std::vector OpLayoutInfer::GetExpandedShape( + const std::vector& ref_shape, + const std::vector& origin_shape) { + std::vector expanded_shape; + for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) { + if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) { + expanded_shape.push_back(origin_shape[j]); + ++j; + } else { + expanded_shape.push_back(1); + } + } + return expanded_shape; +} + bool OpLayoutInfer::TransposeConstTensorData( const std::shared_ptr& input, const std::shared_ptr& pv, std::vector& out_data) { @@ -265,16 +323,29 @@ std::shared_ptr OpLayoutInfer::PermuteConstTensor( return context_->infer_graph_->CreateTensor(dst_spec, data.data()); } -std::vector OpLayoutInfer::MapPadding(const std::vector& perm, - const std::vector& padding) { - assert(perm.size() == padding.size()); - std::vector r(padding.size()); +std::vector OpLayoutInfer::MapMultipleAxis( + const std::vector& perm, const std::vector& axises) { + assert(perm.size() == axises.size()); + std::vector r(axises.size()); - for (uint32_t i = 0; i < padding.size(); ++i) { - r[i] = padding[perm[i]]; + for (uint32_t i = 0; i < axises.size(); ++i) { + r[i] = axises[perm[i]]; } return r; } + +std::vector OpLayoutInfer::MapMultipleAxis( + const std::vector& perm, const std::vector& axises) { + assert(perm.size() == axises.size()); + std::vector r(axises.size()); + + for (uint32_t i = 0; i < axises.size(); ++i) { + r[i] = axises[perm[i]]; + } + + return r; +} + } // namespace transform } // namespace tim \ No newline at end of file diff --git a/src/tim/transform/ops/op_layout_inference.h b/src/tim/transform/ops/op_layout_inference.h index 61cb6cb..b20f08f 100644 --- a/src/tim/transform/ops/op_layout_inference.h +++ b/src/tim/transform/ops/op_layout_inference.h @@ -71,8 +71,14 @@ class OpLayoutInfer { std::shared_ptr AlignPermuteVectorForMutilInputs(); + std::shared_ptr AlignPermuteVectorForElementWise(); + void ReverseInputsPermuteVector(); + std::vector GetExpandedShape( + const std::vector& ref_shape, + const std::vector& origin_shape); + bool TransposeConstTensorData(const std::shared_ptr& input, const std::shared_ptr& pv, std::vector& out_data); @@ -81,8 +87,10 @@ class OpLayoutInfer { const std::shared_ptr& input, const std::shared_ptr& pv); - std::vector MapPadding(const std::vector& perm, - const std::vector& padding); + std::vector MapMultipleAxis(const std::vector& perm, + const std::vector& axises); + std::vector MapMultipleAxis(const std::vector& perm, + const std::vector& axises); protected: const std::shared_ptr op_; diff --git a/src/tim/transform/ops/pad_layout_inference.h b/src/tim/transform/ops/pad_layout_inference.h index 4d62b26..3388604 100644 --- a/src/tim/transform/ops/pad_layout_inference.h +++ b/src/tim/transform/ops/pad_layout_inference.h @@ -54,8 +54,8 @@ class PadLayoutInfer : public OpLayoutInfer { int32_t pad_value = op_->impl()->node()->nn_param.pad.const_val; if (!input_pv->IsAligned()) { - front_size = MapPadding(input_pv->AsStdVec(), front_size); - back_size = MapPadding(input_pv->AsStdVec(), back_size); + front_size = MapMultipleAxis(input_pv->AsStdVec(), front_size); + back_size = MapMultipleAxis(input_pv->AsStdVec(), back_size); } auto pad = context_->infer_graph_->CreateOperation( diff --git a/src/tim/transform/ops/split_layout_inference.h b/src/tim/transform/ops/split_layout_inference.h new file mode 100644 index 0000000..d76d10e --- /dev/null +++ b/src/tim/transform/ops/split_layout_inference.h @@ -0,0 +1,64 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_SPLIT_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_SPLIT_LAYOUT_INFERENCE_H_ + +#include "tim/vx/ops/split.h" + +#include "src/tim/transform/ops/op_layout_inference.h" +#include "src/tim/transform/permute_vector.h" +#include "src/tim/vx/operation_private.h" + +namespace tim { +namespace transform { +class SplitLayoutInfer : public OpLayoutInfer { + public: + SplitLayoutInfer( + const std::shared_ptr op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + void OnInputs( + std::vector>& next_tensors) override { + auto input_tensor = op_->impl()->InputsTensor()[0]; + uint32_t slices_num = op_->impl()->node()->nn_param.split.slices_num; + std::vector slices(slices_num); + memcpy(slices.data(), op_->impl()->node()->nn_param.split.slices, + slices_num * sizeof(uint32_t)); + auto input_pv = context_->GetPermuteVector(input_tensor); + uint32_t axis = + MapAxis(input_pv->AsStdVec(), op_->impl()->node()->nn_param.split.axis); + auto split = + context_->infer_graph_->CreateOperation(axis, slices); + auto infer_out = CreateOutputsTensor(input_pv); + (*split).BindInput(context_->GetMapedTensor(input_tensor)); + (*split).BindOutputs(infer_out); + for (const auto& out : op_->impl()->OutputsTensor()) { + context_->SetPermuteVector(out, input_pv); + next_tensors.push_back(out); + } + } +}; +} // namespace transform +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/transform/ops/stridedslice_layout_inference.h b/src/tim/transform/ops/stridedslice_layout_inference.h new file mode 100644 index 0000000..04bea42 --- /dev/null +++ b/src/tim/transform/ops/stridedslice_layout_inference.h @@ -0,0 +1,87 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_STRIDEDSLICE_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_STRIDEDSLICE_LAYOUT_INFERENCE_H_ + +#include "tim/vx/ops/stridedslice.h" + +#include "src/tim/transform/ops/op_layout_inference.h" +#include "src/tim/transform/permute_vector.h" +#include "src/tim/vx/operation_private.h" + +namespace tim { +namespace transform { +class StridedSliceLayoutInfer : public OpLayoutInfer { + public: + StridedSliceLayoutInfer( + const std::shared_ptr op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + void OnInputs( + std::vector>& next_tensors) override { + auto src_input = op_->impl()->InputsTensor()[0]; + auto input_pv = context_->GetPermuteVector(src_input); + + int32_t begin_mask = op_->impl()->node()->nn_param.strided_slice.begin_mask; + int32_t end_mask = op_->impl()->node()->nn_param.strided_slice.end_mask; + int32_t shrink_axis_mask = + op_->impl()->node()->nn_param.strided_slice.shrink_axis_mask; + uint32_t begin_dims_num = + op_->impl()->node()->nn_param.strided_slice.begin_dims_num; + std::vector begin_dims(begin_dims_num); + memcpy(begin_dims.data(), + op_->impl()->node()->nn_param.strided_slice.begin_dims, + begin_dims_num * sizeof(uint32_t)); + uint32_t end_dims_num = + op_->impl()->node()->nn_param.strided_slice.end_dims_num; + std::vector end_dims(end_dims_num); + memcpy(end_dims.data(), + op_->impl()->node()->nn_param.strided_slice.end_dims, + end_dims_num * sizeof(uint32_t)); + uint32_t stride_dims_num = + op_->impl()->node()->nn_param.strided_slice.stride_dims_num; + std::vector stride_dims(stride_dims_num); + memcpy(stride_dims.data(), + op_->impl()->node()->nn_param.strided_slice.stride_dims, + stride_dims_num * sizeof(uint32_t)); + + begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims); + end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims); + stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims); + + auto strided_slice = + context_->infer_graph_->CreateOperation( + begin_dims, end_dims, stride_dims, begin_mask, end_mask, + shrink_axis_mask); + auto infer_out = CreateOutputsTensor(input_pv); + (*strided_slice).BindInput(context_->GetMapedTensor(src_input)); + (*strided_slice).BindOutput(infer_out[0]); + + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; +} // namespace transform +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/vx/tensor_private.h b/src/tim/vx/tensor_private.h index 602a69a..367586e 100644 --- a/src/tim/vx/tensor_private.h +++ b/src/tim/vx/tensor_private.h @@ -42,7 +42,7 @@ class TensorImpl : public Tensor { const ShapeType& GetShape() { return spec_.shape_; } DataType GetDataType() { return spec_.datatype_; } const Quantization& GetQuantization() { return spec_.quantization_; } - const TensorSpec& GetSpec() { return spec_; } + TensorSpec& GetSpec() { return spec_; } uint32_t GetId(); bool CopyDataToTensor(const void* data, uint32_t size = 0); bool CopyDataFromTensor(void* data); @@ -66,7 +66,7 @@ class TensorPlaceholder : public Tensor { const ShapeType& GetShape() { return spec_.shape_; } DataType GetDataType() { return spec_.datatype_; } const Quantization& GetQuantization() { return spec_.quantization_; } - const TensorSpec& GetSpec() { return spec_; } + TensorSpec& GetSpec() { return spec_; } uint32_t GetId() { return id_; }; bool CopyDataToTensor(const void* data, uint32_t size = 0) { (void)data, void(size);