Added new_axis_mask param for stridedslice (#600)

Add another constructor for stridedslice when new_axis_mask is set

The layout inference need to reconstruct the axis mapping when
new_axis_mask is set(TODO)

Type: New Feature

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-06-25 09:24:41 +08:00 committed by GitHub
parent d823ef6fcb
commit 75882d4195
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 38 deletions

View File

@ -58,6 +58,10 @@ class StridedSlice : public BuiltinOp {
const std::vector<int32_t> end_dims,
const std::vector<int32_t> stride_dims, int32_t begin_mask,
int32_t end_mask, int32_t shrink_axis_mask);
StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
const std::vector<int32_t> end_dims,
const std::vector<int32_t> stride_dims, int32_t begin_mask,
int32_t end_mask, int32_t shrink_axis_mask, int32_t new_axis_mask);
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
@ -68,6 +72,7 @@ class StridedSlice : public BuiltinOp {
int32_t begin_mask_;
int32_t end_mask_;
int32_t shrink_axis_mask_;
int32_t new_axis_mask_;
};
} // namespace ops

View File

@ -19,6 +19,11 @@
},
{"name": "shrink_axis_mask",
"dtype": "int32_t"
},
{"name": "new_axis_mask",
"dtype": "int32_t",
"Optional": "true",
"default": "0"
}
]
}

View File

