From 0dc7a3465e391749fffa54df6879104244f13cc0 Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Mon, 11 Dec 2023 16:59:37 +0800 Subject: [PATCH] fix const tensor align bug in AlignPermuteVectorForElementWise (#666) * fix const tensor align bug in AlignPermuteVectorForElementWise Signed-off-by: Chen * fix build issue use android ndk Type: Bug fix Signed-off-by: Chen * Fix inappropriate comments for reduce layoutinfer Type: Code refine Signed-off-by: Chen --------- Signed-off-by: Chen Co-authored-by: Chen --- include/tim/vx/platform/platform.h | 4 ++++ src/tim/transform/ops/op_layout_inference.cc | 19 +++++++++---------- .../transform/ops/reduce_layout_inference.h | 8 +++++--- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/include/tim/vx/platform/platform.h b/include/tim/vx/platform/platform.h index cc2799d..263042b 100644 --- a/include/tim/vx/platform/platform.h +++ b/include/tim/vx/platform/platform.h @@ -59,7 +59,11 @@ std::shared_ptr CreateExecutableSet( class IDevice { public: using device_id_t = uint32_t; + #ifdef __ANDROID_NDK__ + typedef bool (*async_callback)(const void*); + #else using async_callback = std::function; + #endif using data_t = const void*; virtual ~IDevice(){}; virtual bool Submit(const std::shared_ptr& graph) = 0; diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index 1d1f0e8..7275a28 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -231,11 +231,13 @@ OpLayoutInfer::AlignPermuteVectorForElementWise() { auto src_inputs = op_->impl()->InputsTensor(); std::shared_ptr required_pv = nullptr; std::shared_ptr ref_input; + int32_t ref_rank = 0; for (const auto& in : src_inputs) { - if (!in->IsConstTensor()) { + int32_t rank = in->GetShape().size(); + if (!in->IsConstTensor() && rank > ref_rank) { required_pv = context_->GetPermuteVector(in); ref_input = in; - break; + ref_rank = rank; } } @@ -297,14 +299,11 @@ void OpLayoutInfer::ReverseInputsPermuteVector() { std::vector OpLayoutInfer::GetExpandedShape( const std::vector& ref_shape, const std::vector& origin_shape) { - std::vector expanded_shape; - for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) { - if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) { - expanded_shape.push_back(origin_shape[j]); - ++j; - } else { - expanded_shape.push_back(1); - } + std::vector expanded_shape(origin_shape); + int32_t ref_rank = ref_shape.size(); + int32_t origin_rank = origin_shape.size(); + for (int32_t i = 0; i < ref_rank; ++i) { + if (i >= origin_rank) expanded_shape.push_back(1); } return expanded_shape; } diff --git a/src/tim/transform/ops/reduce_layout_inference.h b/src/tim/transform/ops/reduce_layout_inference.h index 88ae721..6766985 100644 --- a/src/tim/transform/ops/reduce_layout_inference.h +++ b/src/tim/transform/ops/reduce_layout_inference.h @@ -45,21 +45,23 @@ class ReduceLayoutInfer : public OpLayoutInfer { std::vector>& next_tensor) override { auto t_src = op_->impl()->InputsTensor()[0]; auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]); - std::set axis_set; //Same value as new_axis, convenient for searching + std::set axis_set; // Save unique axis values std::vector new_axis, pv_reduced; - for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num; - ++i) { + uint32_t axis_num = op_->impl()->node()->nn_param.reduce.axis_num; + for (uint32_t i = 0; i < axis_num; ++i) { int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i]; if (axis < 0) { axis += pv->Rank(); } axis = MapAxis(pv->AsStdVec(), axis); + // Save unique axis values for calculating pv length axis_set.insert(axis); new_axis.push_back(axis); } auto reduce = context_->infer_graph_->CreateOperation( new_axis, op_->impl()->node()->nn_param.reduce.keep_dim); (*reduce).BindInput(context_->GetMapedTensor(t_src)); + if (op_->impl()->node()->nn_param.reduce.keep_dim) { auto otensor_infer = CreateOutputsTensor(pv); (*reduce).BindOutput(otensor_infer[0]);