fix layoutinfer crash when logical op inputs are different rank (#667)
Type: Bug fix Signed-off-by: Chen <jack.chen@verisilicon.com> Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
0dc7a3465e
commit
11d12f03a8
|
|
@ -40,6 +40,33 @@ class LogicalOpsLayoutInfer : public OpLayoutInfer {
|
|||
|
||||
void OnInputs(
|
||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||
auto in_0 = op_->impl()->InputsTensor()[0];
|
||||
auto in_1 = op_->impl()->InputsTensor()[1];
|
||||
std::shared_ptr<tim::vx::Tensor> short_tensor =
|
||||
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
|
||||
std::shared_ptr<tim::vx::Tensor> long_tensor =
|
||||
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
|
||||
if (in_0->GetSpec().attr_ != tim::vx::CONSTANT &&
|
||||
in_1->GetSpec().attr_ != tim::vx::CONSTANT &&
|
||||
in_0->GetShape().size() != in_1->GetShape().size()) {
|
||||
auto pv_long = context_->GetPermuteVector(long_tensor);
|
||||
auto pv_short = context_->GetPermuteVector(short_tensor);
|
||||
auto rank_long = pv_long->Rank();
|
||||
auto rank_short = pv_short->Rank();
|
||||
auto expanded_pv = MakeShared(rank_long);
|
||||
// if different size, expand short pv to long pv
|
||||
for (uint32_t i = 0; i < rank_short; ++i) {
|
||||
expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv
|
||||
}
|
||||
std::vector<uint32_t> expanded_shape(short_tensor->GetShape());
|
||||
for (uint32_t i = 0; i < rank_long; ++i) {
|
||||
if (i >= rank_short) expanded_shape.push_back(1);
|
||||
}
|
||||
short_tensor->GetSpec().SetShape(expanded_shape);
|
||||
|
||||
context_->SetPermuteVector(short_tensor,
|
||||
expanded_pv); // set new expand pv
|
||||
}
|
||||
auto required_pv = AlignPermuteVectorForMutilInputs();
|
||||
auto infer_out = CreateOutputsTensor(required_pv);
|
||||
auto logical_op = context_->infer_graph_->CreateOperation<OpTpye>();
|
||||
|
|
|
|||
Loading…
Reference in New Issue