Added new_axis_mask param for stridedslice (#600)

Add another constructor for stridedslice when new_axis_mask is set

The layout inference need to reconstruct the axis mapping when
new_axis_mask is set(TODO)

Type: New Feature

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-06-25 09:24:41 +08:00 committed by GitHub
parent d823ef6fcb
commit 75882d4195
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 38 deletions

View File

@ -58,6 +58,10 @@ class StridedSlice : public BuiltinOp {
const std::vector<int32_t> end_dims, const std::vector<int32_t> end_dims,
const std::vector<int32_t> stride_dims, int32_t begin_mask, const std::vector<int32_t> stride_dims, int32_t begin_mask,
int32_t end_mask, int32_t shrink_axis_mask); int32_t end_mask, int32_t shrink_axis_mask);
StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
const std::vector<int32_t> end_dims,
const std::vector<int32_t> stride_dims, int32_t begin_mask,
int32_t end_mask, int32_t shrink_axis_mask, int32_t new_axis_mask);
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override; std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
@ -68,6 +72,7 @@ class StridedSlice : public BuiltinOp {
int32_t begin_mask_; int32_t begin_mask_;
int32_t end_mask_; int32_t end_mask_;
int32_t shrink_axis_mask_; int32_t shrink_axis_mask_;
int32_t new_axis_mask_;
}; };
} // namespace ops } // namespace ops

View File

@ -19,6 +19,11 @@
}, },
{"name": "shrink_axis_mask", {"name": "shrink_axis_mask",
"dtype": "int32_t" "dtype": "int32_t"
},
{"name": "new_axis_mask",
"dtype": "int32_t",
"Optional": "true",
"default": "0"
} }
] ]
} }

View File

@ -47,6 +47,8 @@ class StridedSliceLayoutInfer : public OpLayoutInfer {
int32_t end_mask = op_->impl()->node()->nn_param.strided_slice.end_mask; int32_t end_mask = op_->impl()->node()->nn_param.strided_slice.end_mask;
int32_t shrink_axis_mask = int32_t shrink_axis_mask =
op_->impl()->node()->nn_param.strided_slice.shrink_axis_mask; op_->impl()->node()->nn_param.strided_slice.shrink_axis_mask;
int32_t new_axis_mask =
op_->impl()->node()->nn_param.strided_slice.new_axis_mask;
uint32_t begin_dims_num = uint32_t begin_dims_num =
op_->impl()->node()->nn_param.strided_slice.begin_dims_num; op_->impl()->node()->nn_param.strided_slice.begin_dims_num;
std::vector<int32_t> begin_dims(begin_dims_num); std::vector<int32_t> begin_dims(begin_dims_num);
@ -66,45 +68,51 @@ class StridedSliceLayoutInfer : public OpLayoutInfer {
op_->impl()->node()->nn_param.strided_slice.stride_dims, op_->impl()->node()->nn_param.strided_slice.stride_dims,
stride_dims_num * sizeof(uint32_t)); stride_dims_num * sizeof(uint32_t));
begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims); if (!new_axis_mask) {
end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims); begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims);
stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims); end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims);
stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims);
shrink_axis_mask = MapMask(input_pv->AsStdVec(), shrink_axis_mask); shrink_axis_mask = MapMask(input_pv->AsStdVec(), shrink_axis_mask);
begin_mask = MapMask(input_pv->AsStdVec(), begin_mask); begin_mask = MapMask(input_pv->AsStdVec(), begin_mask);
end_mask = MapMask(input_pv->AsStdVec(), end_mask); end_mask = MapMask(input_pv->AsStdVec(), end_mask);
auto strided_slice = auto strided_slice =
context_->infer_graph_->CreateOperation<vx::ops::StridedSlice>( context_->infer_graph_->CreateOperation<vx::ops::StridedSlice>(
begin_dims, end_dims, stride_dims, begin_mask, end_mask, begin_dims, end_dims, stride_dims, begin_mask, end_mask,
shrink_axis_mask); shrink_axis_mask);
// The following is the normalized dimension calculation // The following is the normalized dimension calculation
std::set<uint32_t> remaind_set; std::set<uint32_t> remained_set;
std::vector<uint32_t> remaind_axis; std::vector<uint32_t> remained_axis;
for (uint32_t i = 0; i < input_pv->AsStdVec().size(); ++i) for (uint32_t i = 0; i < input_pv->AsStdVec().size(); ++i)
if ((shrink_axis_mask & (1 << i)) == 0) { if ((shrink_axis_mask & (1 << i)) == 0) {
remaind_axis.push_back( remained_axis.push_back(
input_pv->AsStdVec() input_pv->AsStdVec()
[i]); // Store unnormalized dimensionality reduction pv values [i]); // Store unnormalized dimensionality reduction pv values
remaind_set.insert(input_pv->AsStdVec()[i]); remained_set.insert(input_pv->AsStdVec()[i]);
}
// Traverse the input pv to find a dimension smaller than the current remaining dimension
auto out_pv = MakeShared(remaind_axis.size());
for (uint32_t i = 0; i < remaind_axis.size(); ++i) {
uint32_t cnt = 0;
for (uint32_t j = 0; j < input_pv->AsStdVec().size(); j++) {
if (input_pv->AsStdVec()[j] < remaind_axis[i] &&
remaind_set.end() == remaind_set.find(input_pv->AsStdVec()[j])) {
cnt++; // Record the number of dimensions smaller than the current dimension
} }
// Traverse the input pv to find a dimension smaller than the current remaining dimension
auto out_pv = MakeShared(remained_axis.size());
for (uint32_t i = 0; i < remained_axis.size(); ++i) {
uint32_t cnt = 0;
for (uint32_t j = 0; j < input_pv->AsStdVec().size(); j++) {
if (input_pv->AsStdVec()[j] < remained_axis[i] &&
remained_set.end() ==
remained_set.find(input_pv->AsStdVec()[j])) {
cnt++; // Record the number of dimensions smaller than the current dimension
}
}
out_pv->At(i) = remained_axis[i] - cnt;
} }
out_pv->At(i) = remaind_axis[i] - cnt;
}
auto infer_out = CreateOutputsTensor(out_pv); auto infer_out = CreateOutputsTensor(out_pv);
(*strided_slice).BindInput(context_->GetMapedTensor(src_input)); (*strided_slice).BindInput(context_->GetMapedTensor(src_input));
(*strided_slice).BindOutput(infer_out[0]); (*strided_slice).BindOutput(infer_out[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv); context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]); next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
} else { //TODO
VSILOGE("ERROR: implement not supported yet for new_axis_mask !=0");
assert(false);
}
} }
}; };
} // namespace transform } // namespace transform

