fixed bug when broadcast dimensions is negative
Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
parent
a038df2a84
commit
3fed6d6757
|
|
@ -36,10 +36,10 @@ namespace ops {
|
|||
*
|
||||
* Input:
|
||||
* - input.
|
||||
*
|
||||
*
|
||||
* Attribute:
|
||||
* - shape: the shape which broadcast to.
|
||||
* - dimensions(optional): Which dimension in the target shape each dimension
|
||||
* - dimensions(optional): Which dimension in the target shape each dimension
|
||||
* of the operand shape corresponds to. For BroadcastInDim.
|
||||
*/
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class Broadcast : public BuiltinOp {
|
|||
|
||||
protected:
|
||||
const std::vector<int32_t> shape_;
|
||||
const std::vector<int32_t> dimensions_;
|
||||
std::vector<int32_t> dimensions_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ Broadcast::Broadcast(Graph* graph, const std::vector<int32_t>& shape,
|
|||
this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size();
|
||||
if (dimensions.size() > 0)
|
||||
{
|
||||
int dim_num = shape.size();
|
||||
for (uint32_t i = 0; i < dimensions.size(); ++i) {
|
||||
dimensions_[i] += (dimensions[i] < 0 ? dim_num : 0U);
|
||||
}
|
||||
this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data();
|
||||
} else {
|
||||
this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue