Fixed layout inference bug for stack (#375)
Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
parent
eab0d807a6
commit
11572140d2
|
|
@ -53,30 +53,28 @@ class StackLayoutInfer : public OpLayoutInfer {
|
||||||
(*stack).BindInput(context_->GetMapedTensor(i_src));
|
(*stack).BindInput(context_->GetMapedTensor(i_src));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> v;
|
|
||||||
uint32_t dim_num = src_input->GetShape().size();
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += dim_num;
|
axis += src_input->GetShape().size();
|
||||||
}
|
|
||||||
for (uint32_t i = 0; i < src_input->GetShape().size(); ++i) {
|
|
||||||
if (input_pv->At(i) > (uint32_t)axis) {
|
|
||||||
v.push_back(input_pv->At(i) + 1);
|
|
||||||
} else if (input_pv->At(i) == (uint32_t)axis) {
|
|
||||||
v.push_back(input_pv->At(i));
|
|
||||||
v.push_back(input_pv->At(i) + 1);
|
|
||||||
} else {
|
|
||||||
v.push_back(input_pv->At(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto out_pv =
|
|
||||||
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
|
|
||||||
for (uint32_t i = 0; i < out_pv->Rank(); ++i) {
|
|
||||||
out_pv->At(i) = v[i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out_infer = CreateOutputsTensor(out_pv);
|
auto output_pv = MakeShared(input_pv->Rank() + 1);
|
||||||
|
if (!input_pv->IsAligned()) {
|
||||||
|
output_pv->At(axis) = (uint32_t)axis;
|
||||||
|
for (uint32_t i = 0, j = 0; i < input_pv->Rank(); ++i, ++j) {
|
||||||
|
if ((uint32_t)axis == i) {
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
if (input_pv->At(i) < (uint32_t)axis) {
|
||||||
|
output_pv->At(j) = input_pv->At(i);
|
||||||
|
} else {
|
||||||
|
output_pv->At(j) = input_pv->At(i) + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_infer = CreateOutputsTensor(output_pv);
|
||||||
(*stack).BindOutput(out_infer[0]);
|
(*stack).BindOutput(out_infer[0]);
|
||||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
|
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], output_pv);
|
||||||
// Add out tensor of src_graph into next_tensor
|
// Add out tensor of src_graph into next_tensor
|
||||||
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
|
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ TEST(Stack, LayoutinferernceTest_1) {
|
||||||
1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1,
|
1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1,
|
||||||
};
|
};
|
||||||
std::vector<float> golden = {
|
std::vector<float> golden = {
|
||||||
64, 77, 49, 44, 81, 97, 64, 77, 49, 44, 81, 97
|
64, 64, 49, 49, 81, 81, 77, 77, 44, 44, 97, 97
|
||||||
};
|
};
|
||||||
auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
||||||
|
|
||||||
|
|
@ -180,7 +180,7 @@ TEST(Stack, LayoutinferernceTest_3) {
|
||||||
1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1,
|
1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1,
|
||||||
};
|
};
|
||||||
std::vector<float> golden = {
|
std::vector<float> golden = {
|
||||||
55, 39, 21, 28, 37, 41, 49, 55, 28, 24, 40, 41
|
55, 49, 21, 28, 37, 40, 39, 55, 28, 24, 41, 41,
|
||||||
};
|
};
|
||||||
auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
||||||
auto kernel2_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
auto kernel2_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue