fixed reduce layoutinfer bug (#605)
Type: Bug fixed Signed-off-by: Chen <jack.chen@verisilicon.com> Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
fbfbdd7c83
commit
26b4e53fe7
|
|
@ -45,16 +45,17 @@ class ReduceLayoutInfer : public OpLayoutInfer {
|
||||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
|
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
|
||||||
auto t_src = op_->impl()->InputsTensor()[0];
|
auto t_src = op_->impl()->InputsTensor()[0];
|
||||||
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
|
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
|
||||||
std::set<int32_t> unique_axis;
|
std::set<int32_t> axis_set; //Same value as new_axis, convenient for searching
|
||||||
std::vector<int32_t> new_axis;
|
std::vector<int32_t> new_axis, pv_reduced;
|
||||||
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
|
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
|
||||||
++i) {
|
++i) {
|
||||||
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
|
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += pv->Rank();
|
axis += pv->Rank();
|
||||||
}
|
}
|
||||||
unique_axis.insert(axis);
|
axis = MapAxis(pv->AsStdVec(), axis);
|
||||||
new_axis.push_back(MapAxis(pv->AsStdVec(), axis));
|
axis_set.insert(axis);
|
||||||
|
new_axis.push_back(axis);
|
||||||
}
|
}
|
||||||
auto reduce = context_->infer_graph_->CreateOperation<OpType>(
|
auto reduce = context_->infer_graph_->CreateOperation<OpType>(
|
||||||
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
|
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
|
||||||
|
|
@ -64,16 +65,19 @@ class ReduceLayoutInfer : public OpLayoutInfer {
|
||||||
(*reduce).BindOutput(otensor_infer[0]);
|
(*reduce).BindOutput(otensor_infer[0]);
|
||||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], pv);
|
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], pv);
|
||||||
} else {
|
} else {
|
||||||
auto out_pv = MakeShared(pv->Rank() - unique_axis.size());
|
auto out_pv = MakeShared(pv->Rank() - axis_set.size());
|
||||||
|
for (uint32_t i = 0; i < pv->Rank(); i++) {
|
||||||
|
if (axis_set.end() != axis_set.find(i)) continue;
|
||||||
|
pv_reduced.push_back(pv->At(i));
|
||||||
|
}
|
||||||
uint32_t j = 0;
|
uint32_t j = 0;
|
||||||
for (uint32_t i = 0; i < out_pv->Rank(); i++) {
|
for (auto axis_remine : pv_reduced) {
|
||||||
if (unique_axis.end() != unique_axis.find(pv->At(i))) continue;
|
|
||||||
uint32_t cnt = 0;
|
uint32_t cnt = 0;
|
||||||
for (auto axis : unique_axis) {
|
for(auto axis_reduced : axis_set) {
|
||||||
if (pv->At(i) > (uint32_t)axis) cnt++;
|
if ((uint32_t)axis_remine > pv->At(axis_reduced)) cnt++;
|
||||||
}
|
}
|
||||||
out_pv->At(j) = pv->At(i) - cnt;
|
out_pv->At(j) = (uint32_t)axis_remine - cnt;
|
||||||
j++;
|
++j;
|
||||||
}
|
}
|
||||||
auto otensor_infer = CreateOutputsTensor(out_pv);
|
auto otensor_infer = CreateOutputsTensor(out_pv);
|
||||||
(*reduce).BindOutput(otensor_infer[0]);
|
(*reduce).BindOutput(otensor_infer[0]);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue