Added new_axis_mask param for stridedslice (#600)
Add another constructor for stridedslice when new_axis_mask is set The layout inference need to reconstruct the axis mapping when new_axis_mask is set(TODO) Type: New Feature Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
d823ef6fcb
commit
75882d4195
|
|
@ -58,6 +58,10 @@ class StridedSlice : public BuiltinOp {
|
|||
const std::vector<int32_t> end_dims,
|
||||
const std::vector<int32_t> stride_dims, int32_t begin_mask,
|
||||
int32_t end_mask, int32_t shrink_axis_mask);
|
||||
StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
|
||||
const std::vector<int32_t> end_dims,
|
||||
const std::vector<int32_t> stride_dims, int32_t begin_mask,
|
||||
int32_t end_mask, int32_t shrink_axis_mask, int32_t new_axis_mask);
|
||||
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
|
||||
|
|
@ -68,6 +72,7 @@ class StridedSlice : public BuiltinOp {
|
|||
int32_t begin_mask_;
|
||||
int32_t end_mask_;
|
||||
int32_t shrink_axis_mask_;
|
||||
int32_t new_axis_mask_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -19,6 +19,11 @@
|
|||
},
|
||||
{"name": "shrink_axis_mask",
|
||||
"dtype": "int32_t"
|
||||
},
|
||||
{"name": "new_axis_mask",
|
||||
"dtype": "int32_t",
|
||||
"Optional": "true",
|
||||
"default": "0"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class StridedSliceLayoutInfer : public OpLayoutInfer {
|
|||
int32_t end_mask = op_->impl()->node()->nn_param.strided_slice.end_mask;
|
||||
int32_t shrink_axis_mask =
|
||||
op_->impl()->node()->nn_param.strided_slice.shrink_axis_mask;
|
||||
int32_t new_axis_mask =
|
||||
op_->impl()->node()->nn_param.strided_slice.new_axis_mask;
|
||||
uint32_t begin_dims_num =
|
||||
op_->impl()->node()->nn_param.strided_slice.begin_dims_num;
|
||||
std::vector<int32_t> begin_dims(begin_dims_num);
|
||||
|
|
@ -66,45 +68,51 @@ class StridedSliceLayoutInfer : public OpLayoutInfer {
|
|||
op_->impl()->node()->nn_param.strided_slice.stride_dims,
|
||||
stride_dims_num * sizeof(uint32_t));
|
||||
|
||||
begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims);
|
||||
end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims);
|
||||
stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims);
|
||||
if (!new_axis_mask) {
|
||||
begin_dims = MapMultipleAxis(input_pv->AsStdVec(), begin_dims);
|
||||
end_dims = MapMultipleAxis(input_pv->AsStdVec(), end_dims);
|
||||
stride_dims = MapMultipleAxis(input_pv->AsStdVec(), stride_dims);
|
||||
|
||||
shrink_axis_mask = MapMask(input_pv->AsStdVec(), shrink_axis_mask);
|
||||
begin_mask = MapMask(input_pv->AsStdVec(), begin_mask);
|
||||
end_mask = MapMask(input_pv->AsStdVec(), end_mask);
|
||||
auto strided_slice =
|
||||
context_->infer_graph_->CreateOperation<vx::ops::StridedSlice>(
|
||||
begin_dims, end_dims, stride_dims, begin_mask, end_mask,
|
||||
shrink_axis_mask);
|
||||
// The following is the normalized dimension calculation
|
||||
std::set<uint32_t> remaind_set;
|
||||
std::vector<uint32_t> remaind_axis;
|
||||
for (uint32_t i = 0; i < input_pv->AsStdVec().size(); ++i)
|
||||
if ((shrink_axis_mask & (1 << i)) == 0) {
|
||||
remaind_axis.push_back(
|
||||
input_pv->AsStdVec()
|
||||
[i]); // Store unnormalized dimensionality reduction pv values
|
||||
remaind_set.insert(input_pv->AsStdVec()[i]);
|
||||
}
|
||||
// Traverse the input pv to find a dimension smaller than the current remaining dimension
|
||||
auto out_pv = MakeShared(remaind_axis.size());
|
||||
for (uint32_t i = 0; i < remaind_axis.size(); ++i) {
|
||||
uint32_t cnt = 0;
|
||||
for (uint32_t j = 0; j < input_pv->AsStdVec().size(); j++) {
|
||||
if (input_pv->AsStdVec()[j] < remaind_axis[i] &&
|
||||
remaind_set.end() == remaind_set.find(input_pv->AsStdVec()[j])) {
|
||||
cnt++; // Record the number of dimensions smaller than the current dimension
|
||||
shrink_axis_mask = MapMask(input_pv->AsStdVec(), shrink_axis_mask);
|
||||
begin_mask = MapMask(input_pv->AsStdVec(), begin_mask);
|
||||
end_mask = MapMask(input_pv->AsStdVec(), end_mask);
|
||||
auto strided_slice =
|
||||
context_->infer_graph_->CreateOperation<vx::ops::StridedSlice>(
|
||||
begin_dims, end_dims, stride_dims, begin_mask, end_mask,
|
||||
shrink_axis_mask);
|
||||
// The following is the normalized dimension calculation
|
||||
std::set<uint32_t> remained_set;
|
||||
std::vector<uint32_t> remained_axis;
|
||||
for (uint32_t i = 0; i < input_pv->AsStdVec().size(); ++i)
|
||||
if ((shrink_axis_mask & (1 << i)) == 0) {
|
||||
remained_axis.push_back(
|
||||
input_pv->AsStdVec()
|
||||
[i]); // Store unnormalized dimensionality reduction pv values
|
||||
remained_set.insert(input_pv->AsStdVec()[i]);
|
||||
}
|
||||
// Traverse the input pv to find a dimension smaller than the current remaining dimension
|
||||
auto out_pv = MakeShared(remained_axis.size());
|
||||
for (uint32_t i = 0; i < remained_axis.size(); ++i) {
|
||||
uint32_t cnt = 0;
|
||||
for (uint32_t j = 0; j < input_pv->AsStdVec().size(); j++) {
|
||||
if (input_pv->AsStdVec()[j] < remained_axis[i] &&
|
||||
remained_set.end() ==
|
||||
remained_set.find(input_pv->AsStdVec()[j])) {
|
||||
cnt++; // Record the number of dimensions smaller than the current dimension
|
||||
}
|
||||
}
|
||||
out_pv->At(i) = remained_axis[i] - cnt;
|
||||
}
|
||||
out_pv->At(i) = remaind_axis[i] - cnt;
|
||||
}
|
||||
|
||||
auto infer_out = CreateOutputsTensor(out_pv);
|
||||
(*strided_slice).BindInput(context_->GetMapedTensor(src_input));
|
||||
(*strided_slice).BindOutput(infer_out[0]);
|
||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
|
||||
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
|
||||
auto infer_out = CreateOutputsTensor(out_pv);
|
||||
(*strided_slice).BindInput(context_->GetMapedTensor(src_input));
|
||||
(*strided_slice).BindOutput(infer_out[0]);
|
||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
|
||||
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
|
||||
} else { //TODO
|
||||
VSILOGE("ERROR: implement not supported yet for new_axis_mask !=0");
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace transform
|
||||
|
|
|
|||
|
|
@ -34,18 +34,21 @@ StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
|
|||
const std::vector<int32_t> end_dims,
|
||||
const std::vector<int32_t> stride_dims,
|
||||
int32_t begin_mask, int32_t end_mask,
|
||||
int32_t shrink_axis_mask)
|
||||
int32_t shrink_axis_mask, int32_t new_axis_mask)
|
||||
: BuiltinOp(graph, VSI_NN_OP_STRIDED_SLICE),
|
||||
begin_dims_(std::move(begin_dims)),
|
||||
end_dims_(std::move(end_dims)),
|
||||
stride_dims_(std::move(stride_dims)),
|
||||
begin_mask_(begin_mask),
|
||||
end_mask_(end_mask),
|
||||
shrink_axis_mask_(shrink_axis_mask) {
|
||||
shrink_axis_mask_(shrink_axis_mask),
|
||||
new_axis_mask_(new_axis_mask) {
|
||||
this->impl()->node()->nn_param.strided_slice.begin_mask = begin_mask_;
|
||||
this->impl()->node()->nn_param.strided_slice.end_mask = end_mask_;
|
||||
this->impl()->node()->nn_param.strided_slice.shrink_axis_mask =
|
||||
shrink_axis_mask_;
|
||||
this->impl()->node()->nn_param.strided_slice.new_axis_mask =
|
||||
new_axis_mask_;
|
||||
this->impl()->node()->nn_param.strided_slice.begin_dims = begin_dims_.data();
|
||||
this->impl()->node()->nn_param.strided_slice.begin_dims_num =
|
||||
begin_dims_.size();
|
||||
|
|
@ -57,11 +60,19 @@ StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
|
|||
stride_dims_.size();
|
||||
}
|
||||
|
||||
StridedSlice::StridedSlice(Graph* graph, const std::vector<int32_t> begin_dims,
|
||||
const std::vector<int32_t> end_dims,
|
||||
const std::vector<int32_t> stride_dims,
|
||||
int32_t begin_mask, int32_t end_mask,
|
||||
int32_t shrink_axis_mask)
|
||||
: StridedSlice(graph, begin_dims, end_dims, stride_dims,
|
||||
begin_mask, end_mask, shrink_axis_mask, 0) {}
|
||||
|
||||
std::shared_ptr<Operation> StridedSlice::Clone(
|
||||
std::shared_ptr<Graph>& graph) const {
|
||||
return graph->CreateOperation<StridedSlice>(
|
||||
this->begin_dims_, this->end_dims_, this->stride_dims_, this->begin_mask_,
|
||||
this->end_mask_, this->shrink_axis_mask_);
|
||||
this->end_mask_, this->shrink_axis_mask_, this->new_axis_mask_);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
Loading…
Reference in New Issue