fixed reduce layoutinfer bug (#605)
Type: Bug fixed Signed-off-by: Chen <jack.chen@verisilicon.com> Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
fbfbdd7c83
commit
26b4e53fe7
|
|
@ -45,16 +45,17 @@ class ReduceLayoutInfer : public OpLayoutInfer {
|
|||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
|
||||
auto t_src = op_->impl()->InputsTensor()[0];
|
||||
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
|
||||
std::set<int32_t> unique_axis;
|
||||
std::vector<int32_t> new_axis;
|
||||
std::set<int32_t> axis_set; //Same value as new_axis, convenient for searching
|
||||
std::vector<int32_t> new_axis, pv_reduced;
|
||||
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
|
||||
++i) {
|
||||
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
|
||||
if (axis < 0) {
|
||||
axis += pv->Rank();
|
||||
}
|
||||
unique_axis.insert(axis);
|
||||
new_axis.push_back(MapAxis(pv->AsStdVec(), axis));
|
||||
axis = MapAxis(pv->AsStdVec(), axis);
|
||||
axis_set.insert(axis);
|
||||
new_axis.push_back(axis);
|
||||
}
|
||||
auto reduce = context_->infer_graph_->CreateOperation<OpType>(
|
||||
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
|
||||
|
|
@ -64,16 +65,19 @@ class ReduceLayoutInfer : public OpLayoutInfer {
|
|||
(*reduce).BindOutput(otensor_infer[0]);
|
||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], pv);
|
||||
} else {
|
||||
auto out_pv = MakeShared(pv->Rank() - unique_axis.size());
|
||||
auto out_pv = MakeShared(pv->Rank() - axis_set.size());
|
||||
for (uint32_t i = 0; i < pv->Rank(); i++) {
|
||||
if (axis_set.end() != axis_set.find(i)) continue;
|
||||
pv_reduced.push_back(pv->At(i));
|
||||
}
|
||||
uint32_t j = 0;
|
||||
for (uint32_t i = 0; i < out_pv->Rank(); i++) {
|
||||
if (unique_axis.end() != unique_axis.find(pv->At(i))) continue;
|
||||
for (auto axis_remine : pv_reduced) {
|
||||
uint32_t cnt = 0;
|
||||
for (auto axis : unique_axis) {
|
||||
if (pv->At(i) > (uint32_t)axis) cnt++;
|
||||
for(auto axis_reduced : axis_set) {
|
||||
if ((uint32_t)axis_remine > pv->At(axis_reduced)) cnt++;
|
||||
}
|
||||
out_pv->At(j) = pv->At(i) - cnt;
|
||||
j++;
|
||||
out_pv->At(j) = (uint32_t)axis_remine - cnt;
|
||||
++j;
|
||||
}
|
||||
auto otensor_infer = CreateOutputsTensor(out_pv);
|
||||
(*reduce).BindOutput(otensor_infer[0]);
|
||||
|
|
|
|||
Loading…
Reference in New Issue