Added axis param for TopK (#610)
Topk support specifying dimensions with later internal ovxlib Type: New Feature Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
18749f5d05
commit
62c6b6560c
|
|
@ -39,11 +39,12 @@ namespace ops {
|
||||||
* Finds values and indices of the k largest entries for the last dimension.
|
* Finds values and indices of the k largest entries for the last dimension.
|
||||||
*
|
*
|
||||||
* - k : Number of top elements to look for along the last dimension.
|
* - k : Number of top elements to look for along the last dimension.
|
||||||
|
* -axis : Dimension on which to do th sort. Default is 0.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class Topk : public BuiltinOp {
|
class Topk : public BuiltinOp {
|
||||||
public:
|
public:
|
||||||
Topk(Graph* graph, uint32_t k);
|
Topk(Graph* graph, uint32_t k, int32_t axis = 0);
|
||||||
|
|
||||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,11 @@
|
||||||
[
|
[
|
||||||
{"name":"k",
|
{"name":"k",
|
||||||
"dtype": "uint32_t"
|
"dtype": "uint32_t"
|
||||||
|
},
|
||||||
|
{"name":"axis",
|
||||||
|
"dtype": "uint32_t",
|
||||||
|
"Optional":"true",
|
||||||
|
"default":"0"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,13 +31,15 @@ namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
Topk::Topk(Graph* graph, uint32_t k)
|
Topk::Topk(Graph* graph, uint32_t k, int32_t axis)
|
||||||
: BuiltinOp(graph, VSI_NN_OP_TOPK) {
|
: BuiltinOp(graph, VSI_NN_OP_TOPK) {
|
||||||
this->impl()->node()->nn_param.topk.k = k;
|
this->impl()->node()->nn_param.topk.k = k;
|
||||||
|
this->impl()->node()->nn_param.topk.axis = axis;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Operation> Topk::Clone(std::shared_ptr<Graph>& graph) const {
|
std::shared_ptr<Operation> Topk::Clone(std::shared_ptr<Graph>& graph) const {
|
||||||
return graph->CreateOperation<Topk>(this->impl()->node()->nn_param.topk.k);
|
return graph->CreateOperation<Topk>(this->impl()->node()->nn_param.topk.k,
|
||||||
|
this->impl()->node()->nn_param.topk.axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue