fixed reduce layoutinfer bug (#605)

Type: Bug fixed

Signed-off-by: Chen <jack.chen@verisilicon.com>
Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-06-19 21:56:08 +08:00 committed by GitHub
parent fbfbdd7c83
commit 26b4e53fe7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 11 deletions

View File

@ -45,16 +45,17 @@ class ReduceLayoutInfer : public OpLayoutInfer {
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override { std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
auto t_src = op_->impl()->InputsTensor()[0]; auto t_src = op_->impl()->InputsTensor()[0];
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]); auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
std::set<int32_t> unique_axis; std::set<int32_t> axis_set; //Same value as new_axis, convenient for searching
std::vector<int32_t> new_axis; std::vector<int32_t> new_axis, pv_reduced;
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num; for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
++i) { ++i) {
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i]; int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
if (axis < 0) { if (axis < 0) {
axis += pv->Rank(); axis += pv->Rank();
} }
unique_axis.insert(axis); axis = MapAxis(pv->AsStdVec(), axis);
new_axis.push_back(MapAxis(pv->AsStdVec(), axis)); axis_set.insert(axis);
new_axis.push_back(axis);
} }
auto reduce = context_->infer_graph_->CreateOperation<OpType>( auto reduce = context_->infer_graph_->CreateOperation<OpType>(
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim); new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
@ -64,16 +65,19 @@ class ReduceLayoutInfer : public OpLayoutInfer {
(*reduce).BindOutput(otensor_infer[0]); (*reduce).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], pv); context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], pv);
} else { } else {
auto out_pv = MakeShared(pv->Rank() - unique_axis.size()); auto out_pv = MakeShared(pv->Rank() - axis_set.size());
for (uint32_t i = 0; i < pv->Rank(); i++) {
if (axis_set.end() != axis_set.find(i)) continue;
pv_reduced.push_back(pv->At(i));
}
uint32_t j = 0; uint32_t j = 0;
for (uint32_t i = 0; i < out_pv->Rank(); i++) { for (auto axis_remine : pv_reduced) {
if (unique_axis.end() != unique_axis.find(pv->At(i))) continue;
uint32_t cnt = 0; uint32_t cnt = 0;
for (auto axis : unique_axis) { for(auto axis_reduced : axis_set) {
if (pv->At(i) > (uint32_t)axis) cnt++; if ((uint32_t)axis_remine > pv->At(axis_reduced)) cnt++;
} }
out_pv->At(j) = pv->At(i) - cnt; out_pv->At(j) = (uint32_t)axis_remine - cnt;
j++; ++j;
} }
auto otensor_infer = CreateOutputsTensor(out_pv); auto otensor_infer = CreateOutputsTensor(out_pv);
(*reduce).BindOutput(otensor_infer[0]); (*reduce).BindOutput(otensor_infer[0]);