diff --git a/src/tim/transform/ops/transpose_layout_inference.h b/src/tim/transform/ops/transpose_layout_inference.h index bda00a4..e926a93 100644 --- a/src/tim/transform/ops/transpose_layout_inference.h +++ b/src/tim/transform/ops/transpose_layout_inference.h @@ -69,7 +69,8 @@ class TransposeLayoutInfer : public OpLayoutInfer { context_->infer_graph_->CreateOperation( final_pv->AsStdVec()); transpose_op->BindInput(infer_input); - auto infer_out = CreateOutputsTensor(final_pv); + // The layout after final_pv permute is the default sequence + auto infer_out = CreateOutputsTensor(MakeShared(perm.size())); transpose_op->BindOutput(infer_out[0]); } context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], MakeShared(perm.size()));