@ -47,6 +47,8 @@ class StridedSliceLayoutInfer : public OpLayoutInfer {
int32_t end_mask = op_->impl()->node()->nn_param.strided_slice.end_mask;
int32_t shrink_axis_mask =
op_->impl()->node()->nn_param.strided_slice.shrink_axis_mask;
int32_t new_axis_mask =
op_->impl()->node()->nn_param.strided_slice.new_axis_mask;
uint32_t begin_dims_num =
op_->impl()->node()->nn_param.strided_slice.begin_dims_num;
std::vector<int32_t> begin_dims(begin_dims_num);
@ -66,45 +68,51 @@ class StridedSliceLayoutInfer : public OpLayoutInfer {
op_->impl()->node()->nn_param.strided_slice.stride_dims,
stride_dims_num * sizeof(uint32_t));
begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims);
end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims);
stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims);
if (!new_axis_mask) {
begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims);
end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims);
stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims);
shrink_axis_mask = MapMask(input_pv->AsStdVec(), shrink_axis_mask);
begin_mask = MapMask(input_pv->AsStdVec(), begin_mask);
end_mask = MapMask(input_pv->AsStdVec(), end_mask);
auto strided_slice =
context_->infer_graph_->CreateOperation<vx::ops::StridedSlice>(
begin_dims, end_dims, stride_dims, begin_mask, end_mask,
shrink_axis_mask);
// The following is the normalized dimension calculation
std::set<uint32_t> remaind_set;
std::vector<uint32_t> remaind_axis;
for (uint32_t i = 0; i < input_pv->AsStdVec().size(); ++i)
if ((shrink_axis_mask & (1 << i)) == 0) {
remaind_axis.push_back(
input_pv->AsStdVec()
[i]); // Store unnormalized dimensionality reduction pv values
remaind_set.insert(input_pv->AsStdVec()[i]);
}
// Traverse the input pv to find a dimension smaller than the current remaining dimension
auto out_pv = MakeShared(remaind_axis.size());
for (uint32_t i = 0; i < remaind_axis.size(); ++i) {
uint32_t cnt = 0;
for (uint32_t j = 0; j < input_pv->AsStdVec().size(); j++) {
if (input_pv->AsStdVec()[j] < remaind_axis[i] &&
remaind_set.end() == remaind_set.find(input_pv->AsStdVec()[j])) {
cnt++; // Record the number of dimensions smaller than the current dimension
shrink_axis_mask = MapMask(input_pv->AsStdVec(), shrink_axis_mask);
begin_mask = MapMask(input_pv->AsStdVec(), begin_mask);
end_mask = MapMask(input_pv->AsStdVec(), end_mask);
auto strided_slice =
context_->infer_graph_->CreateOperation<vx::ops::StridedSlice>(
begin_dims, end_dims, stride_dims, begin_mask, end_mask,
shrink_axis_mask);
// The following is the normalized dimension calculation
std::set<uint32_t> remained_set;
std::vector<uint32_t> remained_axis;
for (uint32_t i = 0; i < input_pv->AsStdVec().size(); ++i)
if ((shrink_axis_mask & (1 << i)) == 0) {
remained_axis.push_back(
input_pv->AsStdVec()
[i]); // Store unnormalized dimensionality reduction pv values
remained_set.insert(input_pv->AsStdVec()[i]);
}
// Traverse the input pv to find a dimension smaller than the current remaining dimension
auto out_pv = MakeShared(remained_axis.size());
for (uint32_t i = 0; i < remained_axis.size(); ++i) {
uint32_t cnt = 0;
for (uint32_t j = 0; j < input_pv->AsStdVec().size(); j++) {
if (input_pv->AsStdVec()[j] < remained_axis[i] &&
remained_set.end() ==
remained_set.find(input_pv->AsStdVec()[j])) {
cnt++; // Record the number of dimensions smaller than the current dimension
}
}
out_pv->At(i) = remained_axis[i] - cnt;
}
out_pv->At(i) = remaind_axis[i] - cnt;
}
auto infer_out = CreateOutputsTensor(out_pv);
(*strided_slice).BindInput(context_->GetMapedTensor(src_input));
(*strided_slice).BindOutput(infer_out[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
auto infer_out = CreateOutputsTensor(out_pv);
(*strided_slice).BindInput(context_->GetMapedTensor(src_input));
(*strided_slice).BindOutput(infer_out[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
} else { //TODO
VSILOGE("ERROR: implement not supported yet for new_axis_mask !=0");
assert(false);
}
}
};
} // namespace transform

View File

@ -34,18 +34,21 @@ StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
const std::vector<int32_t> end_dims,
const std::vector<int32_t> stride_dims,
int32_t begin_mask, int32_t end_mask,
int32_t shrink_axis_mask)
int32_t shrink_axis_mask, int32_t new_axis_mask)
: BuiltinOp(graph, VSI_NN_OP_STRIDED_SLICE),
begin_dims_(std::move(begin_dims)),
end_dims_(std::move(end_dims)),
stride_dims_(std::move(stride_dims)),
begin_mask_(begin_mask),
end_mask_(end_mask),
shrink_axis_mask_(shrink_axis_mask) {
shrink_axis_mask_(shrink_axis_mask),
new_axis_mask_(new_axis_mask) {
this->impl()->node()->nn_param.strided_slice.begin_mask = begin_mask_;
this->impl()->node()->nn_param.strided_slice.end_mask = end_mask_;
this->impl()->node()->nn_param.strided_slice.shrink_axis_mask =
shrink_axis_mask_;
this->impl()->node()->nn_param.strided_slice.new_axis_mask =
new_axis_mask_;
this->impl()->node()->nn_param.strided_slice.begin_dims = begin_dims_.data();
this->impl()->node()->nn_param.strided_slice.begin_dims_num =
begin_dims_.size();
@ -57,11 +60,19 @@ StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
stride_dims_.size();
}
StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
const std::vector<int32_t> end_dims,
const std::vector<int32_t> stride_dims,
int32_t begin_mask, int32_t end_mask,
int32_t shrink_axis_mask)
: StridedSlice(graph, begin_dims, end_dims, stride_dims,
begin_mask, end_mask, shrink_axis_mask, 0) {}
std::shared_ptr<Operation> StridedSlice::Clone(
std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<StridedSlice>(
this->begin_dims_, this->end_dims_, this->stride_dims_, this->begin_mask_,
this->end_mask_, this->shrink_axis_mask_);
this->end_mask_, this->shrink_axis_mask_, this->new_axis_mask_);
}
} // namespace ops