Support multiply attribute for tensor spec
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
2390ece5ac
commit
bd4d277ac1
|
|
@ -41,7 +41,13 @@ enum class DataType {
|
||||||
|
|
||||||
enum class QuantType { NONE, ASYMMETRIC, SYMMETRIC_PER_CHANNEL };
|
enum class QuantType { NONE, ASYMMETRIC, SYMMETRIC_PER_CHANNEL };
|
||||||
|
|
||||||
enum class TensorAttribute { CONSTANT, TRANSIENT, VARIABLE, INPUT, OUTPUT };
|
enum TensorAttribute {
|
||||||
|
CONSTANT = 1 << 0,
|
||||||
|
TRANSIENT = 1 << 1,
|
||||||
|
VARIABLE = 1 << 2,
|
||||||
|
INPUT = 1 << 3,
|
||||||
|
OUTPUT = 1 << 4
|
||||||
|
};
|
||||||
|
|
||||||
enum class PadType { NONE = -1, AUTO, VALID, SAME };
|
enum class PadType { NONE = -1, AUTO, VALID, SAME };
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ OperationImpl::OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt,
|
||||||
OperationImpl& OperationImpl::BindInput(const std::shared_ptr<Tensor>& tensor) {
|
OperationImpl& OperationImpl::BindInput(const std::shared_ptr<Tensor>& tensor) {
|
||||||
uint32_t tensor_id = tensor->GetId();
|
uint32_t tensor_id = tensor->GetId();
|
||||||
node_->input.tensors[input_tensor_index++] = tensor_id;
|
node_->input.tensors[input_tensor_index++] = tensor_id;
|
||||||
if (tensor->GetSpec().attr_ == TensorAttribute::INPUT) {
|
if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) {
|
||||||
graph_->AddInput(tensor_id);
|
graph_->AddInput(tensor_id);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
|
|
|
||||||
|
|
@ -151,18 +151,8 @@ bool TensorImpl::Init() {
|
||||||
|
|
||||||
memset(&attr, 0x00, sizeof(attr));
|
memset(&attr, 0x00, sizeof(attr));
|
||||||
attr.dim_num = spec_.shape_.size();
|
attr.dim_num = spec_.shape_.size();
|
||||||
attr.is_const = FALSE;
|
attr.is_const = spec_.attr_ & TensorAttribute::CONSTANT;
|
||||||
attr.vtl = FALSE;
|
attr.vtl = spec_.attr_ & TensorAttribute::TRANSIENT;
|
||||||
switch (spec_.attr_) {
|
|
||||||
case TensorAttribute::CONSTANT:
|
|
||||||
attr.is_const = TRUE;
|
|
||||||
break;
|
|
||||||
case TensorAttribute::TRANSIENT:
|
|
||||||
attr.vtl = TRUE;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (ShapeType::size_type i = 0; i < spec_.shape_.size(); i++) {
|
for (ShapeType::size_type i = 0; i < spec_.shape_.size(); i++) {
|
||||||
attr.size[i] = spec_.shape_[i];
|
attr.size[i] = spec_.shape_[i];
|
||||||
|
|
@ -170,8 +160,8 @@ bool TensorImpl::Init() {
|
||||||
|
|
||||||
PackTensorDtype(spec_, &attr.dtype);
|
PackTensorDtype(spec_, &attr.dtype);
|
||||||
|
|
||||||
if (spec_.attr_ == TensorAttribute::INPUT ||
|
if ((spec_.attr_ & TensorAttribute::INPUT) ||
|
||||||
spec_.attr_ == TensorAttribute::OUTPUT) {
|
(spec_.attr_ & TensorAttribute::OUTPUT)) {
|
||||||
id_ = vsi_nn_AddTensorFromHandle(graph_->graph(), VSI_NN_TENSOR_ID_AUTO,
|
id_ = vsi_nn_AddTensorFromHandle(graph_->graph(), VSI_NN_TENSOR_ID_AUTO,
|
||||||
&attr, nullptr);
|
&attr, nullptr);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue