diff --git a/include/tim/vx/ops/topk.h b/include/tim/vx/ops/topk.h index d7b584f..eaf13d1 100644 --- a/include/tim/vx/ops/topk.h +++ b/include/tim/vx/ops/topk.h @@ -39,11 +39,12 @@ namespace ops { * Finds values and indices of the k largest entries for the last dimension. * * - k : Number of top elements to look for along the last dimension. + * -axis : Dimension on which to do th sort. Default is 0. */ class Topk : public BuiltinOp { public: - Topk(Graph* graph, uint32_t k); + Topk(Graph* graph, uint32_t k, int32_t axis = 0); std::shared_ptr Clone(std::shared_ptr& graph) const override; }; diff --git a/include/tim/vx/ops/topk.json b/include/tim/vx/ops/topk.json index 9b60d67..1dac118 100755 --- a/include/tim/vx/ops/topk.json +++ b/include/tim/vx/ops/topk.json @@ -4,6 +4,11 @@ [ {"name":"k", "dtype": "uint32_t" + }, + {"name":"axis", + "dtype": "uint32_t", + "Optional":"true", + "default":"0" } ] } diff --git a/src/tim/vx/ops/topk.cc b/src/tim/vx/ops/topk.cc index 1785c67..a43d581 100644 --- a/src/tim/vx/ops/topk.cc +++ b/src/tim/vx/ops/topk.cc @@ -31,13 +31,15 @@ namespace tim { namespace vx { namespace ops { -Topk::Topk(Graph* graph, uint32_t k) +Topk::Topk(Graph* graph, uint32_t k, int32_t axis) : BuiltinOp(graph, VSI_NN_OP_TOPK) { this->impl()->node()->nn_param.topk.k = k; + this->impl()->node()->nn_param.topk.axis = axis; } std::shared_ptr Topk::Clone(std::shared_ptr& graph) const { - return graph->CreateOperation(this->impl()->node()->nn_param.topk.k); + return graph->CreateOperation(this->impl()->node()->nn_param.topk.k, + this->impl()->node()->nn_param.topk.axis); } } // namespace ops