Fixed layout inference bug for stack (#375)

Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2022-05-05 17:18:09 +08:00 committed by GitHub
parent eab0d807a6
commit 11572140d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 22 deletions

View File

@ -53,30 +53,28 @@ class StackLayoutInfer : public OpLayoutInfer {
(*stack).BindInput(context_->GetMapedTensor(i_src));
}
std::vector<uint32_t> v;
uint32_t dim_num = src_input->GetShape().size();
if (axis < 0) {
axis += dim_num;
}
for (uint32_t i = 0; i < src_input->GetShape().size(); ++i) {
if (input_pv->At(i) > (uint32_t)axis) {
v.push_back(input_pv->At(i) + 1);
} else if (input_pv->At(i) == (uint32_t)axis) {
v.push_back(input_pv->At(i));
v.push_back(input_pv->At(i) + 1);
} else {
v.push_back(input_pv->At(i));
}
}
auto out_pv =
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
for (uint32_t i = 0; i < out_pv->Rank(); ++i) {
out_pv->At(i) = v[i];
axis += src_input->GetShape().size();
}
auto out_infer = CreateOutputsTensor(out_pv);
auto output_pv = MakeShared(input_pv->Rank() + 1);
if (!input_pv->IsAligned()) {
output_pv->At(axis) = (uint32_t)axis;
for (uint32_t i = 0, j = 0; i < input_pv->Rank(); ++i, ++j) {
if ((uint32_t)axis == i) {
++j;
}
if (input_pv->At(i) < (uint32_t)axis) {
output_pv->At(j) = input_pv->At(i);
} else {
output_pv->At(j) = input_pv->At(i) + 1;
}
}
}
auto out_infer = CreateOutputsTensor(output_pv);
(*stack).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], output_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}

View File

@ -38,7 +38,7 @@ TEST(Stack, LayoutinferernceTest_1) {
1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1,
};
std::vector<float> golden = {
64, 77, 49, 44, 81, 97, 64, 77, 49, 44, 81, 97
64, 64, 49, 49, 81, 81, 77, 77, 44, 44, 97, 97
};
auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
@ -180,7 +180,7 @@ TEST(Stack, LayoutinferernceTest_3) {
1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1,
};
std::vector<float> golden = {
55, 39, 21, 28, 37, 41, 49, 55, 28, 24, 40, 41
55, 49, 21, 28, 37, 40, 39, 55, 28, 24, 41, 41,
};
auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
auto kernel2_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());