From 98b9759663a1eb6189cf86d93d8eb9d25925280d Mon Sep 17 00:00:00 2001 From: "yuenan.li" Date: Mon, 28 Jun 2021 14:19:30 +0800 Subject: [PATCH] Refine arg in layout inference Signed-off-by: yuenan.li --- src/tim/transform/ops/arg_layout_inference.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tim/transform/ops/arg_layout_inference.h b/src/tim/transform/ops/arg_layout_inference.h index 618de29..eb0a745 100644 --- a/src/tim/transform/ops/arg_layout_inference.h +++ b/src/tim/transform/ops/arg_layout_inference.h @@ -38,12 +38,12 @@ class ArgMaxLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + ReverseInputsPermuteVector(); assert(1 == op_->impl()->InputsTensor().size()); auto src_input = op_->impl()->InputsTensor()[0]; auto input_pv = context_->GetPermuteVector(src_input); uint32_t axis = op_->impl()->node()->nn_param.argmax.axis; - axis = MapAxis(input_pv->AsStdVec(), axis); auto argmax = context_->infer_graph_->CreateOperation(axis); @@ -65,12 +65,12 @@ class ArgMinLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + ReverseInputsPermuteVector(); assert(1 == op_->impl()->InputsTensor().size()); auto src_input = op_->impl()->InputsTensor()[0]; auto input_pv = context_->GetPermuteVector(src_input); uint32_t axis = op_->impl()->node()->nn_param.argmin.axis; - axis = MapAxis(input_pv->AsStdVec(), axis); auto argmin = context_->infer_graph_->CreateOperation(axis);