Added op fusion for mean_stddev_normalization (#629)

Added op fusion for mean_stddev_normalization ops such as layernorm and
instance norm

Type: New Feature

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-08-09 22:10:45 +08:00 committed by GitHub
parent bff26a32c4
commit 35e50d7692
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 573 additions and 131 deletions

View File

@ -0,0 +1,22 @@
#ifndef TIM_MEAN_STD_DEV_NORMALIZE_FUSION_H
#define TIM_MEAN_STD_DEV_NORMALIZE_FUSION_H
#include <map>
#include <vector>
#include <memory>
namespace tim {
namespace vx {
class Context;
class Graph;
class Tensor;
class Operation;
} // namespace vx
namespace transform {
void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& src_graph);
} // namespace transform
} // namespace tim
#endif

View File

@ -33,6 +33,7 @@
#endif #endif
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
namespace tim { namespace tim {
namespace vx { namespace vx {
#ifdef ENABLE_TENSOR_CACHE #ifdef ENABLE_TENSOR_CACHE
@ -61,7 +62,7 @@ class Graph {
/// Create a tensor with given `TensorSpec`. /// Create a tensor with given `TensorSpec`.
/// spec.attr_ must be TensorAttribute::Input or Output /// spec.attr_ must be TensorAttribute::Input or Output
virtual std::shared_ptr<Tensor> CreateIOTensor(const TensorSpec& spec, virtual std::shared_ptr<Tensor> CreateIOTensor(const TensorSpec& spec,
void* data = nullptr) = 0; void* data = nullptr) = 0;
/// Create a placeholder tensor for optional inputs of operations /// Create a placeholder tensor for optional inputs of operations
virtual std::shared_ptr<Tensor> CreateTensorPlaceHolder() = 0; virtual std::shared_ptr<Tensor> CreateTensorPlaceHolder() = 0;
@ -102,6 +103,12 @@ class Graph {
virtual void PrintGraph() const = 0; virtual void PrintGraph() const = 0;
const std::vector<std::shared_ptr<Tensor>> GetConstantInputs() const; const std::vector<std::shared_ptr<Tensor>> GetConstantInputs() const;
virtual std::vector<std::shared_ptr<Operation>>& OpVector() = 0;
virtual std::map<std::shared_ptr<Tensor>,
std::vector<std::shared_ptr<Operation>>>&
TensorConsumer() = 0;
virtual std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>>&
TensorProducer() = 0;
protected: protected:
std::vector<std::shared_ptr<tim::vx::Operation>> op_vector_; std::vector<std::shared_ptr<tim::vx::Operation>> op_vector_;

View File

@ -48,9 +48,7 @@ class Quantization {
channel_dim_(channel_dim), channel_dim_(channel_dim),
scales_(std::move(scales)), scales_(std::move(scales)),
zero_points_(std::move(zero_points)) {} zero_points_(std::move(zero_points)) {}
Quantization(QuantType type, int8_t fl) Quantization(QuantType type, int8_t fl) : type_(type), fl_(fl) {}
: type_(type),
fl_(fl){}
QuantType& Type() { return type_; } QuantType& Type() { return type_; }
const QuantType& Type() const { return type_; } const QuantType& Type() const { return type_; }
Quantization& SetType(QuantType type) { Quantization& SetType(QuantType type) {
@ -79,9 +77,9 @@ class Quantization {
return *this; return *this;
} }
const std::int8_t& Fl() const{ return this->fl_; } const std::int8_t& Fl() const { return this->fl_; }
bool operator == (const Quantization& other_quant) const; bool operator==(const Quantization& other_quant) const;
protected: protected:
QuantType type_{QuantType::NONE}; QuantType type_{QuantType::NONE};
@ -148,7 +146,8 @@ class Tensor {
virtual const Quantization& GetQuantization() = 0; virtual const Quantization& GetQuantization() = 0;
virtual TensorSpec& GetSpec() = 0; virtual TensorSpec& GetSpec() = 0;
virtual uint32_t GetId() = 0; virtual uint32_t GetId() = 0;
virtual bool CopyDataToTensor(const void* data, uint32_t size_in_bytes = 0) = 0; virtual bool CopyDataToTensor(const void* data,
uint32_t size_in_bytes = 0) = 0;
virtual bool CopyDataFromTensor(void* data) = 0; virtual bool CopyDataFromTensor(void* data) = 0;
virtual bool FlushCacheForHandle() = 0; virtual bool FlushCacheForHandle() = 0;
virtual bool InvalidateCacheForHandle() = 0; virtual bool InvalidateCacheForHandle() = 0;
@ -158,10 +157,13 @@ class Tensor {
virtual bool IsConstTensor() = 0; virtual bool IsConstTensor() = 0;
virtual bool SaveTensorToTextByFp32(std::string filename) = 0; virtual bool SaveTensorToTextByFp32(std::string filename) = 0;
virtual void* ConvertTensorToData(uint8_t* tensorData) = 0; virtual void* ConvertTensorToData(uint8_t* tensorData) = 0;
virtual float* ConvertTensorToFloat32Data() = 0;
}; };
namespace utils{ namespace utils {
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData); bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor,
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data); std::vector<float> fval, uint8_t* tensorData);
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor,
uint8_t* tensorData, float* data);
} //namespace utils } //namespace utils
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim

View File

@ -0,0 +1,368 @@
#include <algorithm>
#include <stdarg.h>
#include "tim/transform/mean_stddev_normalize_fusion.h"
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/operation.h"
#include "tim/vx/ops/layernormalization.h"
#include "tim/vx/ops/instancenormalization.h"
#include "builtin_op_impl.h"
namespace tim {
namespace transform {
enum {
NORMALIZATION_INDEX_MEAN_0 = 0,
NORMALIZATION_INDEX_SUB_0 = 1,
NORMALIZATION_INDEX_MUL_2 = 2,
NORMALIZATION_INDEX_POW = 3,
NORMALIZATION_INDEX_MEAN_1 = 4,
NORMALIZATION_INDEX_ADD_0 = 5,
NORMALIZATION_INDEX_RSQRT = 6,
NORMALIZATION_INDEX_MUL_0 = 7,
NORMALIZATION_INDEX_MUL_1 = 8,
NORMALIZATION_INDEX_ADD_1 = 9,
NORMALIZATION_INDEX_SUB_1 = 10
};
//Determine whether the needed opkind is in given consumer op list
bool OpkindInConsumers(std::vector<std::shared_ptr<vx::Operation>> consumers,
int32_t op_id) {
if (consumers.size() == 1 && consumers[0]->impl()->kind_ != op_id) {
return false;
}
auto op_iter = std::find_if(consumers.begin(), consumers.end(),
[op_id](std::shared_ptr<vx::Operation> oper) {
return oper.get()->impl()->kind_ == op_id;
});
return op_iter != consumers.end();
}
// Check if one of op's consumers has already in list.
// Only if the consumer is same to the compared one with given index
// can be considered as pattern matched.
bool OpInConsumer(const std::shared_ptr<vx::Graph>& graph,
const std::shared_ptr<vx::Operation>& current,
const std::shared_ptr<vx::Operation>& compared) {
auto output_tensor = current->impl()->OutputsTensor()[0];
auto ops = graph->GetConsumersOp(output_tensor);
for (auto op : ops) {
if (op == compared) {
return true;
}
}
return false;
}
// Determine whether the current op is suitable for pattern matching with specified
// consumers. The possible ops will be stored in a temporary vector created during
// each Pattern Matching. Special situation that the consumer has already in list will
// NOT be concerned in this function.
bool UpdateTempVector(std::vector<std::shared_ptr<vx::Operation>>& temp,
int32_t curr_index,
const std::shared_ptr<vx::Graph>& graph,
std::vector<int32_t> op_kind) {
auto outputs = temp[curr_index]->impl()->OutputsTensor();
auto ops = graph->GetConsumersOp(outputs[0]);
if (outputs.size() > 1 || ops.size() != op_kind.size() || op_kind.size() > 2)
return false;
else {
for (int32_t op_k : op_kind) {
if (!OpkindInConsumers(ops, op_k)) return false;
}
if (op_kind.size() == 2) {
int32_t first_index = ops[0]->impl()->kind_ == op_kind[0] ? 0 : 1;
//push back ops as same order as need
temp.push_back(graph->GetConsumersOp(outputs[0])[first_index]);
temp.push_back(graph->GetConsumersOp(outputs[0])[1 - first_index]);
} else {
temp.push_back(ops[0]);
}
return true;
}
}
// Remove ops and tensors in each matched normlization patten
void RemoveTensorsAndOps(
std::shared_ptr<vx::Graph>& graph,
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
for (uint32_t i = 0; i < norm_ops.size(); i++) {
auto it = std::find(graph->OpVector().begin(), graph->OpVector().end(),
norm_ops[i]);
graph->OpVector().erase(it); //Remove current op from op_vector_
auto input_tensors = norm_ops[i]->impl()->InputsTensor();
auto output_tensors = norm_ops[i]->impl()->OutputsTensor();
switch (i) {
case NORMALIZATION_INDEX_MEAN_0:
case NORMALIZATION_INDEX_SUB_0:
case NORMALIZATION_INDEX_MUL_1:
for (auto tensor : input_tensors) {
if (tensor->GetSpec().attr_ == vx::TensorAttribute::CONSTANT)
graph->TensorConsumer().erase(tensor);
else {
it = std::find_if(
graph->TensorConsumer()[tensor].begin(),
graph->TensorConsumer()[tensor].end(),
[i, norm_ops](std::shared_ptr<vx::Operation> oper) {
return oper == norm_ops[i];
});
if (it != graph->TensorConsumer()[tensor].end())
graph->TensorConsumer()[tensor].erase(it);
if (graph->TensorConsumer()[tensor].empty())
graph->TensorConsumer().erase(tensor);
}
graph->TensorProducer().erase(output_tensors[0]);
}
break;
case NORMALIZATION_INDEX_ADD_1:
break;
default:
for (auto tensor : input_tensors) {
if (tensor->GetSpec().attr_ != vx::TensorAttribute::CONSTANT) {
if (graph->TensorProducer()[tensor] != nullptr) {
auto it =
std::find(graph->OpVector().begin(), graph->OpVector().end(),
graph->GetProducerOp(tensor));
graph->OpVector().erase(it);
graph->TensorProducer().erase(tensor);
}
}
graph->TensorConsumer().erase(tensor);
for (auto tensor : output_tensors)
graph->TensorProducer().erase(tensor);
}
break;
}
}
}
bool CheckMediumMul(const std::shared_ptr<vx::Graph>& graph,
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
auto mul0_output_tensor =
norm_ops[NORMALIZATION_INDEX_MUL_0]->impl()->OutputsTensor();
auto mul0_consumers = graph->GetConsumersOp(mul0_output_tensor[0]);
if (mul0_output_tensor.size() > 1 || mul0_consumers.size() != 2 ||
mul0_consumers[0]->impl()->kind_ != 1 ||
mul0_consumers[1]->impl()->kind_ != 1)
return false;
if (!OpInConsumer(graph, norm_ops[NORMALIZATION_INDEX_MUL_0],
norm_ops[NORMALIZATION_INDEX_MUL_2]))
return false;
int32_t mul1_index = graph->GetConsumersOp(mul0_output_tensor[0])[0] ==
norm_ops[NORMALIZATION_INDEX_MUL_2]
? 1
: 0;
norm_ops.push_back(mul0_consumers[mul1_index]);
return true;
}
bool HaveASameInput(const std::shared_ptr<vx::Operation>& op1,
const std::shared_ptr<vx::Operation>& op2) {
auto Left = op1->impl()->InputsTensor();
auto Right = op2->impl()->InputsTensor();
for (auto left_tensor : Left) {
if (std::find(Right.begin(), Right.end(), left_tensor) != Right.end())
return true;
}
return false;
}
void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
std::vector<std::shared_ptr<vx::Operation>> norm_ops) {
auto src_tensor =
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor()[0];
auto final_tensor =
norm_ops[NORMALIZATION_INDEX_ADD_1]->impl()->OutputsTensor()[0];
int32_t axis = *norm_ops[NORMALIZATION_INDEX_MEAN_0]
->impl()
->node()
->nn_param.reduce.axis;
axis = src_tensor->GetShape().size() - axis - 1; // reverse axis
// Get eps, gamma,beta;
// Do datatype convert due to InstanceNormlization op requirements
int32_t eps_index = graph->GetProducerOp(
norm_ops[5]->impl()->InputsTensor()[0]) == norm_ops[4]
? 1
: 0;
auto org_eps = norm_ops[5]->impl()->InputsTensor()[eps_index];
if (!org_eps->IsConstTensor()) {
org_eps = graph->GetProducerOp(org_eps)->impl()->InputsTensor()[0];
}
auto org_gamma =
norm_ops[NORMALIZATION_INDEX_MUL_0]->impl()->InputsTensor()[1];
auto org_beta =
norm_ops[NORMALIZATION_INDEX_SUB_1]->impl()->InputsTensor()[0];
float* float_eps = org_eps->ConvertTensorToFloat32Data();
float* float_gamma = org_gamma->ConvertTensorToFloat32Data();
float* float_beta = org_beta->ConvertTensorToFloat32Data();
RemoveTensorsAndOps(graph, norm_ops);
std::vector<uint32_t> shape(src_tensor->GetShape().size(), 1);
shape[axis] = src_tensor->GetShape()[axis];
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
vx::TensorAttribute::CONSTANT);
auto beta = graph->CreateTensor(param_spec);
auto gamma = graph->CreateTensor(param_spec);
float eps = *float_eps;
beta->CopyDataToTensor(float_beta);
gamma->CopyDataToTensor(float_gamma);
vsi_nn_Free(float_gamma);
vsi_nn_Free(float_beta);
vsi_nn_Free(float_eps);
auto layernorm =
graph->CreateOperation<vx::ops::LayerNormalization>(axis, eps);
graph->TensorConsumer()[src_tensor].push_back(layernorm);
layernorm->BindInputs({src_tensor, beta, gamma});
layernorm->BindOutputs({final_tensor});
}
void InstancenormConnection(
std::shared_ptr<vx::Graph>& graph,
std::vector<std::shared_ptr<vx::Operation>> norm_ops) {
auto src_tensor =
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor()[0];
auto final_tensor =
norm_ops[NORMALIZATION_INDEX_ADD_1]->impl()->OutputsTensor()[0];
// Get eps, gamma,beta from graph.
// Do datatype convert due to InstanceNormlization op requirements
int32_t eps_index =
graph->GetProducerOp(
norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[0]) ==
norm_ops[NORMALIZATION_INDEX_MEAN_1]
? 1
: 0;
auto org_eps =
norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[eps_index];
if (!org_eps->IsConstTensor()) {
org_eps = graph->GetProducerOp(org_eps)->impl()->InputsTensor()[0];
}
auto org_gamma =
norm_ops[NORMALIZATION_INDEX_MUL_0]->impl()->InputsTensor()[1];
auto org_beta =
norm_ops[NORMALIZATION_INDEX_SUB_1]->impl()->InputsTensor()[0];
float* float_eps = org_eps->ConvertTensorToFloat32Data();
float* float_gamma = org_gamma->ConvertTensorToFloat32Data();
float* float_beta = org_beta->ConvertTensorToFloat32Data();
RemoveTensorsAndOps(graph, norm_ops);
std::vector<uint32_t> shape(src_tensor->GetShape().size(), 1);
shape[0] = src_tensor->GetShape()[0];
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
vx::TensorAttribute::CONSTANT);
auto beta = graph->CreateTensor(param_spec);
auto gamma = graph->CreateTensor(param_spec);
float eps = *float_eps;
beta->CopyDataToTensor(float_beta);
gamma->CopyDataToTensor(float_gamma);
vsi_nn_Free(float_gamma);
vsi_nn_Free(float_beta);
vsi_nn_Free(float_eps);
auto instancenorm = graph->CreateOperation<vx::ops::InstanceNormalization>(
eps, vx::DataLayout::CWHN);
graph->TensorConsumer()[src_tensor].push_back(instancenorm);
instancenorm->BindInputs({src_tensor, beta, gamma});
instancenorm->BindOutputs({final_tensor});
}
/* Checking Mean StdDev Normalization structure:
input
/ | \
/ | Mean0
| | / |
| Sub0 |
| | |
| Pow |
| | |
| Mean1 |
| | |
| Add0 |
| | |
| Rsqrt |
| | |
| Mul0 |
| / \ |
Mul1 Mul2
| |
| Sub1
\ /
Add1
|
output
*/
void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& graph) {
std::vector<std::shared_ptr<vx::Operation>> op_vector = graph->OpVector();
for (const auto& op : op_vector) {
if (op->impl()->kind_ != VSI_NN_OP_REDUCE) continue;
std::vector<std::shared_ptr<vx::Operation>> temp;
temp.push_back(op);
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MEAN_0, graph,
{VSI_NN_OP_SUBTRACT, VSI_NN_OP_MULTIPLY}))
continue;
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_SUB_0, graph,
{VSI_NN_OP_POW}))
continue;
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_POW, graph,
{VSI_NN_OP_REDUCE}))
continue; //Mean1
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MEAN_1, graph,
{VSI_NN_OP_ADD}))
continue; //Add0
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_ADD_0, graph,
{VSI_NN_OP_RSQRT}))
continue; //Rsqrt
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_RSQRT, graph,
{VSI_NN_OP_MULTIPLY}))
continue; //Mul0
if (!CheckMediumMul(graph, temp)) continue;
if (!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0],
temp[NORMALIZATION_INDEX_SUB_0]) &&
!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0],
temp[NORMALIZATION_INDEX_MUL_1]) &&
!HaveASameInput(temp[NORMALIZATION_INDEX_SUB_0],
temp[NORMALIZATION_INDEX_MUL_1]))
continue;
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MUL_1, graph,
{VSI_NN_OP_ADD}))
continue; //Add1
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MUL_2, graph,
{VSI_NN_OP_SUBTRACT})) //Sub1
continue;
auto sub_outputs = temp[NORMALIZATION_INDEX_SUB_1]->impl()->OutputsTensor();
if (sub_outputs.size() >= 2 ||
graph->GetConsumersOp(sub_outputs[0]).size() > 1 ||
graph->GetConsumersOp(sub_outputs[0])[0]->impl()->kind_ != 0)
continue;
if (!OpInConsumer(graph, temp[NORMALIZATION_INDEX_SUB_1],
temp[NORMALIZATION_INDEX_ADD_1]))
continue;
int axis_num = temp[NORMALIZATION_INDEX_MEAN_0]
->impl()
->node()
->nn_param.reduce.axis_num;
if (axis_num == 1) {
LayernormConnection(graph, temp);
} else {
InstancenormConnection(graph, temp);
}
}
}
} // namespace transform
} // namespace tim