View File

@ -34,18 +34,21 @@ StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
const std::vector<int32_t> end_dims, const std::vector<int32_t> end_dims,
const std::vector<int32_t> stride_dims, const std::vector<int32_t> stride_dims,
int32_t begin_mask, int32_t end_mask, int32_t begin_mask, int32_t end_mask,
int32_t shrink_axis_mask) int32_t shrink_axis_mask, int32_t new_axis_mask)
: BuiltinOp(graph, VSI_NN_OP_STRIDED_SLICE), : BuiltinOp(graph, VSI_NN_OP_STRIDED_SLICE),
begin_dims_(std::move(begin_dims)), begin_dims_(std::move(begin_dims)),
end_dims_(std::move(end_dims)), end_dims_(std::move(end_dims)),
stride_dims_(std::move(stride_dims)), stride_dims_(std::move(stride_dims)),
begin_mask_(begin_mask), begin_mask_(begin_mask),
end_mask_(end_mask), end_mask_(end_mask),
shrink_axis_mask_(shrink_axis_mask) { shrink_axis_mask_(shrink_axis_mask),
new_axis_mask_(new_axis_mask) {
this->impl()->node()->nn_param.strided_slice.begin_mask = begin_mask_; this->impl()->node()->nn_param.strided_slice.begin_mask = begin_mask_;
this->impl()->node()->nn_param.strided_slice.end_mask = end_mask_; this->impl()->node()->nn_param.strided_slice.end_mask = end_mask_;
this->impl()->node()->nn_param.strided_slice.shrink_axis_mask = this->impl()->node()->nn_param.strided_slice.shrink_axis_mask =
shrink_axis_mask_; shrink_axis_mask_;
this->impl()->node()->nn_param.strided_slice.new_axis_mask =
new_axis_mask_;
this->impl()->node()->nn_param.strided_slice.begin_dims = begin_dims_.data(); this->impl()->node()->nn_param.strided_slice.begin_dims = begin_dims_.data();
this->impl()->node()->nn_param.strided_slice.begin_dims_num = this->impl()->node()->nn_param.strided_slice.begin_dims_num =
begin_dims_.size(); begin_dims_.size();
@ -57,11 +60,19 @@ StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
stride_dims_.size(); stride_dims_.size();
} }
StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
const std::vector<int32_t> end_dims,
const std::vector<int32_t> stride_dims,
int32_t begin_mask, int32_t end_mask,
int32_t shrink_axis_mask)
: StridedSlice(graph, begin_dims, end_dims, stride_dims,
begin_mask, end_mask, shrink_axis_mask, 0) {}
std::shared_ptr<Operation> StridedSlice::Clone( std::shared_ptr<Operation> StridedSlice::Clone(
std::shared_ptr<Graph>& graph) const { std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<StridedSlice>( return graph->CreateOperation<StridedSlice>(
this->begin_dims_, this->end_dims_, this->stride_dims_, this->begin_mask_, this->begin_dims_, this->end_dims_, this->stride_dims_, this->begin_mask_,
this->end_mask_, this->shrink_axis_mask_); this->end_mask_, this->shrink_axis_mask_, this->new_axis_mask_);
} }
} // namespace ops } // namespace ops