fix layoutinfer crash when logical op inputs are different rank (#667)

Type: Bug fix

Signed-off-by: Chen <jack.chen@verisilicon.com>
Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-12-13 09:57:17 +08:00 committed by GitHub
parent 0dc7a3465e
commit 11d12f03a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 0 deletions

View File

@ -40,6 +40,33 @@ class LogicalOpsLayoutInfer : public OpLayoutInfer {
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ != tim::vx::CONSTANT &&
in_1->GetSpec().attr_ != tim::vx::CONSTANT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto rank_long = pv_long->Rank();
auto rank_short = pv_short->Rank();
auto expanded_pv = MakeShared(rank_long);
// if different size, expand short pv to long pv
for (uint32_t i = 0; i < rank_short; ++i) {
expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}
std::vector<uint32_t> expanded_shape(short_tensor->GetShape());
for (uint32_t i = 0; i < rank_long; ++i) {
if (i >= rank_short) expanded_shape.push_back(1);
}
short_tensor->GetSpec().SetShape(expanded_shape);
context_->SetPermuteVector(short_tensor,
expanded_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForMutilInputs();
auto infer_out = CreateOutputsTensor(required_pv);
auto logical_op = context_->infer_graph_->CreateOperation<OpTpye>();