Added axis param for TopK (#610)

Topk support specifying dimensions with later internal ovxlib

Type: New Feature

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-07-12 09:54:07 +08:00 committed by GitHub
parent 18749f5d05
commit 62c6b6560c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 3 deletions

View File

@ -39,11 +39,12 @@ namespace ops {
* Finds values and indices of the k largest entries for the last dimension.
*
* - k : Number of top elements to look for along the last dimension.
* -axis : Dimension on which to do th sort. Default is 0.
*/
class Topk : public BuiltinOp {
public:
Topk(Graph* graph, uint32_t k);
Topk(Graph* graph, uint32_t k, int32_t axis = 0);
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
};

View File

@ -4,6 +4,11 @@
[
{"name":"k",
"dtype": "uint32_t"
},
{"name":"axis",
"dtype": "uint32_t",
"Optional":"true",
"default":"0"
}
]
}

View File

@ -31,13 +31,15 @@ namespace tim {
namespace vx {
namespace ops {
Topk::Topk(Graph* graph, uint32_t k)
Topk::Topk(Graph* graph, uint32_t k, int32_t axis)
: BuiltinOp(graph, VSI_NN_OP_TOPK) {
this->impl()->node()->nn_param.topk.k = k;
this->impl()->node()->nn_param.topk.axis = axis;
}
std::shared_ptr<Operation> Topk::Clone(std::shared_ptr<Graph>& graph) const {
return graph->CreateOperation<Topk>(this->impl()->node()->nn_param.topk.k);
return graph->CreateOperation<Topk>(this->impl()->node()->nn_param.topk.k,
this->impl()->node()->nn_param.topk.axis);
}
} // namespace ops