View File

@ -44,49 +44,49 @@ namespace vx {
#define MD5_SECRET_LEN_16 (16) #define MD5_SECRET_LEN_16 (16)
#define MD5_BYTE_STRING_LEN (4) #define MD5_BYTE_STRING_LEN (4)
const std::string calculateMd5Secret32(const std::string& src) { const std::string calculateMd5Secret32(const std::string& src) {
std::string md5String; std::string md5String;
EVP_MD_CTX *mdctx; EVP_MD_CTX* mdctx;
const EVP_MD *md; const EVP_MD* md;
uint32_t md_len; uint32_t md_len;
unsigned char md_value[MD5_SECRET_LEN_16] = {0}; unsigned char md_value[MD5_SECRET_LEN_16] = {0};
char tmp[MD5_BYTE_STRING_LEN] = {0}; char tmp[MD5_BYTE_STRING_LEN] = {0};
md = EVP_md5(); md = EVP_md5();
if (md == NULL) { if (md == NULL) {
printf("Unknown EVP_md5 message."); printf("Unknown EVP_md5 message.");
}
mdctx = EVP_MD_CTX_new();
if (!EVP_DigestInit_ex(mdctx, md, NULL)) {
printf("EVP_MD_CTX initialization failed.");
EVP_MD_CTX_free(mdctx);
}
if (!EVP_DigestUpdate(mdctx, src.c_str(), src.size())) {
printf("EVP_MD_CTX update failed.");
EVP_MD_CTX_free(mdctx);
}
if (!EVP_DigestFinal_ex(mdctx, md_value, &md_len)) {
printf("EVP_MD_CTX finalization failed.");
EVP_MD_CTX_free(mdctx);
}
EVP_MD_CTX_free(mdctx);
for (int i = 0; i < MD5_SECRET_LEN_16; ++i) {
memset(tmp, 0x00, sizeof(tmp));
snprintf(tmp, sizeof(tmp), "%02X", md_value[i]);
md5String += tmp;
}
return md5String;
} }
mdctx = EVP_MD_CTX_new();
if (!EVP_DigestInit_ex(mdctx, md, NULL)) {
printf("EVP_MD_CTX initialization failed.");
EVP_MD_CTX_free(mdctx);
}
if (!EVP_DigestUpdate(mdctx, src.c_str(), src.size())) {
printf("EVP_MD_CTX update failed.");
EVP_MD_CTX_free(mdctx);
}
if (!EVP_DigestFinal_ex(mdctx, md_value, &md_len)) {
printf("EVP_MD_CTX finalization failed.");
EVP_MD_CTX_free(mdctx);
}
EVP_MD_CTX_free(mdctx);
for (int i = 0; i < MD5_SECRET_LEN_16; ++i) {
memset(tmp, 0x00, sizeof(tmp));
snprintf(tmp, sizeof(tmp), "%02X", md_value[i]);
md5String += tmp;
}
return md5String;
}
#endif #endif
const std::vector<std::shared_ptr<Tensor>> Graph::GetConstantInputs() const { const std::vector<std::shared_ptr<Tensor>> Graph::GetConstantInputs() const {
std::vector<std::shared_ptr<Tensor>> const_inputs; std::vector<std::shared_ptr<Tensor>> const_inputs;
for (auto op : op_vector_) { for (auto op : op_vector_) {
auto const_i = op->ConstantInputsTensor(); auto const_i = op->ConstantInputsTensor();
const_inputs.insert(const_inputs.end(), const_i.begin(), const_i.end()); const_inputs.insert(const_inputs.end(), const_i.begin(), const_i.end());
}
return const_inputs;
} }
return const_inputs;
}
GraphImpl::GraphImpl(ContextImpl* context, const CompileOption& options) GraphImpl::GraphImpl(ContextImpl* context, const CompileOption& options)
: context_(context), : context_(context),
@ -94,16 +94,18 @@ GraphImpl::GraphImpl(ContextImpl* context, const CompileOption& options)
tensor_placeholder_(nullptr), tensor_placeholder_(nullptr),
not_consumed_input_cnt_(0), not_consumed_input_cnt_(0),
not_consumed_output_cnt_(0), not_consumed_output_cnt_(0),
options_(options){} options_(options) {}
GraphImpl::~GraphImpl() { vsi_nn_ReleaseGraph(&graph_); } GraphImpl::~GraphImpl() { vsi_nn_ReleaseGraph(&graph_); }
#ifdef ENABLE_TENSOR_CACHE #ifdef ENABLE_TENSOR_CACHE
std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GraphImpl::GetTensorCacheMap() { std::map<std::string, std::shared_ptr<tim::vx::Tensor>>&
GraphImpl::GetTensorCacheMap() {
return cached_tensor_; return cached_tensor_;
} }
const std::string GraphImpl::CalculateCacheKey(const TensorSpec& spec, const void* data) { const std::string GraphImpl::CalculateCacheKey(const TensorSpec& spec,
const void* data) {
std::string md5_key; std::string md5_key;
uint32_t data_size = 1; uint32_t data_size = 1;
for (auto it = spec.shape_.begin(); it != spec.shape_.end(); ++it) { for (auto it = spec.shape_.begin(); it != spec.shape_.end(); ++it) {
@ -135,12 +137,15 @@ const std::string GraphImpl::CalculateCacheKey(const TensorSpec& spec, const voi
return md5_key; return md5_key;
} }
std::shared_ptr<Tensor> GraphImpl::GetTensorFromCache(const TensorSpec& spec, const void* data) { std::shared_ptr<Tensor> GraphImpl::GetTensorFromCache(const TensorSpec& spec,
const void* data) {
std::shared_ptr<tim::vx::Tensor> tensor; std::shared_ptr<tim::vx::Tensor> tensor;
std::string md5_key = CalculateCacheKey(spec, data); std::string md5_key = CalculateCacheKey(spec, data);
if (GetTensorCacheMap().find(md5_key) != GetTensorCacheMap().end() && if (GetTensorCacheMap().find(md5_key) != GetTensorCacheMap().end() &&
GetTensorCacheMap()[md5_key]->GetQuantization().Scales() == spec.quantization_.Scales() && GetTensorCacheMap()[md5_key]->GetQuantization().Scales() ==
GetTensorCacheMap()[md5_key]->GetQuantization().ZeroPoints() == spec.quantization_.ZeroPoints()) { spec.quantization_.Scales() &&
GetTensorCacheMap()[md5_key]->GetQuantization().ZeroPoints() ==
spec.quantization_.ZeroPoints()) {
tensor = GetTensorCacheMap()[md5_key]; tensor = GetTensorCacheMap()[md5_key];
} else { } else {
tensor = std::make_shared<TensorImpl>(this, spec, data); tensor = std::make_shared<TensorImpl>(this, spec, data);
@ -190,6 +195,20 @@ const std::vector<std::shared_ptr<Tensor>> GraphImpl::OutputsTensor() const {
return outputs_tensor_; return outputs_tensor_;
} }
std::vector<std::shared_ptr<Operation>>& GraphImpl::OpVector() {
return op_vector_;
}
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>>&
GraphImpl::TensorConsumer() {
return tensor_consumers_;
}
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>>&
GraphImpl::TensorProducer() {
return tensor_producer_;
}
void GraphImpl::UpdateTensorConsumersMap(const std::shared_ptr<Tensor>& tensor, void GraphImpl::UpdateTensorConsumersMap(const std::shared_ptr<Tensor>& tensor,
const Operation* op) { const Operation* op) {
for (const auto& added_op : op_vector_) { for (const auto& added_op : op_vector_) {
@ -216,7 +235,7 @@ void GraphImpl::RenewTensorConsumersMap(
} }
void GraphImpl::UpdateTensorProducerMap(const std::shared_ptr<Tensor>& tensor, void GraphImpl::UpdateTensorProducerMap(const std::shared_ptr<Tensor>& tensor,
const Operation* op) { const Operation* op) {
for (const auto& added_op : op_vector_) { for (const auto& added_op : op_vector_) {
if (added_op.get() == op) { if (added_op.get() == op) {
tensor_producer_[tensor] = added_op; tensor_producer_[tensor] = added_op;
@ -236,7 +255,7 @@ const std::vector<std::shared_ptr<Operation>> GraphImpl::GetConsumersOp(
} }
std::shared_ptr<Operation> GraphImpl::GetProducerOp( std::shared_ptr<Operation> GraphImpl::GetProducerOp(
std::shared_ptr<Tensor> tensor) { std::shared_ptr<Tensor> tensor) {
auto producer = tensor_producer_.find(tensor); auto producer = tensor_producer_.find(tensor);
if (tensor_producer_.end() != producer) { if (tensor_producer_.end() != producer) {
return producer->second; return producer->second;
@ -286,7 +305,7 @@ std::shared_ptr<Tensor> GraphImpl::CreateTensor(const TensorSpec& spec,
} }
std::shared_ptr<Tensor> GraphImpl::CreateIOTensor(const TensorSpec& spec, std::shared_ptr<Tensor> GraphImpl::CreateIOTensor(const TensorSpec& spec,
void* data) { void* data) {
auto tensor = std::make_shared<TensorImpl>(this, spec, data); auto tensor = std::make_shared<TensorImpl>(this, spec, data);
if (spec.attr_ & TensorAttribute::INPUT) { if (spec.attr_ & TensorAttribute::INPUT) {
this->AddInput(tensor); this->AddInput(tensor);
@ -320,16 +339,17 @@ bool GraphImpl::Setup() {
bool is_fast_mode = options_.isRelaxMode(); bool is_fast_mode = options_.isRelaxMode();
if (is_fast_mode) { if (is_fast_mode) {
VSILOGW("Important notice: float model executed in bfloat16 " VSILOGW(
"mode which will have better performance but lower precesion"); "Important notice: float model executed in bfloat16 "
"mode which will have better performance but lower precesion");
} }
vsi_nn_SetGraphFastMode(graph_, is_fast_mode); vsi_nn_SetGraphFastMode(graph_, is_fast_mode);
#if defined(ENABLE_PLATFORM) #if defined(ENABLE_PLATFORM)
auto id = options_.getDeviceId(); auto id = options_.getDeviceId();
vxSetGraphAttribute(graph_->g, VX_GRAPH_DEVICE_INDEX_VIV, vxSetGraphAttribute(graph_->g, VX_GRAPH_DEVICE_INDEX_VIV, (void*)(&id),
(void*)(&id), sizeof(id)); sizeof(id));
#endif #endif
std::call_once(setio_once_, [&status, this]() { std::call_once(setio_once_, [&status, this]() {
status = (vsi_nn_SetGraphInputs(this->graph_, this->inputs_.data(), status = (vsi_nn_SetGraphInputs(this->graph_, this->inputs_.data(),
@ -346,12 +366,16 @@ bool GraphImpl::Setup() {
bool GraphImpl::Compile() { bool GraphImpl::Compile() {
bool status = true; bool status = true;
if (not_consumed_input_cnt_ > 0 ) { if (not_consumed_input_cnt_ > 0) {
// Tensor can bind to different operations // Tensor can bind to different operations
VSILOGW("Graph has free input, INPUT tensor may be created but not consumed."); VSILOGW(
"Graph has free input, INPUT tensor may be created but not "
"consumed.");
} }
if (not_consumed_output_cnt_ != 0) { if (not_consumed_output_cnt_ != 0) {
VSILOGW("Graph has free output, OUTPUT tensor may be created but not consumed."); VSILOGW(
"Graph has free output, OUTPUT tensor may be created but not "
"consumed.");
} }
status = Setup(); status = Setup();
std::call_once(verify_graph_once_, [&status, this]() { std::call_once(verify_graph_once_, [&status, this]() {

View File

@ -42,10 +42,12 @@ namespace vx {
class GraphImpl : public Graph { class GraphImpl : public Graph {
public: public:
GraphImpl(ContextImpl* context, const CompileOption& options = CompileOption::DefaultOptions); GraphImpl(ContextImpl* context,
const CompileOption& options = CompileOption::DefaultOptions);
~GraphImpl(); ~GraphImpl();
#ifdef ENABLE_TENSOR_CACHE #ifdef ENABLE_TENSOR_CACHE
std::shared_ptr<Tensor> GetTensorFromCache(const TensorSpec& spec, const void* data); std::shared_ptr<Tensor> GetTensorFromCache(const TensorSpec& spec,
const void* data);
const std::string CalculateCacheKey(const TensorSpec& spec, const void* data); const std::string CalculateCacheKey(const TensorSpec& spec, const void* data);
std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GetTensorCacheMap(); std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GetTensorCacheMap();
#endif #endif
@ -62,14 +64,18 @@ class GraphImpl : public Graph {
const std::vector<std::shared_ptr<Tensor>> InputsTensor() const override; const std::vector<std::shared_ptr<Tensor>> InputsTensor() const override;
const std::vector<std::shared_ptr<Tensor>> OutputsTensor() const override; const std::vector<std::shared_ptr<Tensor>> OutputsTensor() const override;
std::vector<std::shared_ptr<Operation>>& OpVector() override;
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>>&
TensorConsumer() override;
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>>&
TensorProducer() override;
void UpdateTensorConsumersMap(const std::shared_ptr<Tensor>& tensor, void UpdateTensorConsumersMap(const std::shared_ptr<Tensor>& tensor,
const Operation* op) override; const Operation* op) override;
void RenewTensorConsumersMap(const std::shared_ptr<Tensor>& org_tensor, void RenewTensorConsumersMap(const std::shared_ptr<Tensor>& org_tensor,
const std::shared_ptr<Tensor>& dst_tensor, const std::shared_ptr<Tensor>& dst_tensor,
const Operation* op) override; const Operation* op) override;
void UpdateTensorProducerMap(const std::shared_ptr<Tensor>& tensor, void UpdateTensorProducerMap(const std::shared_ptr<Tensor>& tensor,
const Operation* op) override; const Operation* op) override;
const std::vector<std::shared_ptr<Operation>> GetConsumersOp( const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
std::shared_ptr<Tensor> tensor) const override; std::shared_ptr<Tensor> tensor) const override;
std::shared_ptr<Operation> GetProducerOp( std::shared_ptr<Operation> GetProducerOp(
@ -82,7 +88,7 @@ class GraphImpl : public Graph {
std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec, std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec,
const DmaBufferDesc& dmafd) override; const DmaBufferDesc& dmafd) override;
std::shared_ptr<Tensor> CreateIOTensor(const TensorSpec& spec, std::shared_ptr<Tensor> CreateIOTensor(const TensorSpec& spec,
void* data = nullptr) override; void* data = nullptr) override;
std::shared_ptr<Tensor> CreateTensorPlaceHolder() override; std::shared_ptr<Tensor> CreateTensorPlaceHolder() override;
bool Compile() override; bool Compile() override;
@ -106,14 +112,17 @@ class GraphImpl : public Graph {
int32_t not_consumed_input_cnt_; int32_t not_consumed_input_cnt_;
std::vector<std::shared_ptr<Tensor>> outputs_tensor_; std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
int32_t not_consumed_output_cnt_; int32_t not_consumed_output_cnt_;
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_; std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>>
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>> tensor_producer_; tensor_consumers_;
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>>
tensor_producer_;
#ifdef ENABLE_TENSOR_CACHE #ifdef ENABLE_TENSOR_CACHE
std::map<std::string, std::shared_ptr<tim::vx::Tensor>> cached_tensor_; std::map<std::string, std::shared_ptr<tim::vx::Tensor>> cached_tensor_;
#endif #endif
CompileOption options_; CompileOption options_;
private: private:
/// Setup graph /// Setup graph
bool Setup(); bool Setup();
}; };

View File

@ -87,14 +87,15 @@ TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, const void* data)
: graph_(reinterpret_cast<GraphImpl*>(graph)), : graph_(reinterpret_cast<GraphImpl*>(graph)),
id_(VSI_NN_TENSOR_ID_NA), id_(VSI_NN_TENSOR_ID_NA),
spec_(spec), spec_(spec),
data_(const_cast<void *>(data)) { data_(const_cast<void*>(data)) {
Init(); Init();
if (spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT)) { if (spec_.attr_ & (TensorAttribute::INPUT | TensorAttribute::OUTPUT)) {
data_ = nullptr; // it's not needed to reset it in a constant tensor data_ = nullptr; // it's not needed to reset it in a constant tensor
} }
} }
TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, const DmaBufferDesc& dmafd) TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec,
const DmaBufferDesc& dmafd)
: graph_(reinterpret_cast<GraphImpl*>(graph)), : graph_(reinterpret_cast<GraphImpl*>(graph)),
id_(VSI_NN_TENSOR_ID_NA), id_(VSI_NN_TENSOR_ID_NA),
spec_(spec), spec_(spec),
@ -118,13 +119,14 @@ TensorImpl::TensorImpl(Graph* graph, const TensorSpec& spec, void* data)
TensorImpl::~TensorImpl() {} TensorImpl::~TensorImpl() {}
bool TensorImpl::SaveTensorToTextByFp32(std::string filename){ bool TensorImpl::SaveTensorToTextByFp32(std::string filename) {
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
vsi_nn_SaveTensorToTextByFp32(graph_->graph(), tensor, filename.c_str(), NULL); vsi_nn_SaveTensorToTextByFp32(graph_->graph(), tensor, filename.c_str(),
NULL);
return true; return true;
} }
void* TensorImpl::ConvertTensorToData(uint8_t* tensorData){ void* TensorImpl::ConvertTensorToData(uint8_t* tensorData) {
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
tensorData = vsi_nn_ConvertTensorToData(graph_->graph(), tensor); tensorData = vsi_nn_ConvertTensorToData(graph_->graph(), tensor);
return tensorData; return tensorData;
@ -142,10 +144,10 @@ bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
if (tensor) { if (tensor) {
uint32_t tensor_bytes = vsi_nn_GetTensorSize( uint32_t tensor_bytes = vsi_nn_GetTensorSize(
tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type); tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type);
if (tensor->attr.is_created_from_handle) { if (tensor->attr.is_created_from_handle) {
void *ptr = NULL; void* ptr = NULL;
vsi_nn_GetTensorHandle(tensor, &ptr); vsi_nn_GetTensorHandle(tensor, &ptr);
if (ptr) { if (ptr) {
memcpy(ptr, data, tensor_bytes); memcpy(ptr, data, tensor_bytes);
@ -154,8 +156,7 @@ bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
} else { } else {
VSILOGE("GetTensorHandle fail"); VSILOGE("GetTensorHandle fail");
} }
} } else {
else {
/* /*
argument `data` of vsi_nn_CopyDataToTensor is non-const argument `data` of vsi_nn_CopyDataToTensor is non-const
convert it from const data to non-const, will be fixed in ovxlib convert it from const data to non-const, will be fixed in ovxlib
@ -163,8 +164,8 @@ bool TensorImpl::CopyDataToTensor(const void* data, uint32_t size_in_bytes) {
const uint8_t* end = static_cast<const uint8_t*>(data) + tensor_bytes; const uint8_t* end = static_cast<const uint8_t*>(data) + tensor_bytes;
std::vector<uint8_t> data_copy(static_cast<const uint8_t*>(data), end); std::vector<uint8_t> data_copy(static_cast<const uint8_t*>(data), end);
retn = (VSI_SUCCESS == retn = (VSI_SUCCESS == vsi_nn_CopyDataToTensor(graph_->graph(), tensor,
vsi_nn_CopyDataToTensor(graph_->graph(), tensor, data_copy.data())); data_copy.data()));
} }
} }
} }
@ -183,14 +184,14 @@ bool TensorImpl::CopyDataFromTensor(void* data) {
if (tensor) { if (tensor) {
uint32_t tensor_bytes = vsi_nn_GetTensorSize( uint32_t tensor_bytes = vsi_nn_GetTensorSize(
tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type); tensor->attr.size, tensor->attr.dim_num, tensor->attr.dtype.vx_type);
if (tensor->attr.is_created_from_handle) { if (tensor->attr.is_created_from_handle) {
void* ptr = NULL; void* ptr = NULL;
vsi_nn_GetTensorHandle(tensor, &ptr); vsi_nn_GetTensorHandle(tensor, &ptr);
#ifdef VSI_INVALIDATE_HANDLE_SUPPORT #ifdef VSI_INVALIDATE_HANDLE_SUPPORT
vsi_nn_InvalidateHandle(tensor); vsi_nn_InvalidateHandle(tensor);
#endif #endif
if (ptr) { if (ptr) {
memcpy(data, ptr, tensor_bytes); memcpy(data, ptr, tensor_bytes);
retn = true; retn = true;
@ -207,6 +208,11 @@ bool TensorImpl::CopyDataFromTensor(void* data) {
return retn; return retn;
} }
float* TensorImpl::ConvertTensorToFloat32Data() {
return vsi_nn_ConvertTensorToFloat32Data(
graph_->graph(), vsi_nn_GetTensor(graph_->graph(), id_));
}
bool TensorImpl::FlushCacheForHandle() { bool TensorImpl::FlushCacheForHandle() {
if (!(spec_.attr_ & TensorAttribute::INPUT)) { if (!(spec_.attr_ & TensorAttribute::INPUT)) {
return false; return false;
@ -283,7 +289,7 @@ void TensorImpl::unmap() {
if (data_ && spec_.attr_ & TensorAttribute::INPUT) { if (data_ && spec_.attr_ & TensorAttribute::INPUT) {
// Here data_ is an external buffer and may have been updated // Here data_ is an external buffer and may have been updated
vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_); vsi_nn_tensor_t* tensor = vsi_nn_GetTensor(graph_->graph(), id_);
if ( tensor && tensor->attr.is_created_from_handle) { if (tensor && tensor->attr.is_created_from_handle) {
bool retn = (VSI_SUCCESS == vsi_nn_FlushHandle(tensor)); bool retn = (VSI_SUCCESS == vsi_nn_FlushHandle(tensor));
if (!retn) { if (!retn) {
VSILOGE("FlushHandle fail"); VSILOGE("FlushHandle fail");
@ -295,7 +301,7 @@ void TensorImpl::unmap() {
// TODO: unmap fd_ // TODO: unmap fd_
} }
bool TensorImpl::Init(void *external_cache) { bool TensorImpl::Init(void* external_cache) {
vsi_nn_tensor_attr_t attr; vsi_nn_tensor_attr_t attr;
#if (!ENABLE_TENSOR_HNDL) #if (!ENABLE_TENSOR_HNDL)
@ -319,7 +325,7 @@ bool TensorImpl::Init(void *external_cache) {
PackTensorDtype(spec_, &attr.dtype); PackTensorDtype(spec_, &attr.dtype);
#if(ENABLE_TENSOR_HNDL) #if (ENABLE_TENSOR_HNDL)
if ((spec_.attr_ & TensorAttribute::INPUT) || if ((spec_.attr_ & TensorAttribute::INPUT) ||
(spec_.attr_ & TensorAttribute::OUTPUT)) { (spec_.attr_ & TensorAttribute::OUTPUT)) {
#ifdef VX_CREATE_TENSOR_SUPPORT_PHYSICAL #ifdef VX_CREATE_TENSOR_SUPPORT_PHYSICAL
@ -331,7 +337,8 @@ bool TensorImpl::Init(void *external_cache) {
graph_->graph(), graph_->graph(),
VSI_NN_TENSOR_ID_AUTO, // DMABUF's fd is created by TensorFromHandle as input or output, VSI_NN_TENSOR_ID_AUTO, // DMABUF's fd is created by TensorFromHandle as input or output,
&attr, &attr,
fd_ != -1 ? (uint8_t*)fd_ : (uint8_t*)external_cache); // and cannot be set to const fd_ != -1 ? (uint8_t*)fd_
: (uint8_t*)external_cache); // and cannot be set to const
#else #else
if (-1 == fd_) { if (-1 == fd_) {
id_ = vsi_nn_AddTensorFromHandle(graph_->graph(), VSI_NN_TENSOR_ID_AUTO, id_ = vsi_nn_AddTensorFromHandle(graph_->graph(), VSI_NN_TENSOR_ID_AUTO,
@ -445,40 +452,42 @@ int64_t TensorSpec::GetByteSize() const {
return GetElementNum() * GetElementByteSize(); return GetElementNum() * GetElementByteSize();
} }
bool Quantization::operator == (const Quantization& other_quant) const { bool Quantization::operator==(const Quantization& other_quant) const {
if (type_ != tim::vx::QuantType::DYNAMIC_FIXED_POINT){ if (type_ != tim::vx::QuantType::DYNAMIC_FIXED_POINT) {
if(type_ == other_quant.type_ && if (type_ == other_quant.type_ && scales_ == other_quant.scales_ &&
scales_ == other_quant.scales_ && zero_points_ == other_quant.zero_points_ &&
zero_points_ == other_quant.zero_points_ && channel_dim_ == other_quant.channel_dim_)
channel_dim_ == other_quant.channel_dim_) return true;
return true; } else if (fl_ == other_quant.fl_)
return true;
return false;
}
namespace utils {
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor,
std::vector<float> fval, uint8_t* tensorData) {
bool retn = true;
vsi_nn_tensor_attr_t attr;
uint32_t sz = tensor->GetSpec().GetElementNum();
uint32_t stride = tensor->GetSpec().GetElementByteSize();
PackTensorDtype(tensor->GetSpec(), &attr.dtype);
for (uint32_t i = 0; i < sz; i++) {
retn = (VSI_SUCCESS == vsi_nn_Float32ToDtype(
fval[i], &tensorData[i * stride], &attr.dtype));
if (!retn) {
VSILOGE("Convert data fail");
return retn;
} }
else if(fl_ == other_quant.fl_) return true;
return false;
}
namespace utils{
bool Float32ToDtype(std::shared_ptr<tim::vx::Tensor> tensor, std::vector<float> fval, uint8_t* tensorData){
bool retn = true;
vsi_nn_tensor_attr_t attr;
uint32_t sz = tensor->GetSpec().GetElementNum();
uint32_t stride = tensor->GetSpec().GetElementByteSize();
PackTensorDtype(tensor->GetSpec(), &attr.dtype);
for (uint32_t i = 0; i < sz; i++){
retn = (VSI_SUCCESS == vsi_nn_Float32ToDtype(fval[i], &tensorData[i * stride], &attr.dtype));
if (!retn) {
VSILOGE("Convert data fail");
return retn;
} }
} return retn;
return retn;
} }
bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor, uint8_t* tensorData, float* data){ bool DtypeToFloat32(std::shared_ptr<tim::vx::Tensor> tensor,
uint8_t* tensorData, float* data) {
bool retn = true; bool retn = true;
vsi_nn_tensor_attr_t attr; vsi_nn_tensor_attr_t attr;
PackTensorDtype(tensor->GetSpec(), &attr.dtype); PackTensorDtype(tensor->GetSpec(), &attr.dtype);
retn = (VSI_SUCCESS == vsi_nn_DtypeToFloat32(tensorData, data, &attr.dtype)); retn = (VSI_SUCCESS == vsi_nn_DtypeToFloat32(tensorData, data, &attr.dtype));
return retn; return retn;
} }

View File

@ -37,7 +37,7 @@ class TensorImpl : public Tensor {
TensorImpl(Graph* graph, const TensorSpec& spec, void* data = nullptr); TensorImpl(Graph* graph, const TensorSpec& spec, void* data = nullptr);
~TensorImpl(); ~TensorImpl();
bool Init(void *external_cache = nullptr); bool Init(void* external_cache = nullptr);
bool IsWriteable(); bool IsWriteable();
bool IsReadable(); bool IsReadable();
@ -58,7 +58,7 @@ class TensorImpl : public Tensor {
} }
bool SaveTensorToTextByFp32(std::string filename) override; bool SaveTensorToTextByFp32(std::string filename) override;
void* ConvertTensorToData(uint8_t* tensorData) override; void* ConvertTensorToData(uint8_t* tensorData) override;
float* ConvertTensorToFloat32Data() override;
GraphImpl* graph_; GraphImpl* graph_;
vsi_nn_tensor_id_t id_; vsi_nn_tensor_id_t id_;
TensorSpec spec_; TensorSpec spec_;
@ -68,7 +68,7 @@ class TensorImpl : public Tensor {
class TensorPlaceholder : public Tensor { class TensorPlaceholder : public Tensor {
public: public:
TensorPlaceholder(Graph* graph) : id_(VSI_NN_TENSOR_ID_NA) {(void)(graph);} TensorPlaceholder(Graph* graph) : id_(VSI_NN_TENSOR_ID_NA) { (void)(graph); }
~TensorPlaceholder(){}; ~TensorPlaceholder(){};
const ShapeType& GetShape() override { return spec_.shape_; } const ShapeType& GetShape() override { return spec_.shape_; }
@ -95,14 +95,15 @@ class TensorPlaceholder : public Tensor {
bool IsConstTensor() override { bool IsConstTensor() override {
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
} }
bool SaveTensorToTextByFp32(std::string filename) override { bool SaveTensorToTextByFp32(std::string filename) override {
(void)filename; (void)filename;
return false; return false;
} }
void* ConvertTensorToData(uint8_t* tensorData) override { void* ConvertTensorToData(uint8_t* tensorData) override {
(void)tensorData; (void)tensorData;
return nullptr; return nullptr;
} }
float* ConvertTensorToFloat32Data() override { return nullptr; }
vsi_nn_tensor_id_t id_; vsi_nn_tensor_id_t id_;
TensorSpec spec_; TensorSpec spec_;