Supported composed layout infer & added unit test

Fixed fc layout infer in rnncell layout infer
This commit is contained in:
Chen Xin 2022-09-14 18:07:08 +08:00 committed by Sven
parent 1802e558ad
commit 72f2c5b69e
9 changed files with 239 additions and 23 deletions

View File

@ -38,16 +38,18 @@ namespace vx {
}
namespace transform {
class IPermuteVector;
std::pair<
/*graph after layout inference*/
std::shared_ptr<vx::Graph>,
/* tensor mapping between original graph and graph after layout infer*/
std::map<
std::shared_ptr<vx::Tensor>,
std::shared_ptr<vx::Tensor>>
>
LayoutInference(const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Context>& ctx);
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>>
LayoutInference(
const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Context>& ctx,
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<IPermuteVector>>
tensor_pv_map = std::map<std::shared_ptr<vx::Tensor>,
std::shared_ptr<IPermuteVector>>());
} // namespace transform
} // namespace tim

View File

@ -282,10 +282,12 @@ std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
} // namespace layout_inference_impl
std::pair<std::shared_ptr<vx::Graph>,
std::map<std::shared_ptr<vx::Tensor>,
std::shared_ptr<vx::Tensor>>> LayoutInference(
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>>
LayoutInference(
const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Context>& ctx) {
std::shared_ptr<vx::Context>& ctx,
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<IPermuteVector>>
tensor_pv_map) {
std::shared_ptr<vx::Graph> infer_graph = ctx->CreateGraph();
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
graph_io_map;
@ -300,8 +302,10 @@ std::pair<std::shared_ptr<vx::Graph>,
layout_infer_ctx->UpdateTensorMap(t_src, input);
layout_infer_ctx->UpdateGraphInputMap(t_src, input);
tensor_queue.push_back(t_src);
layout_infer_ctx->SetPermuteVector(t_src,
MakeShared(t_src->GetShape().size()));
layout_infer_ctx->SetPermuteVector(
t_src, tensor_pv_map.find(t_src) != tensor_pv_map.end()
? tensor_pv_map[t_src]
: MakeShared(t_src->GetShape().size()));
}
auto const_inputs = src_graph->GetConstantInputs();
@ -310,8 +314,10 @@ std::pair<std::shared_ptr<vx::Graph>,
infer_graph->CreateTensor(const_in->GetSpec(), const_in->GetDataRef());
layout_infer_ctx->UpdateTensorMap(const_in, input);
tensor_queue.push_back(const_in);
layout_infer_ctx->SetPermuteVector(const_in,
MakeShared(const_in->GetShape().size()));
layout_infer_ctx->SetPermuteVector(
const_in, tensor_pv_map.find(const_in) != tensor_pv_map.end()
? tensor_pv_map[const_in]
: MakeShared(const_in->GetShape().size()));
}
while (!tensor_queue.empty()) {
@ -319,7 +325,7 @@ std::pair<std::shared_ptr<vx::Graph>,
tensor_queue.pop_front();
const auto& consumers = src_graph->GetConsumersOp(tensor);
for (const auto& op : consumers) {
if (!layout_infer_ctx->IsVisited(op) &&
if (!layout_infer_ctx->IsVisited(op) && op->impl()->kind_ !=-1 &&
layout_infer_ctx->IsReadyForInfer(op)) {
auto next_tensors =
layout_inference_impl::HandleLayoutInfer(layout_infer_ctx, op);

View File

@ -41,8 +41,12 @@ class FullyConnectedLayoutInfer : public OpLayoutInfer {
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
/* vx_delegate has reshaped the input to two-dimensional when mapping,
The axis of 2-dimensional fc can only be 0. */
auto input_tensors = op_->impl()->InputsTensor();
if(!context_->GetPermuteVector(input_tensors[0])->IsAligned()){
ReverseInputsPermuteVector();
}
for (const auto& in : input_tensors) {
if (in->IsConstTensor()) {
auto infer_tensor = context_->infer_graph_->CreateTensor(in->GetSpec(),

View File

@ -0,0 +1,205 @@
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/ops.h"
#include "tim/transform/layout_inference.h"
#include "permute_vector.h"
#include "test_utils.h"
#include <algorithm>
#include "gtest/gtest.h"
TEST(RNNCell, layout_infer_align) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
uint32_t input_size = 3, batch_size = 2, num_units = 4;
tim::vx::ShapeType input_shape({input_size, batch_size});
tim::vx::ShapeType weights_shape({input_size, num_units});
tim::vx::ShapeType recurrent_weights_shape({num_units, num_units});
tim::vx::ShapeType bias_shape({num_units});
tim::vx::ShapeType state_in_shape({num_units, batch_size});
tim::vx::ShapeType output_shape({num_units, batch_size});
tim::vx::ShapeType state_out_shape({num_units, batch_size});
tim::vx::Quantization quant(tim::vx::QuantType::ASYMMETRIC, 0.0036, 0);
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, weights_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec recurrent_weights_spec(
tim::vx::DataType::FLOAT32, recurrent_weights_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, bias_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, state_in_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
tim::vx::TensorSpec state_out_spec(tim::vx::DataType::UINT8, state_out_shape,
tim::vx::TensorAttribute::OUTPUT, quant);
std::vector<float> in_data = {
0.12609188, 0.46347019, 0.89598465, 0.35867718, 0.36897406, 0.73463392,
};
std::vector<float> weights_data = {
0.12609188, 0.46347019, 0.89598465, 0.35867718, 0.36897406, 0.73463392,
0.12609188, 0.46347019, 0.89598465, 0.35867718, 0.36897406, 0.73463392,
};
std::vector<float> recurrent_weights_data = {
-0.31930989, 0.37613347, 0.27901134, 0.36137494, -1.36916667, 0.38031587,
0.21580373, 0.27072677, 1.01580888, 0.14943552, 1.15465137, 0.09784451,
-1.02702999, 1.39296314, 0.15785322, 0.21931258,
};
std::vector<float> bias_data = {
0.01580888,
0.14943552,
0.15465137,
0.09784451,
};
std::vector<float> state_in_data = {0, 0, 0, 0, 0, 0, 0, 0};
std::vector<float> output_golden = {0.781534, 0.771447, 0.830002, 0.749713,
0.711524, 0.74155, 0.77355, 0.717427};
std::vector<uint8_t> state_out_golden = {
217, 214, 231, 208, 198, 206, 215, 199,
};
auto input_tensor = graph->CreateTensor(input_spec);
auto weights_tensor = graph->CreateTensor(weights_spec, weights_data.data());
auto recurrent_weights_tensor = graph->CreateTensor(
recurrent_weights_spec, recurrent_weights_data.data());
auto bias_tensor = graph->CreateTensor(bias_spec, bias_data.data());
auto state_in_tensor = graph->CreateTensor(state_in_spec);
auto output_tensor = graph->CreateTensor(output_spec);
auto state_out_tensor = graph->CreateTensor(state_out_spec);
auto op = graph->CreateOperation<tim::vx::ops::RNNCell>(
tim::vx::ops::RNNCell::ActivationType::kSIGMOID);
(*op)
.BindInputs({input_tensor, weights_tensor, bias_tensor, state_in_tensor,
recurrent_weights_tensor})
.BindOutputs({output_tensor, state_out_tensor});
auto transform = tim::transform::LayoutInference(graph, ctx);
auto infer_graph = transform.first;
EXPECT_TRUE(infer_graph->Compile());
auto graph_io_map = transform.second;
auto infer_input = graph_io_map[graph->InputsTensor()[0]];
auto infer_input_state = graph_io_map[graph->InputsTensor()[1]];
auto infer_output = graph_io_map[graph->OutputsTensor()[0]];
auto infer_output_state = graph_io_map[graph->OutputsTensor()[1]];
infer_input->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float));
infer_input_state->CopyDataToTensor(state_in_data.data(),
state_in_data.size() * sizeof(float));
EXPECT_TRUE(infer_graph->Run());
std::vector<float> output(output_golden.size());
std::vector<uint8_t> state_out(state_out_golden.size());
EXPECT_TRUE(infer_output->CopyDataFromTensor(output.data()));
EXPECT_TRUE(infer_output_state->CopyDataFromTensor(state_out.data()));
EXPECT_TRUE(ArraysMatch(output_golden, output, 1e-5f));
EXPECT_EQ(state_out_golden, state_out);
}
TEST(RNNCell, layout_infer_notalign) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
uint32_t input_size = 3, batch_size = 2, num_units = 4;
tim::vx::ShapeType input_shape({batch_size, input_size}); //input_pv={1,0}
tim::vx::ShapeType weights_shape({input_size,num_units});
tim::vx::ShapeType recurrent_weights_shape({num_units, num_units});
tim::vx::ShapeType bias_shape({num_units});
tim::vx::ShapeType state_in_shape({num_units, batch_size});
tim::vx::ShapeType output_shape({num_units, batch_size});
tim::vx::ShapeType state_out_shape({num_units, batch_size});
tim::vx::Quantization quant(tim::vx::QuantType::ASYMMETRIC, 0.0036, 0);
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec weights_spec(tim::vx::DataType::FLOAT32, weights_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec recurrent_weights_spec(
tim::vx::DataType::FLOAT32, recurrent_weights_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, bias_shape,
tim::vx::TensorAttribute::CONSTANT);
tim::vx::TensorSpec state_in_spec(tim::vx::DataType::FLOAT32, state_in_shape,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
tim::vx::TensorSpec state_out_spec(tim::vx::DataType::UINT8, state_out_shape,
tim::vx::TensorAttribute::OUTPUT, quant);
std::vector<float> in_data = {
0.12609188, 0.35867718, 0.46347019, 0.36897406, 0.89598465, 0.73463392,
};
std::vector<float> weights_data = {
0.12609188, 0.46347019, 0.89598465, 0.35867718, 0.36897406, 0.73463392,
0.12609188, 0.46347019, 0.89598465, 0.35867718, 0.36897406, 0.73463392,
};
std::vector<float> recurrent_weights_data = {
-0.31930989, 0.37613347, 0.27901134, 0.36137494, -1.36916667, 0.38031587,
0.21580373, 0.27072677, 1.01580888, 0.14943552, 1.15465137, 0.09784451,
-1.02702999, 1.39296314, 0.15785322, 0.21931258,
};
std::vector<float> bias_data = {
0.01580888,
0.14943552,
0.15465137,
0.09784451,
};
std::vector<float> state_in_data = {0, 0, 0, 0, 0, 0, 0, 0};
std::vector<float> output_golden = {0.781534, 0.771447, 0.830002, 0.749713,
0.711524, 0.74155, 0.77355, 0.717427};
std::vector<uint8_t> state_out_golden = {
217, 214, 231, 208, 198, 206, 215, 199,
};
auto input_tensor = graph->CreateTensor(input_spec);
auto weights_tensor = graph->CreateTensor(weights_spec, weights_data.data());
auto recurrent_weights_tensor = graph->CreateTensor(
recurrent_weights_spec, recurrent_weights_data.data());
auto bias_tensor = graph->CreateTensor(bias_spec, bias_data.data());
auto state_in_tensor = graph->CreateTensor(state_in_spec);
auto output_tensor = graph->CreateTensor(output_spec);
auto state_out_tensor = graph->CreateTensor(state_out_spec);
std::map<std::shared_ptr<tim::vx::Tensor>,
std::shared_ptr<tim::transform::IPermuteVector>>
tensor_pv_map;
std::shared_ptr<tim::transform::IPermuteVector> pv =
std::make_shared<tim::transform::PermuteVector<2>>(
std::initializer_list<uint32_t>({1U, 0U}));
tensor_pv_map.insert({input_tensor, pv});
auto op = graph->CreateOperation<tim::vx::ops::RNNCell>(
tim::vx::ops::RNNCell::ActivationType::kSIGMOID);
(*op)
.BindInputs({input_tensor, weights_tensor, bias_tensor, state_in_tensor,
recurrent_weights_tensor})
.BindOutputs({output_tensor, state_out_tensor});
auto transform = tim::transform::LayoutInference(graph, ctx, tensor_pv_map);
auto infer_graph = transform.first;
EXPECT_TRUE(infer_graph->Compile());
auto graph_io_map = transform.second;
auto infer_input = graph_io_map[graph->InputsTensor()[0]];
auto infer_input_state = graph_io_map[graph->InputsTensor()[1]];
auto infer_output = graph_io_map[graph->OutputsTensor()[0]];
auto infer_output_state = graph_io_map[graph->OutputsTensor()[1]];
infer_input->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float));
infer_input_state->CopyDataToTensor(state_in_data.data(),
state_in_data.size() * sizeof(float));
EXPECT_TRUE(infer_graph->Run());
std::vector<float> output(output_golden.size());
std::vector<uint8_t> state_out(state_out_golden.size());
EXPECT_TRUE(infer_output->CopyDataFromTensor(output.data()));
EXPECT_TRUE(infer_output_state->CopyDataFromTensor(state_out.data()));
EXPECT_TRUE(ArraysMatch(output_golden, output, 1e-5f));
EXPECT_EQ(state_out_golden, state_out);
}

View File

@ -27,7 +27,7 @@
namespace tim {
namespace vx {
BuiltinOpImpl::BuiltinOpImpl(Graph* graph, uint32_t kind, int input_cnt,
BuiltinOpImpl::BuiltinOpImpl(Graph* graph, int32_t kind, int input_cnt,
int output_cnt, DataLayout layout)
: OpImpl(graph, kind, input_cnt, output_cnt, layout),
node_(vsi_nn_AddNode(graph_->graph(), kind_, input_cnt_, output_cnt_,

View File

@ -34,7 +34,7 @@ namespace vx {
class BuiltinOpImpl : public OpImpl {
public:
BuiltinOpImpl(Graph* graph, uint32_t kind, int input_cnt = 0,
BuiltinOpImpl(Graph* graph, int32_t kind, int input_cnt = 0,
int output_cnt = 0, DataLayout layout = DataLayout::ANY);
BuiltinOpImpl(Graph* graph,DataLayout layout = DataLayout::ANY);
~BuiltinOpImpl() {}

View File

@ -26,7 +26,7 @@
namespace tim {
namespace vx {
OpImpl::OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt,
OpImpl::OpImpl(Graph* graph, int32_t kind, int input_cnt, int output_cnt,
DataLayout layout)
: graph_(reinterpret_cast<GraphImpl*>(graph)),
kind_(kind),

View File

@ -33,7 +33,7 @@ namespace vx {
class OpImpl {
public:
OpImpl(Graph* graph, uint32_t kind, int input_cnt, int output_cnt,
OpImpl(Graph* graph, int32_t kind, int input_cnt, int output_cnt,
DataLayout layout);
OpImpl(Graph* graph, DataLayout layout);
@ -50,7 +50,7 @@ class OpImpl {
uint32_t accumulator_bits = 0);
GraphImpl* graph_{nullptr};
uint32_t kind_{0};
int32_t kind_{0};
int32_t input_cnt_{0};
int32_t output_cnt_{0};
DataLayout layout_{DataLayout::ANY};

View File

@ -98,10 +98,9 @@ class RNNCellImpl : public OpImpl {
RNNCellImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override {
out_tensors_[output_tensor_index] = tensor;
tanh_->BindOutput(out_tensors_[OUT]);
data_convert_->BindInput(out_tensors_[OUT]);
if (this->output_tensor_index == OUT_CNT - 1) {
tanh_->BindOutput(out_tensors_[OUT]);
data_convert_->BindInput(out_tensors_[OUT]);
data_convert_->BindOutput(out_tensors_[STATE_OUT]);
}
this->output_tensor_index++;