From bd4d277ac11b2b763244588d9d9c00771afb3598 Mon Sep 17 00:00:00 2001 From: "yuenan.li" Date: Fri, 29 Jan 2021 16:14:33 +0800 Subject: [PATCH] Support multiply attribute for tensor spec Signed-off-by: yuenan.li --- include/tim/vx/types.h | 8 +++++++- src/tim/vx/operation.cc | 2 +- src/tim/vx/tensor.cc | 18 ++++-------------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/include/tim/vx/types.h b/include/tim/vx/types.h index 4948207..89f22f5 100644 --- a/include/tim/vx/types.h +++ b/include/tim/vx/types.h @@ -41,7 +41,13 @@ enum class DataType { enum class QuantType { NONE, ASYMMETRIC, SYMMETRIC_PER_CHANNEL }; -enum class TensorAttribute { CONSTANT, TRANSIENT, VARIABLE, INPUT, OUTPUT }; +enum TensorAttribute { + CONSTANT = 1 << 0, + TRANSIENT = 1 << 1, + VARIABLE = 1 << 2, + INPUT = 1 << 3, + OUTPUT = 1 << 4 +}; enum class PadType { NONE = -1, AUTO, VALID, SAME }; diff --git a/src/tim/vx/operation.cc b/src/tim/vx/operation.cc index 2295441..a347ad3 100644 --- a/src/tim/vx/operation.cc +++ b/src/tim/vx/operation.cc @@ -49,7 +49,7 @@ OperationImpl::OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt, OperationImpl& OperationImpl::BindInput(const std::shared_ptr& tensor) { uint32_t tensor_id = tensor->GetId(); node_->input.tensors[input_tensor_index++] = tensor_id; - if (tensor->GetSpec().attr_ == TensorAttribute::INPUT) { + if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) { graph_->AddInput(tensor_id); } return *this; diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index c59e5ba..7c72604 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -151,18 +151,8 @@ bool TensorImpl::Init() { memset(&attr, 0x00, sizeof(attr)); attr.dim_num = spec_.shape_.size(); - attr.is_const = FALSE; - attr.vtl = FALSE; - switch (spec_.attr_) { - case TensorAttribute::CONSTANT: - attr.is_const = TRUE; - break; - case TensorAttribute::TRANSIENT: - attr.vtl = TRUE; - break; - default: - break; - } + attr.is_const = spec_.attr_ & TensorAttribute::CONSTANT; + attr.vtl = spec_.attr_ & TensorAttribute::TRANSIENT; for (ShapeType::size_type i = 0; i < spec_.shape_.size(); i++) { attr.size[i] = spec_.shape_[i]; @@ -170,8 +160,8 @@ bool TensorImpl::Init() { PackTensorDtype(spec_, &attr.dtype); - if (spec_.attr_ == TensorAttribute::INPUT || - spec_.attr_ == TensorAttribute::OUTPUT) { + if ((spec_.attr_ & TensorAttribute::INPUT) || + (spec_.attr_ & TensorAttribute::OUTPUT)) { id_ = vsi_nn_AddTensorFromHandle(graph_->graph(), VSI_NN_TENSOR_ID_AUTO, &attr, nullptr); } else {