From 75882d4195e71eab6a521e9959bb05345851e262 Mon Sep 17 00:00:00 2001 From: Chen Feiyue <69809761+chenfeiyue-cfy@users.noreply.github.com> Date: Sun, 25 Jun 2023 09:24:41 +0800 Subject: [PATCH] Added new_axis_mask param for stridedslice (#600) Add another constructor for stridedslice when new_axis_mask is set The layout inference need to reconstruct the axis mapping when new_axis_mask is set(TODO) Type: New Feature Signed-off-by: Feiyue Chen --- include/tim/vx/ops/stridedslice.h | 5 ++ include/tim/vx/ops/stridedslice.json | 5 ++ .../ops/stridedslice_layout_inference.h | 78 ++++++++++--------- src/tim/vx/ops/stridedslice.cc | 17 +++- 4 files changed, 67 insertions(+), 38 deletions(-) diff --git a/include/tim/vx/ops/stridedslice.h b/include/tim/vx/ops/stridedslice.h index 429f94c..1474759 100644 --- a/include/tim/vx/ops/stridedslice.h +++ b/include/tim/vx/ops/stridedslice.h @@ -58,6 +58,10 @@ class StridedSlice : public BuiltinOp { const std::vector end_dims, const std::vector stride_dims, int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask); + StridedSlice(Graph* graph, const std::vector begin_dims, + const std::vector end_dims, + const std::vector stride_dims, int32_t begin_mask, + int32_t end_mask, int32_t shrink_axis_mask, int32_t new_axis_mask); std::shared_ptr Clone(std::shared_ptr& graph) const override; @@ -68,6 +72,7 @@ class StridedSlice : public BuiltinOp { int32_t begin_mask_; int32_t end_mask_; int32_t shrink_axis_mask_; + int32_t new_axis_mask_; }; } // namespace ops diff --git a/include/tim/vx/ops/stridedslice.json b/include/tim/vx/ops/stridedslice.json index e48ea70..1530823 100755 --- a/include/tim/vx/ops/stridedslice.json +++ b/include/tim/vx/ops/stridedslice.json @@ -19,6 +19,11 @@ }, {"name": "shrink_axis_mask", "dtype": "int32_t" + }, + {"name": "new_axis_mask", + "dtype": "int32_t", + "Optional": "true", + "default": "0" } ] } diff --git a/src/tim/transform/ops/stridedslice_layout_inference.h b/src/tim/transform/ops/stridedslice_layout_inference.h index 6679f5c..1cfca83 100644 --- a/src/tim/transform/ops/stridedslice_layout_inference.h +++ b/src/tim/transform/ops/stridedslice_layout_inference.h @@ -47,6 +47,8 @@ class StridedSliceLayoutInfer : public OpLayoutInfer { int32_t end_mask = op_->impl()->node()->nn_param.strided_slice.end_mask; int32_t shrink_axis_mask = op_->impl()->node()->nn_param.strided_slice.shrink_axis_mask; + int32_t new_axis_mask = + op_->impl()->node()->nn_param.strided_slice.new_axis_mask; uint32_t begin_dims_num = op_->impl()->node()->nn_param.strided_slice.begin_dims_num; std::vector begin_dims(begin_dims_num); @@ -66,45 +68,51 @@ class StridedSliceLayoutInfer : public OpLayoutInfer { op_->impl()->node()->nn_param.strided_slice.stride_dims, stride_dims_num * sizeof(uint32_t)); - begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims); - end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims); - stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims); + if (!new_axis_mask) { + begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims); + end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims); + stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims); - shrink_axis_mask = MapMask(input_pv->AsStdVec(), shrink_axis_mask); - begin_mask = MapMask(input_pv->AsStdVec(), begin_mask); - end_mask = MapMask(input_pv->AsStdVec(), end_mask); - auto strided_slice = - context_->infer_graph_->CreateOperation( - begin_dims, end_dims, stride_dims, begin_mask, end_mask, - shrink_axis_mask); - // The following is the normalized dimension calculation - std::set remaind_set; - std::vector remaind_axis; - for (uint32_t i = 0; i < input_pv->AsStdVec().size(); ++i) - if ((shrink_axis_mask & (1 << i)) == 0) { - remaind_axis.push_back( - input_pv->AsStdVec() - [i]); // Store unnormalized dimensionality reduction pv values - remaind_set.insert(input_pv->AsStdVec()[i]); - } - // Traverse the input pv to find a dimension smaller than the current remaining dimension - auto out_pv = MakeShared(remaind_axis.size()); - for (uint32_t i = 0; i < remaind_axis.size(); ++i) { - uint32_t cnt = 0; - for (uint32_t j = 0; j < input_pv->AsStdVec().size(); j++) { - if (input_pv->AsStdVec()[j] < remaind_axis[i] && - remaind_set.end() == remaind_set.find(input_pv->AsStdVec()[j])) { - cnt++; // Record the number of dimensions smaller than the current dimension + shrink_axis_mask = MapMask(input_pv->AsStdVec(), shrink_axis_mask); + begin_mask = MapMask(input_pv->AsStdVec(), begin_mask); + end_mask = MapMask(input_pv->AsStdVec(), end_mask); + auto strided_slice = + context_->infer_graph_->CreateOperation( + begin_dims, end_dims, stride_dims, begin_mask, end_mask, + shrink_axis_mask); + // The following is the normalized dimension calculation + std::set remained_set; + std::vector remained_axis; + for (uint32_t i = 0; i < input_pv->AsStdVec().size(); ++i) + if ((shrink_axis_mask & (1 << i)) == 0) { + remained_axis.push_back( + input_pv->AsStdVec() + [i]); // Store unnormalized dimensionality reduction pv values + remained_set.insert(input_pv->AsStdVec()[i]); } + // Traverse the input pv to find a dimension smaller than the current remaining dimension + auto out_pv = MakeShared(remained_axis.size()); + for (uint32_t i = 0; i < remained_axis.size(); ++i) { + uint32_t cnt = 0; + for (uint32_t j = 0; j < input_pv->AsStdVec().size(); j++) { + if (input_pv->AsStdVec()[j] < remained_axis[i] && + remained_set.end() == + remained_set.find(input_pv->AsStdVec()[j])) { + cnt++; // Record the number of dimensions smaller than the current dimension + } + } + out_pv->At(i) = remained_axis[i] - cnt; } - out_pv->At(i) = remaind_axis[i] - cnt; - } - auto infer_out = CreateOutputsTensor(out_pv); - (*strided_slice).BindInput(context_->GetMapedTensor(src_input)); - (*strided_slice).BindOutput(infer_out[0]); - context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv); - next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + auto infer_out = CreateOutputsTensor(out_pv); + (*strided_slice).BindInput(context_->GetMapedTensor(src_input)); + (*strided_slice).BindOutput(infer_out[0]); + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } else { //TODO + VSILOGE("ERROR: implement not supported yet for new_axis_mask !=0"); + assert(false); + } } }; } // namespace transform diff --git a/src/tim/vx/ops/stridedslice.cc b/src/tim/vx/ops/stridedslice.cc index 8781cff..8a916b3 100644 --- a/src/tim/vx/ops/stridedslice.cc +++ b/src/tim/vx/ops/stridedslice.cc @@ -34,18 +34,21 @@ StridedSlice::StridedSlice(Graph* graph, const std::vector begin_dims, const std::vector end_dims, const std::vector stride_dims, int32_t begin_mask, int32_t end_mask, - int32_t shrink_axis_mask) + int32_t shrink_axis_mask, int32_t new_axis_mask) : BuiltinOp(graph, VSI_NN_OP_STRIDED_SLICE), begin_dims_(std::move(begin_dims)), end_dims_(std::move(end_dims)), stride_dims_(std::move(stride_dims)), begin_mask_(begin_mask), end_mask_(end_mask), - shrink_axis_mask_(shrink_axis_mask) { + shrink_axis_mask_(shrink_axis_mask), + new_axis_mask_(new_axis_mask) { this->impl()->node()->nn_param.strided_slice.begin_mask = begin_mask_; this->impl()->node()->nn_param.strided_slice.end_mask = end_mask_; this->impl()->node()->nn_param.strided_slice.shrink_axis_mask = shrink_axis_mask_; + this->impl()->node()->nn_param.strided_slice.new_axis_mask = + new_axis_mask_; this->impl()->node()->nn_param.strided_slice.begin_dims = begin_dims_.data(); this->impl()->node()->nn_param.strided_slice.begin_dims_num = begin_dims_.size(); @@ -57,11 +60,19 @@ StridedSlice::StridedSlice(Graph* graph, const std::vector begin_dims, stride_dims_.size(); } +StridedSlice::StridedSlice(Graph* graph, const std::vector begin_dims, + const std::vector end_dims, + const std::vector stride_dims, + int32_t begin_mask, int32_t end_mask, + int32_t shrink_axis_mask) + : StridedSlice(graph, begin_dims, end_dims, stride_dims, + begin_mask, end_mask, shrink_axis_mask, 0) {} + std::shared_ptr StridedSlice::Clone( std::shared_ptr& graph) const { return graph->CreateOperation( this->begin_dims_, this->end_dims_, this->stride_dims_, this->begin_mask_, - this->end_mask_, this->shrink_axis_mask_); + this->end_mask_, this->shrink_axis_mask_, this->new_axis_mask_); } } // namespace ops