Refine arg in layout inference
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
be066fb9bd
commit
98b9759663
|
|
@ -38,12 +38,12 @@ class ArgMaxLayoutInfer : public OpLayoutInfer {
|
|||
|
||||
void OnInputs(
|
||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||
ReverseInputsPermuteVector();
|
||||
assert(1 == op_->impl()->InputsTensor().size());
|
||||
auto src_input = op_->impl()->InputsTensor()[0];
|
||||
auto input_pv = context_->GetPermuteVector(src_input);
|
||||
|
||||
uint32_t axis = op_->impl()->node()->nn_param.argmax.axis;
|
||||
axis = MapAxis(input_pv->AsStdVec(), axis);
|
||||
|
||||
auto argmax =
|
||||
context_->infer_graph_->CreateOperation<vx::ops::ArgMax>(axis);
|
||||
|
|
@ -65,12 +65,12 @@ class ArgMinLayoutInfer : public OpLayoutInfer {
|
|||
|
||||
void OnInputs(
|
||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||
ReverseInputsPermuteVector();
|
||||
assert(1 == op_->impl()->InputsTensor().size());
|
||||
auto src_input = op_->impl()->InputsTensor()[0];
|
||||
auto input_pv = context_->GetPermuteVector(src_input);
|
||||
|
||||
uint32_t axis = op_->impl()->node()->nn_param.argmin.axis;
|
||||
axis = MapAxis(input_pv->AsStdVec(), axis);
|
||||
|
||||
auto argmin =
|
||||
context_->infer_graph_->CreateOperation<vx::ops::ArgMin>(axis);
|
||||
|
|
|
|||
Loading…
Reference in New Issue