fixed bug when broadcast dimensions is negative

Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
Chen Xin 2022-10-08 14:50:14 +08:00 committed by Sven
parent a038df2a84
commit 3fed6d6757
2 changed files with 7 additions and 3 deletions

View File

@ -36,10 +36,10 @@ namespace ops {
*
* Input:
* - input.
*
*
* Attribute:
* - shape: the shape which broadcast to.
* - dimensions(optional): Which dimension in the target shape each dimension
* - dimensions(optional): Which dimension in the target shape each dimension
* of the operand shape corresponds to. For BroadcastInDim.
*/
@ -51,7 +51,7 @@ class Broadcast : public BuiltinOp {
protected:
const std::vector<int32_t> shape_;
const std::vector<int32_t> dimensions_;
std::vector<int32_t> dimensions_;
};
} // namespace ops

View File

@ -41,6 +41,10 @@ Broadcast::Broadcast(Graph* graph, const std::vector<int32_t>& shape,
this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size();
if (dimensions.size() > 0)
{
int dim_num = shape.size();
for (uint32_t i = 0; i < dimensions.size(); ++i) {
dimensions_[i] += (dimensions[i] < 0 ? dim_num : 0U);
}
this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data();
} else {
this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr;