diff --git a/src/tim/vx/internal/include/ops/vsi_nn_op_expand_broadcast.h b/src/tim/vx/internal/include/ops/vsi_nn_op_expand_broadcast.h index e9ce86a..9c5f8b6 100644 --- a/src/tim/vx/internal/include/ops/vsi_nn_op_expand_broadcast.h +++ b/src/tim/vx/internal/include/ops/vsi_nn_op_expand_broadcast.h @@ -30,6 +30,8 @@ extern "C" { #endif +#define VSI_EXPAND_BROADCAST_ENABLE_DIMENSIONS + typedef struct _vsi_nn_expand_broadcast_param { uint32_t *shape; diff --git a/src/tim/vx/ops/broadcast.cc b/src/tim/vx/ops/broadcast.cc index 4a1d5ac..0954f21 100644 --- a/src/tim/vx/ops/broadcast.cc +++ b/src/tim/vx/ops/broadcast.cc @@ -37,6 +37,7 @@ Broadcast::Broadcast(Graph* graph, const std::vector& shape, dimensions_(dimensions) { this->impl()->node()->nn_param.expand_broadcast.dim_num = shape_.size(); this->impl()->node()->nn_param.expand_broadcast.shape = (uint32_t*)shape_.data(); +#ifdef VSI_EXPAND_BROADCAST_ENABLE_DIMENSIONS this->impl()->node()->nn_param.expand_broadcast.dimensions_num = dimensions_.size(); if (dimensions.size() > 0) { @@ -44,7 +45,7 @@ Broadcast::Broadcast(Graph* graph, const std::vector& shape, } else { this->impl()->node()->nn_param.expand_broadcast.dimensions = nullptr; } - +#endif } std::shared_ptr Broadcast::Clone(