fixed reduce layoutinfer bug (#605)

Type: Bug fixed

Signed-off-by: Chen <jack.chen@verisilicon.com>
Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-06-19 21:56:08 +08:00 committed by GitHub
parent fbfbdd7c83
commit 26b4e53fe7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 11 deletions

View File

@ -45,16 +45,17 @@ class ReduceLayoutInfer : public OpLayoutInfer {
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
auto t_src = op_->impl()->InputsTensor()[0];
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
std::set<int32_t> unique_axis;
std::vector<int32_t> new_axis;
std::set<int32_t> axis_set; //Same value as new_axis, convenient for searching
std::vector<int32_t> new_axis, pv_reduced;
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
++i) {
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
if (axis < 0) {
axis += pv->Rank();
}
unique_axis.insert(axis);
new_axis.push_back(MapAxis(pv->AsStdVec(), axis));
axis = MapAxis(pv->AsStdVec(), axis);
axis_set.insert(axis);
new_axis.push_back(axis);
}
auto reduce = context_->infer_graph_->CreateOperation<OpType>(
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
@ -64,16 +65,19 @@ class ReduceLayoutInfer : public OpLayoutInfer {
(*reduce).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], pv);
} else {
auto out_pv = MakeShared(pv->Rank() - unique_axis.size());
auto out_pv = MakeShared(pv->Rank() - axis_set.size());
for (uint32_t i = 0; i < pv->Rank(); i++) {
if (axis_set.end() != axis_set.find(i)) continue;
pv_reduced.push_back(pv->At(i));
}
uint32_t j = 0;
for (uint32_t i = 0; i < out_pv->Rank(); i++) {
if (unique_axis.end() != unique_axis.find(pv->At(i))) continue;
for (auto axis_remine : pv_reduced) {
uint32_t cnt = 0;
for (auto axis : unique_axis) {
if (pv->At(i) > (uint32_t)axis) cnt++;
for(auto axis_reduced : axis_set) {
if ((uint32_t)axis_remine > pv->At(axis_reduced)) cnt++;
}
out_pv->At(j) = pv->At(i) - cnt;
j++;
out_pv->At(j) = (uint32_t)axis_remine - cnt;
++j;
}
auto otensor_infer = CreateOutputsTensor(out_pv);
(*reduce).BindOutput(otensor_infer[0]);