Refine arg in layout inference

Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
yuenan.li 2021-06-28 14:19:30 +08:00 committed by Kainan Cha
parent be066fb9bd
commit 98b9759663
1 changed files with 2 additions and 2 deletions

View File

@ -38,12 +38,12 @@ class ArgMaxLayoutInfer : public OpLayoutInfer {
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
ReverseInputsPermuteVector();
assert(1 == op_->impl()->InputsTensor().size());
auto src_input = op_->impl()->InputsTensor()[0];
auto input_pv = context_->GetPermuteVector(src_input);
uint32_t axis = op_->impl()->node()->nn_param.argmax.axis;
axis = MapAxis(input_pv->AsStdVec(), axis);
auto argmax =
context_->infer_graph_->CreateOperation<vx::ops::ArgMax>(axis);
@ -65,12 +65,12 @@ class ArgMinLayoutInfer : public OpLayoutInfer {
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
ReverseInputsPermuteVector();
assert(1 == op_->impl()->InputsTensor().size());
auto src_input = op_->impl()->InputsTensor()[0];
auto input_pv = context_->GetPermuteVector(src_input);
uint32_t axis = op_->impl()->node()->nn_param.argmin.axis;
axis = MapAxis(input_pv->AsStdVec(), axis);
auto argmin =
context_->infer_graph_->CreateOperation<vx::ops::ArgMin>(axis);