fixed bug when broadcast dimensions is negative

Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
Chen Xin 2022-10-08 14:50:14 +08:00 committed by Sven
parent a038df2a84
commit 3fed6d6757
2 changed files with 7 additions and 3 deletions

View File

@ -51,7 +51,7 @@ class Broadcast : public BuiltinOp {
protected: protected:
const std::vector<int32_t> shape_; const std::vector<int32_t> shape_;
const std::vector<int32_t> dimensions_; std::vector<int32_t> dimensions_;
}; };
} // namespace ops } // namespace ops

View File

@ -41,6 +41,10 @@ Broadcast::Broadcast(Graph* graph, const std::vector<int32_t>& shape,
this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size(); this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size();
if (dimensions.size() > 0) if (dimensions.size() > 0)
{ {
int dim_num = shape.size();
for (uint32_t i = 0; i < dimensions.size(); ++i) {
dimensions_[i] += (dimensions[i] < 0 ? dim_num : 0U);
}
this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data(); this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data();
} else { } else {
this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr; this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr;