From 3fed6d675757ae9962c461ed193d758aedb7067d Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Sat, 8 Oct 2022 14:50:14 +0800 Subject: [PATCH] fixed bug when broadcast dimensions is negative Signed-off-by: Chen Xin --- include/tim/vx/ops/broadcast.h | 6 +++--- src/tim/vx/ops/broadcast.cc | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/include/tim/vx/ops/broadcast.h b/include/tim/vx/ops/broadcast.h index 7e2329c..0bf224e 100644 --- a/include/tim/vx/ops/broadcast.h +++ b/include/tim/vx/ops/broadcast.h @@ -36,10 +36,10 @@ namespace ops { * * Input: * - input. - * + * * Attribute: * - shape: the shape which broadcast to. - * - dimensions(optional): Which dimension in the target shape each dimension + * - dimensions(optional): Which dimension in the target shape each dimension * of the operand shape corresponds to. For BroadcastInDim. */ @@ -51,7 +51,7 @@ class Broadcast : public BuiltinOp { protected: const std::vector shape_; - const std::vector dimensions_; + std::vector dimensions_; }; } // namespace ops diff --git a/src/tim/vx/ops/broadcast.cc b/src/tim/vx/ops/broadcast.cc index 3746853..5a9bc49 100644 --- a/src/tim/vx/ops/broadcast.cc +++ b/src/tim/vx/ops/broadcast.cc @@ -41,6 +41,10 @@ Broadcast::Broadcast(Graph* graph, const std::vector& shape, this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size(); if (dimensions.size() > 0) { + int dim_num = shape.size(); + for (uint32_t i = 0; i < dimensions.size(); ++i) { + dimensions_[i] += (dimensions[i] < 0 ? dim_num : 0U); + } this->impl()->node()->nn_param.expand_broadcast.dimensions = (uint32_t*)dimensions_.data(); } else { this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr;