diff --git a/src/tim/transform/ops/arg_layout_inference.h b/src/tim/transform/ops/arg_layout_inference.h index 618de29..eb0a745 100644 --- a/src/tim/transform/ops/arg_layout_inference.h +++ b/src/tim/transform/ops/arg_layout_inference.h @@ -38,12 +38,12 @@ class ArgMaxLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + ReverseInputsPermuteVector(); assert(1 == op_->impl()->InputsTensor().size()); auto src_input = op_->impl()->InputsTensor()[0]; auto input_pv = context_->GetPermuteVector(src_input); uint32_t axis = op_->impl()->node()->nn_param.argmax.axis; - axis = MapAxis(input_pv->AsStdVec(), axis); auto argmax = context_->infer_graph_->CreateOperation(axis); @@ -65,12 +65,12 @@ class ArgMinLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + ReverseInputsPermuteVector(); assert(1 == op_->impl()->InputsTensor().size()); auto src_input = op_->impl()->InputsTensor()[0]; auto input_pv = context_->GetPermuteVector(src_input); uint32_t axis = op_->impl()->node()->nn_param.argmin.axis; - axis = MapAxis(input_pv->AsStdVec(), axis); auto argmin = context_->infer_graph_->CreateOperation(axis);