Fixed layout inference bug for stack (#375)
Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
parent
eab0d807a6
commit
11572140d2
|
|
@ -53,30 +53,28 @@ class StackLayoutInfer : public OpLayoutInfer {
|
|||
(*stack).BindInput(context_->GetMapedTensor(i_src));
|
||||
}
|
||||
|
||||
std::vector<uint32_t> v;
|
||||
uint32_t dim_num = src_input->GetShape().size();
|
||||
if (axis < 0) {
|
||||
axis += dim_num;
|
||||
}
|
||||
for (uint32_t i = 0; i < src_input->GetShape().size(); ++i) {
|
||||
if (input_pv->At(i) > (uint32_t)axis) {
|
||||
v.push_back(input_pv->At(i) + 1);
|
||||
} else if (input_pv->At(i) == (uint32_t)axis) {
|
||||
v.push_back(input_pv->At(i));
|
||||
v.push_back(input_pv->At(i) + 1);
|
||||
} else {
|
||||
v.push_back(input_pv->At(i));
|
||||
}
|
||||
}
|
||||
auto out_pv =
|
||||
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
|
||||
for (uint32_t i = 0; i < out_pv->Rank(); ++i) {
|
||||
out_pv->At(i) = v[i];
|
||||
axis += src_input->GetShape().size();
|
||||
}
|
||||
|
||||
auto out_infer = CreateOutputsTensor(out_pv);
|
||||
auto output_pv = MakeShared(input_pv->Rank() + 1);
|
||||
if (!input_pv->IsAligned()) {
|
||||
output_pv->At(axis) = (uint32_t)axis;
|
||||
for (uint32_t i = 0, j = 0; i < input_pv->Rank(); ++i, ++j) {
|
||||
if ((uint32_t)axis == i) {
|
||||
++j;
|
||||
}
|
||||
if (input_pv->At(i) < (uint32_t)axis) {
|
||||
output_pv->At(j) = input_pv->At(i);
|
||||
} else {
|
||||
output_pv->At(j) = input_pv->At(i) + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto out_infer = CreateOutputsTensor(output_pv);
|
||||
(*stack).BindOutput(out_infer[0]);
|
||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
|
||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], output_pv);
|
||||
// Add out tensor of src_graph into next_tensor
|
||||
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ TEST(Stack, LayoutinferernceTest_1) {
|
|||
1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1,
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
64, 77, 49, 44, 81, 97, 64, 77, 49, 44, 81, 97
|
||||
64, 64, 49, 49, 81, 81, 77, 77, 44, 44, 97, 97
|
||||
};
|
||||
auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
||||
|
||||
|
|
@ -180,7 +180,7 @@ TEST(Stack, LayoutinferernceTest_3) {
|
|||
1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1,
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
55, 39, 21, 28, 37, 41, 49, 55, 28, 24, 40, 41
|
||||
55, 49, 21, 28, 37, 40, 39, 55, 28, 24, 41, 41,
|
||||
};
|
||||
auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
||||
auto kernel2_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
||||
|
|
|
|||
Loading…
Reference in New Issue