From 720f0a485a3952560d6d9ab7bea0e6abc1514680 Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Wed, 6 Dec 2023 17:15:15 +0800 Subject: [PATCH] fix crash when eletwise inputs are different rank (#665) Fix crash in AlignPermuteVectorForElmentWise() if inputs tensor have different rank Type: Bug fix Signed-off-by: Chen --- .../ops/elementwise_layout_inference.h | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/tim/transform/ops/elementwise_layout_inference.h b/src/tim/transform/ops/elementwise_layout_inference.h index 102609d..1248e86 100644 --- a/src/tim/transform/ops/elementwise_layout_inference.h +++ b/src/tim/transform/ops/elementwise_layout_inference.h @@ -42,6 +42,32 @@ class ElementWiseLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + auto in_0 = op_->impl()->InputsTensor()[0]; + auto in_1 = op_->impl()->InputsTensor()[1]; + std::shared_ptr short_tensor = + in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0; + std::shared_ptr long_tensor = + in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0; + if (in_0->GetSpec().attr_ != tim::vx::CONSTANT && + in_1->GetSpec().attr_ != tim::vx::CONSTANT && + in_0->GetShape().size() != in_1->GetShape().size()) { + auto pv_long = context_->GetPermuteVector(long_tensor); + auto pv_short = context_->GetPermuteVector(short_tensor); + auto rank_long = pv_long->Rank(); + auto rank_short = pv_short->Rank(); + auto expanded_pv = MakeShared(rank_long); + // if different size, expand short pv to long pv + for (uint32_t i = 0; i < rank_short; ++i) { + expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv + } + std::vector expanded_shape(short_tensor->GetShape()); + for (uint32_t i = 0; i < rank_long; ++i) { + if (i >= rank_short) expanded_shape.push_back(1); + } + short_tensor->GetSpec().SetShape(expanded_shape); + + context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv + } auto required_pv = AlignPermuteVectorForElementWise(); auto elementwise = context_->infer_graph_->CreateOperation(); for (const auto& i_src : op_->impl()->InputsTensor()) { @@ -63,6 +89,32 @@ class MultiplyLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + auto in_0 = op_->impl()->InputsTensor()[0]; + auto in_1 = op_->impl()->InputsTensor()[1]; + std::shared_ptr short_tensor = + in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0; + std::shared_ptr long_tensor = + in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0; + if (in_0->GetSpec().attr_ != tim::vx::CONSTANT && + in_1->GetSpec().attr_ != tim::vx::CONSTANT && + in_0->GetShape().size() != in_1->GetShape().size()) { + auto pv_long = context_->GetPermuteVector(long_tensor); + auto pv_short = context_->GetPermuteVector(short_tensor); + auto rank_long = pv_long->Rank(); + auto rank_short = pv_short->Rank(); + auto expanded_pv = MakeShared(rank_long); + // if different size, expand short pv to long pv + for (uint32_t i = 0; i < rank_short; ++i) { + expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv + } + std::vector expanded_shape(short_tensor->GetShape()); + for (uint32_t i = 0; i < rank_long; ++i) { + if (i >= rank_short) expanded_shape.push_back(1); + } + short_tensor->GetSpec().SetShape(expanded_shape); + + context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv + } auto required_pv = AlignPermuteVectorForElementWise(); auto multiply = context_->infer_graph_->CreateOperation(