Support multiply attribute for tensor spec

Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
yuenan.li 2021-01-29 16:14:33 +08:00
parent 2390ece5ac
commit bd4d277ac1
3 changed files with 12 additions and 16 deletions

View File

@ -41,7 +41,13 @@ enum class DataType {
enum class QuantType { NONE, ASYMMETRIC, SYMMETRIC_PER_CHANNEL };
enum class TensorAttribute { CONSTANT, TRANSIENT, VARIABLE, INPUT, OUTPUT };
enum TensorAttribute {
CONSTANT = 1 << 0,
TRANSIENT = 1 << 1,
VARIABLE = 1 << 2,
INPUT = 1 << 3,
OUTPUT = 1 << 4
};
enum class PadType { NONE = -1, AUTO, VALID, SAME };

View File

@ -49,7 +49,7 @@ OperationImpl::OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt,
OperationImpl& OperationImpl::BindInput(const std::shared_ptr<Tensor>& tensor) {
uint32_t tensor_id = tensor->GetId();
node_->input.tensors[input_tensor_index++] = tensor_id;
if (tensor->GetSpec().attr_ == TensorAttribute::INPUT) {
if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) {
graph_->AddInput(tensor_id);
}
return *this;

View File

@ -151,18 +151,8 @@ bool TensorImpl::Init() {
memset(&attr, 0x00, sizeof(attr));
attr.dim_num = spec_.shape_.size();
attr.is_const = FALSE;
attr.vtl = FALSE;
switch (spec_.attr_) {
case TensorAttribute::CONSTANT:
attr.is_const = TRUE;
break;
case TensorAttribute::TRANSIENT:
attr.vtl = TRUE;
break;
default:
break;
}
attr.is_const = spec_.attr_ & TensorAttribute::CONSTANT;
attr.vtl = spec_.attr_ & TensorAttribute::TRANSIENT;
for (ShapeType::size_type i = 0; i < spec_.shape_.size(); i++) {
attr.size[i] = spec_.shape_[i];
@ -170,8 +160,8 @@ bool TensorImpl::Init() {
PackTensorDtype(spec_, &attr.dtype);
if (spec_.attr_ == TensorAttribute::INPUT ||
spec_.attr_ == TensorAttribute::OUTPUT) {
if ((spec_.attr_ & TensorAttribute::INPUT) ||
(spec_.attr_ & TensorAttribute::OUTPUT)) {
id_ = vsi_nn_AddTensorFromHandle(graph_->graph(), VSI_NN_TENSOR_ID_AUTO,
&attr, nullptr);
} else {