Refine arg in layout inference
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
be066fb9bd
commit
98b9759663
|
|
@ -38,12 +38,12 @@ class ArgMaxLayoutInfer : public OpLayoutInfer {
|
||||||
|
|
||||||
void OnInputs(
|
void OnInputs(
|
||||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||||
|
ReverseInputsPermuteVector();
|
||||||
assert(1 == op_->impl()->InputsTensor().size());
|
assert(1 == op_->impl()->InputsTensor().size());
|
||||||
auto src_input = op_->impl()->InputsTensor()[0];
|
auto src_input = op_->impl()->InputsTensor()[0];
|
||||||
auto input_pv = context_->GetPermuteVector(src_input);
|
auto input_pv = context_->GetPermuteVector(src_input);
|
||||||
|
|
||||||
uint32_t axis = op_->impl()->node()->nn_param.argmax.axis;
|
uint32_t axis = op_->impl()->node()->nn_param.argmax.axis;
|
||||||
axis = MapAxis(input_pv->AsStdVec(), axis);
|
|
||||||
|
|
||||||
auto argmax =
|
auto argmax =
|
||||||
context_->infer_graph_->CreateOperation<vx::ops::ArgMax>(axis);
|
context_->infer_graph_->CreateOperation<vx::ops::ArgMax>(axis);
|
||||||
|
|
@ -65,12 +65,12 @@ class ArgMinLayoutInfer : public OpLayoutInfer {
|
||||||
|
|
||||||
void OnInputs(
|
void OnInputs(
|
||||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||||
|
ReverseInputsPermuteVector();
|
||||||
assert(1 == op_->impl()->InputsTensor().size());
|
assert(1 == op_->impl()->InputsTensor().size());
|
||||||
auto src_input = op_->impl()->InputsTensor()[0];
|
auto src_input = op_->impl()->InputsTensor()[0];
|
||||||
auto input_pv = context_->GetPermuteVector(src_input);
|
auto input_pv = context_->GetPermuteVector(src_input);
|
||||||
|
|
||||||
uint32_t axis = op_->impl()->node()->nn_param.argmin.axis;
|
uint32_t axis = op_->impl()->node()->nn_param.argmin.axis;
|
||||||
axis = MapAxis(input_pv->AsStdVec(), axis);
|
|
||||||
|
|
||||||
auto argmin =
|
auto argmin =
|
||||||
context_->infer_graph_->CreateOperation<vx::ops::ArgMin>(axis);
|
context_->infer_graph_->CreateOperation<vx::ops::ArgMin>(axis);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue