fix slope shpae 1 crash issue (#663)
graph compile will crash when shape is broadcast from 1 to 1,1,1,1 Type: Bug fix Signed-off-by: Chen <jack.chen@verisilicon.com> Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
4578f40953
commit
e013cf0a65
|
|
@ -73,40 +73,56 @@ class PReluLayoutInfer : public OpLayoutInfer {
|
|||
auto slope_shape = src_slope->GetShape();
|
||||
auto input_pv = context_->GetPermuteVector(src_input);
|
||||
std::vector<uint32_t> boardcast_shape;
|
||||
for (uint32_t i = 0; i < input_shape.size(); ++i) {
|
||||
if (i < slope_shape.size()) {
|
||||
boardcast_shape.push_back(slope_shape[i]);
|
||||
} else {
|
||||
boardcast_shape.push_back(1);
|
||||
if (slope_shape.size() != 1) { // Need to be transposed along with the input
|
||||
for (uint32_t i = 0; i < input_shape.size(); ++i) {
|
||||
if (i < slope_shape.size()) {
|
||||
boardcast_shape.push_back(slope_shape[i]);
|
||||
} else {
|
||||
boardcast_shape.push_back(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (src_slope->IsConstTensor()) {
|
||||
std::vector<uint8_t> dataRef(src_slope->GetSpec().GetByteSize());
|
||||
src_slope->CopyDataFromTensor(dataRef.data());
|
||||
auto infer_slope_spec = src_slope->GetSpec();
|
||||
infer_slope_spec.SetShape(boardcast_shape);
|
||||
auto infer_slope = context_->infer_graph_->CreateTensor(
|
||||
infer_slope_spec, (const void*)dataRef.data());
|
||||
if (src_slope->IsConstTensor()) {
|
||||
std::vector<uint8_t> dataRef(src_slope->GetSpec().GetByteSize());
|
||||
src_slope->CopyDataFromTensor(dataRef.data());
|
||||
auto infer_slope_spec = src_slope->GetSpec();
|
||||
infer_slope_spec.SetShape(boardcast_shape);
|
||||
auto infer_slope = context_->infer_graph_->CreateTensor(
|
||||
infer_slope_spec, (const void*)dataRef.data());
|
||||
|
||||
if (!input_pv->IsAligned()) {
|
||||
//The dimension of slop is already the same as input, directly use input_pv to convert
|
||||
auto out_slope = PermuteConstTensor(infer_slope, input_pv);
|
||||
context_->UpdateTensorMap(src_slope, out_slope);
|
||||
if (!input_pv->IsAligned()) {
|
||||
//The dimension of slop is already the same as input, directly use input_pv to convert
|
||||
auto out_slope = PermuteConstTensor(infer_slope, input_pv);
|
||||
context_->UpdateTensorMap(src_slope, out_slope);
|
||||
} else {
|
||||
context_->UpdateTensorMap(src_slope, infer_slope);
|
||||
}
|
||||
} else {
|
||||
auto infer_slope_spec = src_slope->GetSpec().AsTransientSpec();
|
||||
auto reshape_out =
|
||||
context_->infer_graph_->CreateTensor(infer_slope_spec);
|
||||
boardcast_shape =
|
||||
MapMultipleAxis(input_pv->AsStdVec(), boardcast_shape);
|
||||
auto reshape =
|
||||
context_->infer_graph_->CreateOperation<vx::ops::Reshape>(
|
||||
boardcast_shape);
|
||||
(*reshape)
|
||||
.BindInput(context_->GetMapedTensor(src_slope))
|
||||
.BindOutput(reshape_out);
|
||||
context_->UpdateTensorMap(src_slope, reshape_out);
|
||||
}
|
||||
context_->SetPermuteVector(src_slope, input_pv);
|
||||
} else { // 1d slope tensor need not transpose
|
||||
if (src_slope->IsConstTensor()) {
|
||||
std::vector<uint8_t> dataRef(src_slope->GetSpec().GetByteSize());
|
||||
src_slope->CopyDataFromTensor(dataRef.data());
|
||||
auto infer_slope_spec = src_slope->GetSpec();
|
||||
auto infer_slope = context_->infer_graph_->CreateTensor(
|
||||
infer_slope_spec, (const void*)dataRef.data());
|
||||
context_->UpdateTensorMap(src_slope, infer_slope);
|
||||
context_->SetPermuteVector(src_slope, MakeShared(1));
|
||||
}
|
||||
} else {
|
||||
auto infer_slope_spec = src_slope->GetSpec().AsTransientSpec();
|
||||
auto reshape_out = context_->infer_graph_->CreateTensor(infer_slope_spec);
|
||||
boardcast_shape = MapMultipleAxis(input_pv->AsStdVec(), boardcast_shape);
|
||||
auto reshape = context_->infer_graph_->CreateOperation<vx::ops::Reshape>(boardcast_shape);
|
||||
(*reshape)
|
||||
.BindInput(context_->GetMapedTensor(src_slope))
|
||||
.BindOutput(reshape_out);
|
||||
context_->UpdateTensorMap(src_slope, reshape_out);
|
||||
}
|
||||
context_->SetPermuteVector(src_slope, input_pv);
|
||||
|
||||
auto axis =
|
||||
MapAxis(input_pv->AsStdVec(), op_->impl()->node()->nn_param.prelu.axis);
|
||||
|
|
|
|||
Loading…
Reference in New Issue