Added axis param for TopK (#610)
Topk support specifying dimensions with later internal ovxlib Type: New Feature Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
18749f5d05
commit
62c6b6560c
|
|
@ -39,11 +39,12 @@ namespace ops {
|
|||
* Finds values and indices of the k largest entries for the last dimension.
|
||||
*
|
||||
* - k : Number of top elements to look for along the last dimension.
|
||||
* -axis : Dimension on which to do th sort. Default is 0.
|
||||
*/
|
||||
|
||||
class Topk : public BuiltinOp {
|
||||
public:
|
||||
Topk(Graph* graph, uint32_t k);
|
||||
Topk(Graph* graph, uint32_t k, int32_t axis = 0);
|
||||
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@
|
|||
[
|
||||
{"name":"k",
|
||||
"dtype": "uint32_t"
|
||||
},
|
||||
{"name":"axis",
|
||||
"dtype": "uint32_t",
|
||||
"Optional":"true",
|
||||
"default":"0"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,13 +31,15 @@ namespace tim {
|
|||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
Topk::Topk(Graph* graph, uint32_t k)
|
||||
Topk::Topk(Graph* graph, uint32_t k, int32_t axis)
|
||||
: BuiltinOp(graph, VSI_NN_OP_TOPK) {
|
||||
this->impl()->node()->nn_param.topk.k = k;
|
||||
this->impl()->node()->nn_param.topk.axis = axis;
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> Topk::Clone(std::shared_ptr<Graph>& graph) const {
|
||||
return graph->CreateOperation<Topk>(this->impl()->node()->nn_param.topk.k);
|
||||
return graph->CreateOperation<Topk>(this->impl()->node()->nn_param.topk.k,
|
||||
this->impl()->node()->nn_param.topk.axis);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
Loading…
Reference in New Issue