diff --git a/src/tim/transform/ops/elementwise_layout_inference.h b/src/tim/transform/ops/elementwise_layout_inference.h index 102609d..1248e86 100644 --- a/src/tim/transform/ops/elementwise_layout_inference.h +++ b/src/tim/transform/ops/elementwise_layout_inference.h @@ -42,6 +42,32 @@ class ElementWiseLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + auto in_0 = op_->impl()->InputsTensor()[0]; + auto in_1 = op_->impl()->InputsTensor()[1]; + std::shared_ptr short_tensor = + in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0; + std::shared_ptr long_tensor = + in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0; + if (in_0->GetSpec().attr_ != tim::vx::CONSTANT && + in_1->GetSpec().attr_ != tim::vx::CONSTANT && + in_0->GetShape().size() != in_1->GetShape().size()) { + auto pv_long = context_->GetPermuteVector(long_tensor); + auto pv_short = context_->GetPermuteVector(short_tensor); + auto rank_long = pv_long->Rank(); + auto rank_short = pv_short->Rank(); + auto expanded_pv = MakeShared(rank_long); + // if different size, expand short pv to long pv + for (uint32_t i = 0; i < rank_short; ++i) { + expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv + } + std::vector expanded_shape(short_tensor->GetShape()); + for (uint32_t i = 0; i < rank_long; ++i) { + if (i >= rank_short) expanded_shape.push_back(1); + } + short_tensor->GetSpec().SetShape(expanded_shape); + + context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv + } auto required_pv = AlignPermuteVectorForElementWise(); auto elementwise = context_->infer_graph_->CreateOperation(); for (const auto& i_src : op_->impl()->InputsTensor()) { @@ -63,6 +89,32 @@ class MultiplyLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + auto in_0 = op_->impl()->InputsTensor()[0]; + auto in_1 = op_->impl()->InputsTensor()[1]; + std::shared_ptr short_tensor = + in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0; + std::shared_ptr long_tensor = + in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0; + if (in_0->GetSpec().attr_ != tim::vx::CONSTANT && + in_1->GetSpec().attr_ != tim::vx::CONSTANT && + in_0->GetShape().size() != in_1->GetShape().size()) { + auto pv_long = context_->GetPermuteVector(long_tensor); + auto pv_short = context_->GetPermuteVector(short_tensor); + auto rank_long = pv_long->Rank(); + auto rank_short = pv_short->Rank(); + auto expanded_pv = MakeShared(rank_long); + // if different size, expand short pv to long pv + for (uint32_t i = 0; i < rank_short; ++i) { + expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv + } + std::vector expanded_shape(short_tensor->GetShape()); + for (uint32_t i = 0; i < rank_long; ++i) { + if (i >= rank_short) expanded_shape.push_back(1); + } + short_tensor->GetSpec().SetShape(expanded_shape); + + context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv + } auto required_pv = AlignPermuteVectorForElementWise(); auto multiply = context_->infer_graph_->CreateOperation(