Add layout inference feature (#34)

* mobilenet_v1_1.0_224_quant.tflite pass
* inception_v1_224_quant.tflite pass
* ssd_mobilenet_v2_fpnlite_320x320_coco17_quant.tflite pass

Signed-off-by: Zongwu Yang <zongwu.yang@verisilicon.com>
This commit is contained in:
Zongwu.Yang 2021-05-08 09:29:47 +08:00 committed by GitHub
parent e5d6da20a7
commit 77b801a590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1856 additions and 27 deletions

7
BUILD
View File

@ -16,6 +16,7 @@ cc_library(
"include/tim/vx/operation.h", "include/tim/vx/operation.h",
"include/tim/vx/tensor.h", "include/tim/vx/tensor.h",
"include/tim/vx/types.h", "include/tim/vx/types.h",
"include/tim/layout_infer/layout_inference.h"
] + glob([ ] + glob([
"include/tim/vx/ops/*.h" "include/tim/vx/ops/*.h"
]), ]),
@ -30,9 +31,11 @@ cc_library(
"src/tim/vx/tensor_private.h", "src/tim/vx/tensor_private.h",
"src/tim/vx/type_utils.h", "src/tim/vx/type_utils.h",
"src/tim/vx/type_utils.cc", "src/tim/vx/type_utils.cc",
"src/tim/layout_infer/layout_inference.cc",
"src/tim/layout_infer/permute_vector.h"
] + glob([ ] + glob([
"src/tim/vx/ops/*.cc" "src/tim/vx/ops/*.cc"
]), ]) + glob(["src/tim/layout_infer/ops/*.*"]),
deps = [ deps = [
"//src/tim/vx/internal:ovxlibimpl", "//src/tim/vx/internal:ovxlibimpl",
], ],
@ -55,7 +58,7 @@ cc_binary(
cc_test ( cc_test (
name = "unit_test", name = "unit_test",
copts = ["-std=c++14", "-Werror"], copts = ["-std=c++14", "-Werror"],
srcs = glob(["src/tim/vx/*_test.cc"]), srcs = glob(["src/tim/**/*_test.cc"]),
deps = [ deps = [
"@gtest//:gtest", "@gtest//:gtest",
"@gtest//:gtest_main", "@gtest//:gtest_main",

1
bazel-TIM-VX Symbolic link
View File

@ -0,0 +1 @@
/home/zongwu.yang/.cache/bazel/_bazel_zongwu.yang/03042b74c832165a2ec8acf0ff919a07/execroot/TIM_VX

View File

@ -0,0 +1,90 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFERENCE_H_
#include <map>
#include <vector>
#include "tim/vx/context.h"
#include "tim/vx/operation.h"
#include "src/tim/layout_infer/permute_vector.h"
namespace tim {
namespace transform {
namespace layout_inference_impl {
class LayoutInferContext {
public:
LayoutInferContext(const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Graph>& infer_graph)
: src_graph_(src_graph), infer_graph_(infer_graph) {}
void SetPermuteVector(std::shared_ptr<vx::Tensor> tensor,
std::shared_ptr<IPermuteVector> pv);
const std::shared_ptr<IPermuteVector> GetPermuteVector(
const std::shared_ptr<vx::Tensor>& tensor) const;
void MarkVisited(const std::shared_ptr<vx::Operation>& op);
bool IsVisited(const std::shared_ptr<vx::Operation>& op) const;
bool IsReadyForInfer(const std::shared_ptr<vx::Operation>& op) const;
void UpdateTensorMap(const std::shared_ptr<vx::Tensor>& t_src,
const std::shared_ptr<vx::Tensor>& t_layout);
std::shared_ptr<vx::Tensor> GetMapedTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;
void UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
const std::shared_ptr<vx::Tensor>& i_layout);
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
GetGraphInputMap() const {
return graph_input_map_;
}
const std::shared_ptr<vx::Graph>& src_graph_;
std::shared_ptr<vx::Graph>& infer_graph_;
private:
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<IPermuteVector>>
tensor_pv_;
std::vector<std::shared_ptr<vx::Operation>> visited_op_;
// tensor_in_src -> tensor_in_layout
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
tensor_map_;
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
graph_input_map_;
};
std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
std::shared_ptr<layout_inference_impl::LayoutInferContext>& ctx,
const std::shared_ptr<vx::Operation>& op);
} // namespace layout_inference_impl
std::pair<std::shared_ptr<vx::Graph>, /* infer graph */
std::map<std::shared_ptr<vx::Tensor>,
std::shared_ptr<vx::Tensor>> /* graph io tensor map */>
LayoutInference(const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Context>& ctx);
} // namespace transform
} // namespace tim
#endif

View File

@ -57,12 +57,24 @@ class Graph {
template <typename OpType, typename... Params> template <typename OpType, typename... Params>
std::shared_ptr<OpType> CreateOperation(Params... parameters) { std::shared_ptr<OpType> CreateOperation(Params... parameters) {
auto op = std::make_shared<OpType>(this, parameters...); auto op = std::make_shared<OpType>(this, parameters...);
opVector.push_back(op); op_vector_.push_back(op);
return op; return op;
} }
private: virtual const std::vector<std::shared_ptr<Tensor>> InputsTensor() const = 0;
std::vector<std::shared_ptr<tim::vx::Operation>> opVector; virtual const std::vector<std::shared_ptr<Tensor>> OutputsTensor() const = 0;
virtual void UpdateTensorConsumersMap(
const std::shared_ptr<Tensor>& tensor,
const Operation* op) = 0;
virtual const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
std::shared_ptr<Tensor> tensor) const = 0;
virtual void PrintGraph() const = 0;
protected:
std::vector<std::shared_ptr<tim::vx::Operation>> op_vector_;
}; };
} // namespace vx } // namespace vx

View File

@ -34,8 +34,8 @@ class OperationImpl;
class Operation { class Operation {
public: public:
Operation(Graph* graph, uint32_t operation_id, int input_cnt = 0, Operation(Graph* graph, uint32_t operation_id,
int ouput_cnt = 0); int input_cnt = 0, int ouput_cnt = 0, DataLayout layout = DataLayout::ANY);
virtual ~Operation(); virtual ~Operation();
Operation& BindInput(const std::shared_ptr<Tensor>& tensor); Operation& BindInput(const std::shared_ptr<Tensor>& tensor);
Operation& BindOutput(const std::shared_ptr<Tensor>& tensor); Operation& BindOutput(const std::shared_ptr<Tensor>& tensor);

View File

@ -37,12 +37,14 @@ class Conv2d : public Operation {
Conv2d(Graph* graph, int32_t weights, PadType padding, Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize, const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride, const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier = 0); const std::array<uint32_t, 2>& dilation, int32_t multiplier = 0,
DataLayout layout = DataLayout::WHCN);
Conv2d(Graph* graph, int32_t weights, PadType padding, Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize, const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride, const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, const std::array<uint32_t, 2>& dilation,
const std::array<uint32_t, 4>& pad, int32_t multiplier = 0); const std::array<uint32_t, 4>& pad, int32_t multiplier = 0,
DataLayout layout = DataLayout::WHCN);
protected: protected:
const uint32_t weights_; const uint32_t weights_;

View File

@ -38,7 +38,8 @@ class Pool2d : public Operation {
Pool2d(Graph* graph, PoolType type, PadType padding, Pool2d(Graph* graph, PoolType type, PadType padding,
const std::array<uint32_t, 2>& ksize, const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride, const std::array<uint32_t, 2>& stride,
RoundType round_type = RoundType::FLOOR); RoundType round_type = RoundType::FLOOR,
DataLayout layout = DataLayout::WHCN);
protected: protected:
const PoolType type_; const PoolType type_;

View File

@ -140,6 +140,7 @@ class Tensor {
virtual bool CopyDataFromTensor(void* data) = 0; virtual bool CopyDataFromTensor(void* data) = 0;
virtual bool IsPlaceHolder() = 0; virtual bool IsPlaceHolder() = 0;
virtual bool IsConstTensor() = 0; virtual bool IsConstTensor() = 0;
virtual const void* GetDataRef() const = 0;
}; };
} // namespace vx } // namespace vx

View File

@ -73,6 +73,8 @@ enum class ActivationType {
enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA }; enum class ResizeType { NEAREST_NEIGHBOR, BILINEAR, AREA };
enum class DataLayout { WHCN, CWHN, ANY };
} // namespace vx } // namespace vx
} // namespace tim } // namespace tim

View File

@ -0,0 +1,229 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "src/tim/layout_infer/permute_vector.h"
#include "tim/layout_infer/layout_inference.h"
#include "src/tim/vx/operation_private.h"
#include "src/tim/layout_infer/ops/conv2d_layout_inference.h"
#include "src/tim/layout_infer/ops/reduce_layout_inference.h"
#include "src/tim/layout_infer/ops/elementwise_layout_inference.h"
#include "src/tim/layout_infer/ops/activation_layout_inference.h"
#include "src/tim/layout_infer/ops/concat_layout_inferene.h"
#include "src/tim/layout_infer/ops/reshape_layout_inference.h"
#include "src/tim/layout_infer/ops/simple_ops_layout_inference.h"
#include "src/tim/layout_infer/ops/pool2d_layout_inference.h"
#include "src/tim/layout_infer/ops/softmax_layout_inference.h"
#include "src/tim/layout_infer/ops/squeeze_layout_inference.h"
#include "src/tim/layout_infer/ops/stack_layout_inference.h"
#include <algorithm>
#include <deque>
namespace tim {
namespace transform {
namespace layout_inference_impl {
// Implemention for LayoutInferContext
void LayoutInferContext::SetPermuteVector(std::shared_ptr<vx::Tensor> tensor,
std::shared_ptr<IPermuteVector> pv) {
if (tensor_pv_.end() != tensor_pv_.find(tensor)) {
VSILOGD("Tensor PermuteVector has been setted.");
}
tensor_pv_[tensor] = pv;
}
const std::shared_ptr<IPermuteVector> LayoutInferContext::GetPermuteVector(
const std::shared_ptr<vx::Tensor>& tensor) const {
auto pv_it = tensor_pv_.find(tensor);
if (pv_it != tensor_pv_.end()) {
return pv_it->second;
} else {
VSILOGE("Tensor PermuteVecor has not beed setted.");
assert(false);
}
}
void LayoutInferContext::MarkVisited(const std::shared_ptr<vx::Operation>& op) {
if (visited_op_.end() !=
std::find(visited_op_.begin(), visited_op_.end(), op)) {
VSILOGW("The operation has been mark as visited.");
} else {
visited_op_.push_back(op);
}
}
bool LayoutInferContext::IsVisited(const std::shared_ptr<vx::Operation>& op) const {
if (visited_op_.end() !=
std::find(visited_op_.begin(), visited_op_.end(), op)) {
return true;
} else {
return false;
}
}
bool LayoutInferContext::IsReadyForInfer(
const std::shared_ptr<vx::Operation>& op) const {
for (const auto& tensor : op->impl()->InputsTensor()) {
if (tensor_pv_.end() == tensor_pv_.find(tensor)) {
return false;
}
}
return true;
}
void LayoutInferContext::UpdateTensorMap(
const std::shared_ptr<vx::Tensor>& t_src,
const std::shared_ptr<vx::Tensor>& t_layout) {
tensor_map_[t_src] = t_layout;
}
std::shared_ptr<vx::Tensor> LayoutInferContext::GetMapedTensor(
const std::shared_ptr<vx::Tensor>& t_src) const {
auto it = tensor_map_.find(t_src);
if (it != tensor_map_.end()) {
return it->second;
} else {
VSILOGE("Tensor has not beed inserted in tensor map.");
assert(false);
}
}
void LayoutInferContext::UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
const std::shared_ptr<vx::Tensor>& i_layout) {
graph_input_map_[i_src] = i_layout;
}
#define REGIST_LAYOUT_INFERENCE(op_idx, name) \
case op_idx: { \
auto op_infer = std::make_shared<name##LayoutInfer>(op, ctx); \
op_infer->OnInputs(next_tensors); \
op_infer->OnOutputs(next_tensors); \
break; \
} \
std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
std::shared_ptr<layout_inference_impl::LayoutInferContext>& ctx,
const std::shared_ptr<vx::Operation>& op) {
ctx->MarkVisited(op);
auto op_id = op->impl()->operation_id_;
std::vector<std::shared_ptr<vx::Tensor>> next_tensors;
switch (op_id) {
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONV2D, Conv2d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RELU, Relu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RELU1, Relu1);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RELU6, Relu6);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ELU, Elu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SIGMOID, Sigmoid);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MISH, Mish);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_HARD_SIGMOID, HardSigmoid);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SOFTRELU, SoftRelu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SWISH, HardSwish);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LEAKY_RELU, LeakyRelu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONCAT, Concat);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ADD, Add);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SUBTRACT, Sub);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MULTIPLY, Multiply);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DIVIDE, Div);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_POW, Pow);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MINIMUM, Minimum);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MAXIMUM, Maximum);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RESHAPE, Reshape);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DATACONVERT, DataConvert);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_NEG, Neg);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ABS, Abs);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SIN, Sin);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_EXP, Exp);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LOG, Log);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SQRT, Sqrt);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RSQRT, Rsqrt);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SQUARE, Square);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_NOT, LogicalNot);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_POOL, Pool2d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SOFTMAX, Softmax);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SQUEEZE, Squeeze);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_STACK, Stack);
default:
VSILOGW("Op %d: Default layout inference pass.", op_id);
assert(false);
}
return next_tensors;
}
} // namespace layout_inference_impl
std::pair<std::shared_ptr<vx::Graph>,
std::map<std::shared_ptr<vx::Tensor>,
std::shared_ptr<vx::Tensor>>> LayoutInference(
const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Context>& ctx) {
std::shared_ptr<vx::Graph> infer_graph = ctx->CreateGraph();
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
graph_io_map;
auto layout_infer_ctx =
std::make_shared<layout_inference_impl::LayoutInferContext>(src_graph,
infer_graph);
std::deque<std::shared_ptr<vx::Tensor>> tensor_queue;
auto graph_inputs = src_graph->InputsTensor();
for (const auto& t_src : graph_inputs) {
if (t_src->IsConstTensor()) {
layout_infer_ctx->UpdateTensorMap(
t_src,
infer_graph->CreateTensor(t_src->GetSpec(), t_src->GetDataRef()));
} else {
auto input = infer_graph->CreateTensor(t_src->GetSpec());
layout_infer_ctx->UpdateTensorMap(t_src, input);
layout_infer_ctx->UpdateGraphInputMap(t_src, input);
tensor_queue.push_back(t_src);
}
layout_infer_ctx->SetPermuteVector(t_src,
MakeShared(t_src->GetShape().size()));
}
while (!tensor_queue.empty()) {
const auto& tensor = tensor_queue.front();
tensor_queue.pop_front();
const auto& consumers = src_graph->GetConsumersOp(tensor);
for (const auto& op : consumers) {
if (!layout_infer_ctx->IsVisited(op) &&
layout_infer_ctx->IsReadyForInfer(op)) {
auto next_tensors =
layout_inference_impl::HandleLayoutInfer(layout_infer_ctx, op);
for (const auto& t : next_tensors) {
tensor_queue.push_back(t);
}
}
}
}
for (const auto& graph_input : layout_infer_ctx->GetGraphInputMap()) {
graph_io_map[graph_input.first] = graph_input.second;
}
for (const auto& out_src : src_graph->OutputsTensor()) {
graph_io_map[out_src] = layout_infer_ctx->GetMapedTensor(out_src);
}
return std::make_pair(infer_graph, graph_io_map);
}
} // namespace transform
} // namespace tim

View File

@ -0,0 +1,61 @@
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/ops/conv2d.h"
#include "tim/layout_infer/layout_inference.h"
#include "gtest/gtest.h"
TEST(LayoutInference, simple_conv2d) {
auto ctx = tim::vx::Context::Create();
auto src_graph = ctx->CreateGraph();
tim::vx::ShapeType input_shape({1, 3, 3, 1});
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, input_shape,
tim::vx::TensorAttribute::INPUT);
auto input = src_graph->CreateTensor(input_spec);
tim::vx::ShapeType kernel_shape({1, 2, 2, 1});
tim::vx::TensorSpec kernel_spec(tim::vx::DataType::FLOAT32, kernel_shape,
tim::vx::TensorAttribute::CONSTANT);
std::vector<float> kernel_data = {
0.25f, 0.25f, 0.25f, 0.25f};
auto kernel = src_graph->CreateTensor(kernel_spec, kernel_data.data());
tim::vx::ShapeType bias_shape({1});
tim::vx::TensorSpec bias_spec(tim::vx::DataType::FLOAT32, bias_shape,
tim::vx::TensorAttribute::CONSTANT);
std::vector<float> bias_data = {0.0f};
auto bias = src_graph->CreateTensor(bias_spec, bias_data.data());
tim::vx::ShapeType output_shape({1, 2, 2, 1});
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, output_shape,
tim::vx::TensorAttribute::OUTPUT);
auto output = src_graph->CreateTensor(output_spec);
auto conv2d = src_graph->CreateOperation<tim::vx::ops::Conv2d>(
kernel_shape[0], tim::vx::PadType::AUTO,
std::array<uint32_t, 2>({kernel_shape[2], kernel_shape[1]}),
std::array<uint32_t, 2>({1, 1}), std::array<uint32_t, 2>({0, 0}),
std::array<uint32_t, 4>({0, 0, 0, 0}), 0, tim::vx::DataLayout::CWHN);
(*conv2d).BindInputs({input, kernel, bias}).BindOutput(output);
// Do layout inference
auto layout_infer = tim::transform::LayoutInference(src_graph, ctx);
auto infer_graph = layout_infer.first;
auto graph_io_map = layout_infer.second;
infer_graph->Compile();
std::vector<float> input_data = {1.0f, 1.0f, 1.0f, 1.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.0f};
auto infer_input = graph_io_map[src_graph->InputsTensor()[0]];
auto infer_output = graph_io_map[src_graph->OutputsTensor()[0]];
infer_input->CopyDataToTensor(input_data.data(), input_data.size() * sizeof(float));
infer_graph->Run();
std::vector<float> out_data;
auto infer_out_shape = infer_output->GetShape();
out_data.resize(infer_out_shape[0] * infer_out_shape[1] * infer_out_shape[2] *
infer_out_shape[3]);
infer_output->CopyDataFromTensor(out_data.data());
std::vector<float> expect_output = {0.875f, 0.875f, 0.875f, 0.875f};
EXPECT_TRUE(0 == memcmp((void*)out_data.data(), (void*)expect_output.data(),
sizeof(float) * out_data.size()));
tim::vx::ShapeType expect_shape({1, 2, 2, 1});
EXPECT_EQ(infer_out_shape, expect_shape);
}

View File

@ -0,0 +1,99 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_ACTIVATION_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_ACTIVATION_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/activations.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
template <typename OpType>
class ActivationLayoutInfer : public OpLayoutInfer {
public:
ActivationLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
// Transmit input pv to out pv directly for activation ops
assert(op_->impl()->InputsTensor().size() == 1);
auto i_src = op_->impl()->InputsTensor()[0];
auto input_pv = context_->GetPermuteVector(i_src);
auto activation = context_->infer_graph_->CreateOperation<OpType>();
auto out_infer = CreateOutputsTensor(input_pv);
(*activation)
.BindInput(context_->GetMapedTensor(i_src))
.BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
class LeakyReluLayoutInfer : public OpLayoutInfer {
public:
LeakyReluLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
assert(op_->impl()->InputsTensor().size() == 1);
auto i_src = op_->impl()->InputsTensor()[0];
auto input_pv = context_->GetPermuteVector(i_src);
auto leaky_relu =
context_->infer_graph_->CreateOperation<vx::ops::LeakyRelu>(
op_->impl()->node()->nn_param.activation.leaky_ratio);
auto out_infer = CreateOutputsTensor(input_pv);
(*leaky_relu)
.BindInput(context_->GetMapedTensor(i_src))
.BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
// TODO(yzw): Add Prelu
using ReluLayoutInfer = ActivationLayoutInfer<vx::ops::Relu>;
using Relu1LayoutInfer = ActivationLayoutInfer<vx::ops::Relu1>;
using Relu6LayoutInfer = ActivationLayoutInfer<vx::ops::Relu6>;
using EluLayoutInfer = ActivationLayoutInfer<vx::ops::Elu>;
using SigmoidLayoutInfer = ActivationLayoutInfer<vx::ops::Sigmoid>;
using MishLayoutInfer = ActivationLayoutInfer<vx::ops::Mish>;
using HardSigmoidLayoutInfer = ActivationLayoutInfer<vx::ops::HardSigmoid>;
using SoftReluLayoutInfer = ActivationLayoutInfer<vx::ops::SoftRelu>;
using HardSwishLayoutInfer = ActivationLayoutInfer<vx::ops::HardSwish>;
using TanhLayoutInfer = ActivationLayoutInfer<vx::ops::Tanh>;
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,62 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_CONCAT_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_CONCAT_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/concat.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
class ConcatLayoutInfer : public OpLayoutInfer {
public:
ConcatLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto required_pv = AlignPermuteVectorForMutilInputs();
auto axis = MapAxis(required_pv->AsStdVec(),
op_->impl()->node()->nn_param.concat.axis);
auto concat = context_->infer_graph_->CreateOperation<vx::ops::Concat>(
axis, op_->impl()->InputsTensor().size());
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*concat).BindInput(context_->GetMapedTensor(i_src));
}
auto out_infer = CreateOutputsTensor(required_pv);
(*concat).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,101 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_CONV2D_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_CONV2D_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/conv2d.h"
#include "src/tim/vx/operation_private.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
namespace tim {
namespace transform {
class Conv2dLayoutInfer : public OpLayoutInfer {
public:
Conv2dLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
vx::DataLayout layout = op_->impl()->layout_;
auto required_pv = MakeShared(4);
if (layout == vx::DataLayout::CWHN) {
required_pv = std::make_shared<PermuteVector<4>>(kCWHN2WHCN);
}
auto input_tensors = op_->impl()->InputsTensor();
// for input and weight
for (uint32_t i = 0; i < 2; i++) {
auto pv = context_->GetPermuteVector(input_tensors[i]);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[i]), final_pv);
context_->UpdateTensorMap(input_tensors[i], perm_out);
context_->SetPermuteVector(input_tensors[i], required_pv);
}
}
auto pad_type = TranslatePadType(op_->impl()->node()->nn_param.conv2d.pad_type);
std::array<uint32_t, 2> ksize = {
op_->impl()->node()->nn_param.conv2d.ksize[0],
op_->impl()->node()->nn_param.conv2d.ksize[1]
};
std::array<uint32_t, 2> stride = {
op_->impl()->node()->nn_param.conv2d.stride[0],
op_->impl()->node()->nn_param.conv2d.stride[1]
};
std::array<uint32_t, 2> dilation = {
op_->impl()->node()->nn_param.conv2d.dilation[0],
op_->impl()->node()->nn_param.conv2d.dilation[1]
};
std::array<uint32_t, 4> pad = {
op_->impl()->node()->nn_param.conv2d.pad[0],
op_->impl()->node()->nn_param.conv2d.pad[1],
op_->impl()->node()->nn_param.conv2d.pad[2],
op_->impl()->node()->nn_param.conv2d.pad[3]
};
int32_t multiplier = op_->impl()->node()->nn_param.conv2d.multiplier;
int32_t out_channels = op_->impl()->node()->nn_param.conv2d.weights;
auto conv2d = context_->infer_graph_->CreateOperation<vx::ops::Conv2d>(
out_channels, pad_type, ksize, stride, dilation, pad, multiplier,
vx::DataLayout::WHCN);
auto otensor_infer = CreateOutputsTensor(required_pv);
(*conv2d).BindInputs({context_->GetMapedTensor(input_tensors[0]),
context_->GetMapedTensor(input_tensors[1]),
context_->GetMapedTensor(input_tensors[2])});
(*conv2d).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,90 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_ElEMENTWISE_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_ElEMENTWISE_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/elementwise.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
template <typename OpType>
class ElementWiseLayoutInfer : public OpLayoutInfer {
public:
ElementWiseLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto required_pv = AlignPermuteVectorForMutilInputs();
auto elementwise = context_->infer_graph_->CreateOperation<OpType>();
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*elementwise).BindInput(context_->GetMapedTensor(i_src));
}
auto out_infer = CreateOutputsTensor(required_pv);
(*elementwise).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
class MultiplyLayoutInfer : public OpLayoutInfer {
public:
MultiplyLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto required_pv = AlignPermuteVectorForMutilInputs();
auto multiply =
context_->infer_graph_->CreateOperation<tim::vx::ops::Multiply>(
op_->impl()->node()->nn_param.multiply.scale);
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*multiply).BindInput(context_->GetMapedTensor(i_src));
}
auto out_infer = CreateOutputsTensor(required_pv);
(*multiply).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
using AddLayoutInfer = ElementWiseLayoutInfer<tim::vx::ops::Add>;
using SubLayoutInfer = ElementWiseLayoutInfer<tim::vx::ops::Sub>;
using DivLayoutInfer = ElementWiseLayoutInfer<tim::vx::ops::Div>;
using PowLayoutInfer = ElementWiseLayoutInfer<tim::vx::ops::Pow>;
using MinimumLayoutInfer = ElementWiseLayoutInfer<tim::vx::ops::Minimum>;
using MaximumLayoutInfer = ElementWiseLayoutInfer<tim::vx::ops::Maximum>;
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,182 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
#include "tim/vx/ops/transpose.h"
#include <algorithm>
#include <vector>
namespace tim {
namespace transform {
void OpLayoutInfer::OnOutputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) {
auto graph_outputs = context_->src_graph_->OutputsTensor();
auto op_outputs = op_->impl()->OutputsTensor();
for (const auto& out : op_outputs) {
if (graph_outputs.end() !=
std::find(graph_outputs.begin(), graph_outputs.end(), out)) {
auto pv = context_->GetPermuteVector(out);
if (!pv->IsAligned()) {
auto perm_out = InsertPermute(context_->GetMapedTensor(out),
pv->Reverse(), true, out);
// Update graph out tensor
context_->UpdateTensorMap(out, perm_out);
}
if (!context_->src_graph_->GetConsumersOp(out).empty()) {
// The tensor is output of graph, but it also is the input of other operations
context_->SetPermuteVector(out, MakeShared(pv->Rank()));
} else {
auto it = std::find(next_tensors.begin(), next_tensors.end(), out);
if (it != next_tensors.end()) {
next_tensors.erase(it);
}
}
}
}
}
std::shared_ptr<vx::Tensor> OpLayoutInfer::InsertPermute(
std::shared_ptr<vx::Tensor> input, std::shared_ptr<IPermuteVector> perm,
bool is_graph_output, std::shared_ptr<vx::Tensor> src_out) {
auto out_spec = input->GetSpec();
if (is_graph_output) {
auto out_shape = src_out->GetShape();
out_spec.SetShape(out_shape);
out_spec.SetAttribute(vx::TensorAttribute::OUTPUT);
} else {
out_spec.SetAttribute(vx::TensorAttribute::TRANSIENT);
}
if (out_spec.quantization_.Type() == vx::QuantType::SYMMETRIC_PER_CHANNEL) {
out_spec.quantization_.SetChannelDim(
MapAxis(perm->AsStdVec(), out_spec.quantization_.ChannelDim()));
}
auto out_tensor = context_->infer_graph_->CreateTensor(out_spec);
auto perm_op =
context_->infer_graph_->CreateOperation<vx::ops::Transpose>(perm->AsStdVec());
(*perm_op).BindInput(input).BindOutput(out_tensor);
return out_tensor;
}
std::vector<std::shared_ptr<vx::Tensor>> OpLayoutInfer::CreateOutputsTensor(
std::shared_ptr<IPermuteVector> required_pv) {
std::vector<std::shared_ptr<vx::Tensor>> ouptuts_tensor;
for (const auto& o : op_->impl()->OutputsTensor()) {
auto in_shape = o->GetShape();
auto out_spec = o->GetSpec();
if (!required_pv->IsAligned()) {
out_spec = out_spec.AsTransientSpec();
}
auto t_infer = context_->infer_graph_->CreateTensor(out_spec);
context_->UpdateTensorMap(o, t_infer);
ouptuts_tensor.push_back(t_infer);
}
return ouptuts_tensor;
}
vx::PadType OpLayoutInfer::TranslatePadType(int32_t pad) {
switch (pad) {
case VSI_NN_PAD_AUTO:
return vx::PadType::AUTO;
case VSI_NN_PAD_VALID:
return vx::PadType::VALID;
case VSI_NN_PAD_SAME:
return vx::PadType::SAME;
default:
return vx::PadType::AUTO;
}
}
vx::PoolType OpLayoutInfer::TranslatePoolType(int32_t pool) {
switch (pool) {
case VX_CONVOLUTIONAL_NETWORK_POOLING_MAX:
return vx::PoolType::MAX;
case VX_CONVOLUTIONAL_NETWORK_POOLING_AVG:
return vx::PoolType::AVG;
case VX_CONVOLUTIONAL_NETWORK_POOLING_L2:
return vx::PoolType::L2;
case VX_CONVOLUTIONAL_NETWORK_POOLING_AVG_ANDROID:
return vx::PoolType::AVG_ANDROID;
default:
return vx::PoolType::MAX;
}
}
vx::RoundType OpLayoutInfer::TranslateRoundType(int32_t round) {
switch (round) {
case VSI_NN_ROUND_CEIL:
return vx::RoundType::CEILING;
case VSI_NN_ROUND_FLOOR:
return vx::RoundType::FLOOR;
default:
return vx::RoundType::FLOOR;
}
}
uint32_t OpLayoutInfer::MapAxis(const std::vector<uint32_t>& perm,
uint32_t axis) {
for (uint32_t i = 0; i < perm.size(); i++) {
if (axis == perm[i]) {
return i;
}
}
VSILOGE("Map axis failed.");
assert(false);
return perm.size() - 1;
}
std::shared_ptr<IPermuteVector>
OpLayoutInfer::AlignPermuteVectorForMutilInputs() {
auto src_inputs = op_->impl()->InputsTensor();
// Suppose the inputs have same dimension rank
// TODO(yzw): should choose a optimal required_pv
auto required_pv = context_->GetPermuteVector(src_inputs[0]);
for (const auto& i_src : src_inputs) {
auto pv = context_->GetPermuteVector(i_src);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(i_src), final_pv);
context_->UpdateTensorMap(i_src, perm_out);
context_->SetPermuteVector(i_src, required_pv);
}
}
return required_pv;
}
void OpLayoutInfer::ReverseInputsPermuteVector() {
for (const auto& i_src : op_->impl()->InputsTensor()) {
auto input_pv = context_->GetPermuteVector(i_src);
if (!input_pv->IsAligned()) {
auto perm_out = InsertPermute(context_->GetMapedTensor(i_src),
input_pv->Reverse());
context_->UpdateTensorMap(i_src, perm_out);
context_->SetPermuteVector(i_src, MakeShared(input_pv->Rank()));
}
}
}
} // namespace transform
} // namespace tim

View File

@ -0,0 +1,74 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_OPS_OP_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_OPS_OP_LAYOUT_INFERENCE_H_
#include <memory>
#include "tim/layout_infer/layout_inference.h"
namespace tim {
namespace transform {
constexpr std::initializer_list<uint32_t> kCWHN2WHCN = {1, 2, 0, 3};
class OpLayoutInfer {
public:
OpLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: op_(op), context_(context) {}
virtual void OnInputs(std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) = 0;
virtual void OnOutputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors);
protected:
std::shared_ptr<vx::Tensor> InsertPermute(std::shared_ptr<vx::Tensor> input,
std::shared_ptr<IPermuteVector> perm,
bool is_graph_output = false,
std::shared_ptr<vx::Tensor> src_out = nullptr);
std::vector<std::shared_ptr<vx::Tensor>> CreateOutputsTensor(
std::shared_ptr<IPermuteVector> required_pv);
vx::PadType TranslatePadType(int32_t pad);
vx::PoolType TranslatePoolType(int32_t pool);
vx::RoundType TranslateRoundType(int32_t round);
uint32_t MapAxis(const std::vector<uint32_t>& perm, uint32_t axis);
std::shared_ptr<IPermuteVector> AlignPermuteVectorForMutilInputs();
void ReverseInputsPermuteVector();
protected:
const std::shared_ptr<vx::Operation> op_;
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context_;
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,85 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_POOL2D_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_POOL2D_LAYOUT_INFERENCE_H_
#include "src/tim/layout_infer/ops/op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
#include "tim/vx/ops/pool2d.h"
namespace tim {
namespace transform {
class Pool2dLayoutInfer : public OpLayoutInfer {
public:
Pool2dLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
vx::DataLayout layout = op_->impl()->layout_;
auto required_pv = MakeShared(4);
if (layout == vx::DataLayout::CWHN) {
required_pv = std::make_shared<PermuteVector<4>>(kCWHN2WHCN);
}
auto input_tensors = op_->impl()->InputsTensor();
auto pv = context_->GetPermuteVector(input_tensors[0]);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
context_->UpdateTensorMap(input_tensors[0], perm_out);
context_->SetPermuteVector(input_tensors[0], required_pv);
}
auto pool_type = TranslatePoolType(op_->impl()->node()->nn_param.pool.type);
auto round_type =
TranslateRoundType(op_->impl()->node()->nn_param.pool.round_type);
auto pad_type =
TranslatePadType(op_->impl()->node()->nn_param.pool.pad_type);
std::array<uint32_t, 2> ksize = {
op_->impl()->node()->nn_param.pool.ksize[0],
op_->impl()->node()->nn_param.pool.ksize[1]};
std::array<uint32_t, 2> stride = {
op_->impl()->node()->nn_param.pool.stride[0],
op_->impl()->node()->nn_param.pool.stride[1]};
auto pool2d = context_->infer_graph_->CreateOperation<vx::ops::Pool2d>(
pool_type, pad_type, ksize, stride, round_type, vx::DataLayout::WHCN);
auto otensor_infer = CreateOutputsTensor(required_pv);
(*pool2d).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*pool2d).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,93 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_REDUCE_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_REDUCE_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/reduce.h"
#include <set>
#include "src/tim/layout_infer/ops/op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
template <typename OpType>
class ReduceLayoutInfer : public OpLayoutInfer {
ReduceLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
auto t_src = op_->impl()->InputsTensor()[0];
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
std::set<int32_t> unique_axis;
std::vector<int32_t> new_axis;
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
++i) {
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
if (axis < 0) {
axis += pv->Rank();
}
unique_axis.insert(axis);
new_axis.push_back(MapAxis(pv->AsStdVec(), axis));
}
auto reduce = context_->infer_graph_->CreateOperation<OpType>(
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
(*reduce).BindInput(context_->GetMapedTensor(t_src));
if (op_->impl()->node()->nn_param.reduce.keep_dim) {
auto otensor_infer = CreateOutputsTensor(pv);
(*reduce).BindOuput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], pv);
} else {
auto out_pv = MakeShared(pv->Rank() - unique_axis.size());
uint32_t j = 0;
for (uint32_t i = 0; i < out_pv->Rank(); i++) {
if (unique_axis.end() != unique_axis.find(pv->At(i))) continue;
uint32_t cnt = 0;
for (auto axis : unique_axis) {
if (pv->At(i) > (uint32_t)axis) cnt++;
}
out_pv->At(j) = pv->At(i) - cnt;
j++;
}
auto otensor_infer = CreateOutputsTensor(out_pv);
(*reduce).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv);
}
next_tensor.push_back(op_->impl()->OutputsTensor()[0]);
}
};
using ReduceMinLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceMin>;
using ReduceMaxLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceMax>;
using ReduceAnyLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceAny>;
using ReduceProdLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceProd>;
using ReduceMeanLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceMean>;
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,67 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_RESHAPE_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_RESHAPE_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/reshape.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
class ReshapeLayoutInfer : public OpLayoutInfer {
public:
ReshapeLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
// reverse any applied permute on it's input tensor
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
ReverseInputsPermuteVector();
std::vector<uint32_t> perm;
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reshape.dim_num;
i++) {
perm.push_back(op_->impl()->node()->nn_param.reshape.size[i]);
}
auto reshape =
context_->infer_graph_->CreateOperation<vx::ops::Reshape>(perm);
(*reshape).BindInput(
context_->GetMapedTensor(op_->impl()->InputsTensor()[0]));
auto required_pv =
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
auto out_infer = CreateOutputsTensor(required_pv);
(*reshape).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,76 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_SIMMPLE_OPS_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_SIMMPLE_OPS_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/simple_operations.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
template <typename OpType>
class SimpleOpsLayoutInfer : public OpLayoutInfer {
public:
SimpleOpsLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
// Transmit input pv to out pv directly for simple ops
assert(op_->impl()->InputsTensor().size() == 1);
auto i_src = op_->impl()->InputsTensor()[0];
auto input_pv = context_->GetPermuteVector(i_src);
auto out_infer = CreateOutputsTensor(input_pv);
auto simple_op = context_->infer_graph_->CreateOperation<OpType>();
(*simple_op)
.BindInput(context_->GetMapedTensor(i_src))
.BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
using DataConvertLayoutInfer = SimpleOpsLayoutInfer<vx::ops::DataConvert>;
using NegLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Neg>;
using AbsLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Abs>;
using SinLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Sin>;
// TODO(yzw): enable it when TIM-VX support 'Cos'
// using CosLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Cos>;
using ExpLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Exp>;
using LogLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Log>;
using SqrtLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Sqrt>;
using RsqrtLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Rsqrt>;
using SquareLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Square>;
using LogicalNotLayoutInfer = SimpleOpsLayoutInfer<vx::ops::LogicalNot>;
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,64 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_SOFTMAXT_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_SOFTMAXT_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/softmax.h"
#include "src/tim/vx/operation_private.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
namespace tim {
namespace transform {
class SoftmaxLayoutInfer : public OpLayoutInfer {
public:
SoftmaxLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto input_tensors = op_->impl()->InputsTensor();
auto required_pv = context_->GetPermuteVector(input_tensors[0]);
float beta = op_->impl()->node()->nn_param.softmax.beta;
int32_t axis = op_->impl()->node()->nn_param.softmax.axis;
axis = MapAxis(required_pv->AsStdVec(), static_cast<uint32_t>(axis));
auto softmax =
context_->infer_graph_->CreateOperation<vx::ops::Softmax>(beta, axis);
auto otensor_infer = CreateOutputsTensor(required_pv);
(*softmax).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*softmax).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,67 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_SQUEEZE_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_SQUEEZE_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/squeeze.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
class SqueezeLayoutInfer : public OpLayoutInfer {
public:
SqueezeLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
// reverse any applied permute on it's input tensor
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
ReverseInputsPermuteVector();
std::vector<uint32_t> axis;
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.squeeze.axis_num;
i++) {
axis.push_back(op_->impl()->node()->nn_param.squeeze.axis[i]);
}
auto squeeze =
context_->infer_graph_->CreateOperation<vx::ops::Squeeze>(axis);
(*squeeze).BindInput(
context_->GetMapedTensor(op_->impl()->InputsTensor()[0]));
auto required_pv =
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
auto out_infer = CreateOutputsTensor(required_pv);
(*squeeze).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,63 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_STACK_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_STACK_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/stack.h"
#include "src/tim/vx/operation_private.h"
#include "src/tim/layout_infer/permute_vector.h"
#include "src/tim/layout_infer/ops/op_layout_inference.h"
namespace tim {
namespace transform {
class StackLayoutInfer : public OpLayoutInfer {
public:
StackLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto input_tensors = op_->impl()->InputsTensor();
auto required_pv = context_->GetPermuteVector(input_tensors[0]);
int32_t axis = op_->impl()->node()->nn_param.stack.axis;
axis = MapAxis(required_pv->AsStdVec(), static_cast<uint32_t>(axis));
auto stack =
context_->infer_graph_->CreateOperation<vx::ops::Stack>(1, axis);
auto otensor_infer = CreateOutputsTensor(required_pv);
(*stack).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*stack).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,223 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_VX_LAYOUT_INFERENCE_PERMUTE_VECTOR_H_
#define TIM_VX_LAYOUT_INFERENCE_PERMUTE_VECTOR_H_
#include <array>
#include <cassert>
#include <memory>
#include <vector>
namespace tim {
namespace transform {
class IPermuteVector;
using IPermuteVectorPtr = std::shared_ptr<IPermuteVector>;
class IPermuteVector {
public:
virtual ~IPermuteVector() = default;
virtual uint32_t Rank() const = 0;
virtual const uint32_t& At(const uint32_t) const = 0;
virtual uint32_t& At(const uint32_t) = 0;
/**
* @brief get Reverse permute vector
*
* PermuteVector + PermuteVector.Reverse() = {0, 1, 2...R}
*
* Data layout = NHWC, current Permute = 0, 3, 1, 2, output layout = NCHW
* its reverse layout is 0, 2, 3, 1
*
* @return PermuteVector<R> reverse permute vector have same rank as current permute
*/
virtual IPermuteVectorPtr Reverse() = 0;
virtual std::string AsText() const = 0;
/**
* @brief apply addtional permute parameter
*
* @detail
* assume data stored as NHWC, this->param_ = {0, 3, 1, 2}
* if apply current permute vector, data stored as NCHW
* other->param_ = {0, 2, 1, 3}
* if apply the addtion permute, data stored as NHCW, current permute paramter become {0, 1,
* 3, 2}
*
* @param other addtional permute vector
* @return PermuteVector result = data.apply_this_permute().apply_other_permute()
*/
virtual IPermuteVectorPtr Add(const IPermuteVectorPtr& other) const = 0;
virtual void ReInitialize() = 0;
virtual bool IsAligned() const = 0;
virtual std::vector<uint32_t> AsStdVec() const = 0;
};
template <uint32_t R>
class PermuteVector : public IPermuteVector {
public:
static constexpr uint32_t kMaxRank = 10;
PermuteVector() {
for (uint32_t i = 0; i < R; ++i) {
param_[i] = i;
}
}
// Copy Constructor
PermuteVector(const PermuteVector& other) : param_(other.param_) {}
PermuteVector& operator=(const PermuteVector& other) {
assert(this != &other);
this->param_ = other.param_;
return *this;
}
// Move Constructor
PermuteVector(PermuteVector&& other) : param_(std::move(other.param_)) {}
PermuteVector& operator=(PermuteVector&& other) {
assert(this != &other);
this->param_ = std::move(other.param_);
return *this;
}
// Initialize list
PermuteVector(std::initializer_list<uint32_t> init_list) {
std::vector<uint32_t> vec(init_list);
assert(vec.size() == R);
for (uint32_t i = 0; i < R; ++i) {
param_[i] = vec[i];
}
}
template <uint32_t S>
explicit PermuteVector(const PermuteVector<S>& smaller) {
// With this: you can construct a PermuteVector with larger Rank from a smaller rank permute
static_assert(S < R, "Cut Permute Vector is not allowed");
for (auto i = 0; i < R; ++i) {
param_[i] = i < S ? smaller[i] : i;
}
}
const uint32_t& At(uint32_t idx) const override { return param_[idx]; }
uint32_t& At(uint32_t idx) override { return param_[idx]; }
uint32_t Rank() const override { return R; }
bool IsAligned() const override {
uint32_t i = 0;
for (; i < R; ++i) {
if (i != param_[i]) break;
}
return i == R;
}
IPermuteVectorPtr Reverse() override {
IPermuteVectorPtr r = std::make_shared<PermuteVector<R>>();
for (uint32_t i = 0; i < R; ++i) {
r->At(param_[i]) = i;
}
return r;
}
void ReInitialize() override {
for (uint32_t i = 0; i < R; ++i) {
param_[i] = i;
}
}
IPermuteVectorPtr Add(const IPermuteVectorPtr& other) const override {
IPermuteVectorPtr r = std::make_shared<PermuteVector<R>>();
for (uint32_t i = 0; i < other->Rank(); ++i) {
r->At(i) = param_[other->At(i)];
}
return r;
}
virtual std::string AsText() const override {
std::string str(R + 1, '\0');
for (uint32_t i = 0; i < R; i++) {
str[i] = (char(param_[i]));
}
return str;
}
virtual std::vector<uint32_t> AsStdVec() const override {
std::vector<uint32_t> data(R);
for (uint32_t i(0); i < R; ++i) {
data[i] = param_[i];
}
return data;
}
private:
std::array<uint32_t, R> param_;
};
/**
* @brief
*
* @param rank_val
* @return IPermuteVectorPtr
*/
inline IPermuteVectorPtr MakeShared(uint32_t rank_val) {
switch (rank_val) {
// 0: represent scalar
case 0:
case 1:
return std::make_shared<PermuteVector<1>>();
case 2:
return std::make_shared<PermuteVector<2>>();
case 3:
return std::make_shared<PermuteVector<3>>();
case 4:
return std::make_shared<PermuteVector<4>>();
case 5:
return std::make_shared<PermuteVector<5>>();
case 6:
return std::make_shared<PermuteVector<6>>();
case 7:
return std::make_shared<PermuteVector<7>>();
case 8:
return std::make_shared<PermuteVector<8>>();
case 9:
return std::make_shared<PermuteVector<9>>();
case 10:
return std::make_shared<PermuteVector<10>>();
default:
assert("Not supported rankVal");
return nullptr;
}
}
} // namespace transform
} // namespace tim
#endif

View File

@ -58,6 +58,50 @@ void GraphImpl::AddOutput(vsi_nn_tensor_id_t id) {
} }
} }
void GraphImpl::AddInput(const std::shared_ptr<Tensor>& tensor) {
if (inputs_tensor_.end() ==
std::find(inputs_tensor_.begin(), inputs_tensor_.end(), tensor)) {
inputs_tensor_.push_back(tensor);
}
}
void GraphImpl::AddOutput(const std::shared_ptr<Tensor>& tensor) {
if (outputs_tensor_.end() ==
std::find(outputs_tensor_.begin(), outputs_tensor_.end(), tensor)) {
outputs_tensor_.push_back(tensor);
}
}
const std::vector<std::shared_ptr<Tensor>> GraphImpl::InputsTensor() const {
return inputs_tensor_;
}
const std::vector<std::shared_ptr<Tensor>> GraphImpl::OutputsTensor() const {
return outputs_tensor_;
}
void GraphImpl::UpdateTensorConsumersMap(const std::shared_ptr<Tensor>& tensor,
const Operation* op) {
for (const auto& added_op : op_vector_) {
if (added_op.get() == op) {
tensor_consumers_[tensor].push_back(added_op);
}
}
}
const std::vector<std::shared_ptr<Operation>> GraphImpl::GetConsumersOp(
std::shared_ptr<Tensor> tensor) const {
auto consumers = tensor_consumers_.find(tensor);
if (tensor_consumers_.end() != consumers) {
return consumers->second;
} else {
VSILOGD("Tensor has no consumers, may be graph output.");
return {};
}
}
void GraphImpl::PrintGraph() const { vsi_nn_PrintGraph(this->graph_); }
std::shared_ptr<Tensor> GraphImpl::CreateTensor(const TensorSpec& spec, std::shared_ptr<Tensor> GraphImpl::CreateTensor(const TensorSpec& spec,
const void* data) { const void* data) {
return std::make_shared<TensorImpl>(this, spec, data); return std::make_shared<TensorImpl>(this, spec, data);

View File

@ -28,6 +28,7 @@
#include <vector> #include <vector>
#include <mutex> #include <mutex>
#include <utility> #include <utility>
#include <map>
#include "tim/vx/tensor.h" #include "tim/vx/tensor.h"
#include "context_private.h" #include "context_private.h"
@ -46,7 +47,18 @@ class GraphImpl : public Graph {
vsi_nn_graph_t* graph(); vsi_nn_graph_t* graph();
void AddInput(vsi_nn_tensor_id_t id); void AddInput(vsi_nn_tensor_id_t id);
void AddOutput(vsi_nn_tensor_id_t id); void AddOutput(vsi_nn_tensor_id_t id);
void AddInput(const std::shared_ptr<Tensor>& tensor);
void AddOutput(const std::shared_ptr<Tensor>& tensor);
const std::vector<std::shared_ptr<Tensor>> InputsTensor() const override;
const std::vector<std::shared_ptr<Tensor>> OutputsTensor() const override;
void UpdateTensorConsumersMap(const std::shared_ptr<Tensor>& tensor,
const Operation* op) override;
const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
std::shared_ptr<Tensor> tensor) const override;
void PrintGraph() const override;
/// Implement parents' virtual functions /// Implement parents' virtual functions
std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec, std::shared_ptr<Tensor> CreateTensor(const TensorSpec& spec,
const void* data = nullptr) override; const void* data = nullptr) override;
@ -65,6 +77,9 @@ class GraphImpl : public Graph {
std::once_flag verify_graph_once_; std::once_flag verify_graph_once_;
std::vector<vsi_nn_tensor_id_t> inputs_; std::vector<vsi_nn_tensor_id_t> inputs_;
std::vector<vsi_nn_tensor_id_t> outputs_; std::vector<vsi_nn_tensor_id_t> outputs_;
std::vector<std::shared_ptr<Tensor>> inputs_tensor_;
std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_;
}; };
} // namespace vx } // namespace vx

View File

@ -32,14 +32,13 @@
namespace tim { namespace tim {
namespace vx { namespace vx {
OperationImpl::OperationImpl(Graph* graph, uint32_t operation_id,
// OperationImpl implementation int input_cnt, int output_cnt, DataLayout layout)
OperationImpl::OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt,
int output_cnt)
: graph_(reinterpret_cast<GraphImpl*>(graph)), : graph_(reinterpret_cast<GraphImpl*>(graph)),
operation_id_(operation_id), operation_id_(operation_id),
input_cnt_(input_cnt), input_cnt_(input_cnt),
output_cnt_(output_cnt), output_cnt_(output_cnt),
layout_(layout),
node_(vsi_nn_AddNode(graph_->graph(), operation_id_, input_cnt_, node_(vsi_nn_AddNode(graph_->graph(), operation_id_, input_cnt_,
output_cnt_, NULL)) { output_cnt_, NULL)) {
SetRoundingPolicy(); SetRoundingPolicy();
@ -47,20 +46,27 @@ OperationImpl::OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt,
} }
OperationImpl& OperationImpl::BindInput(const std::shared_ptr<Tensor>& tensor) { OperationImpl& OperationImpl::BindInput(const std::shared_ptr<Tensor>& tensor) {
inputs_tensor_.push_back(tensor);
uint32_t tensor_id = tensor->GetId(); uint32_t tensor_id = tensor->GetId();
node_->input.tensors[input_tensor_index++] = tensor_id; node_->input.tensors[input_tensor_index++] = tensor_id;
if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) { if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) {
graph_->AddInput(tensor_id); graph_->AddInput(tensor_id);
} }
if (tensor->GetSpec().attr_ & TensorAttribute::INPUT ||
tensor->GetSpec().attr_ & TensorAttribute::CONSTANT) {
graph_->AddInput(tensor);
}
return *this; return *this;
} }
OperationImpl& OperationImpl::BindOutput( OperationImpl& OperationImpl::BindOutput(
const std::shared_ptr<Tensor>& tensor) { const std::shared_ptr<Tensor>& tensor) {
outputs_tensor_.push_back(tensor);
uint32_t tensor_id = tensor->GetId(); uint32_t tensor_id = tensor->GetId();
node_->output.tensors[output_tensor_index++] = tensor_id; node_->output.tensors[output_tensor_index++] = tensor_id;
if (tensor->GetSpec().attr_ == TensorAttribute::OUTPUT) { if (tensor->GetSpec().attr_ == TensorAttribute::OUTPUT) {
graph_->AddOutput(tensor_id); graph_->AddOutput(tensor_id);
graph_->AddOutput(tensor);
} }
return *this; return *this;
} }
@ -78,10 +84,10 @@ OperationImpl& OperationImpl::SetRoundingPolicy(
} }
// Operation implementation // Operation implementation
Operation::Operation(Graph* graph, uint32_t operation_id, int input_cnt, Operation::Operation(Graph* graph, uint32_t operation_id,
int output_cnt) { int input_cnt, int output_cnt, DataLayout layout) {
impl_ = std::make_unique<OperationImpl>(graph, operation_id, input_cnt, impl_ = std::make_unique<OperationImpl>(graph, operation_id,
output_cnt); input_cnt, output_cnt, layout);
} }
Operation::~Operation() {} Operation::~Operation() {}
@ -90,6 +96,7 @@ std::unique_ptr<OperationImpl>& Operation::impl() { return impl_; }
Operation& Operation::BindInput(const std::shared_ptr<Tensor>& tensor) { Operation& Operation::BindInput(const std::shared_ptr<Tensor>& tensor) {
impl_->BindInput(tensor); impl_->BindInput(tensor);
impl_->graph_->UpdateTensorConsumersMap(tensor, this);
return *this; return *this;
} }

View File

@ -30,8 +30,10 @@ namespace tim {
namespace vx { namespace vx {
class OperationImpl { class OperationImpl {
public: public:
// OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt = 0,
// int output_cnt = 0);
OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt = 0, OperationImpl(Graph* graph, uint32_t operation_id, int input_cnt = 0,
int output_cnt = 0); int output_cnt = 0, DataLayout layout = DataLayout::ANY);
~OperationImpl() {} ~OperationImpl() {}
OperationImpl& BindInput(const std::shared_ptr<Tensor>& tensor); OperationImpl& BindInput(const std::shared_ptr<Tensor>& tensor);
@ -45,13 +47,22 @@ class OperationImpl {
vsi_nn_node_t* node() { return this->node_; } vsi_nn_node_t* node() { return this->node_; }
std::vector<std::shared_ptr<Tensor>> InputsTensor() { return inputs_tensor_; }
std::vector<std::shared_ptr<Tensor>> OutputsTensor() {
return outputs_tensor_;
}
GraphImpl* graph_; GraphImpl* graph_;
uint32_t operation_id_{0}; uint32_t operation_id_{0};
int32_t input_cnt_{0}; int32_t input_cnt_{0};
int32_t output_cnt_{0}; int32_t output_cnt_{0};
DataLayout layout_{DataLayout::ANY};
vsi_nn_node_t* node_{nullptr}; vsi_nn_node_t* node_{nullptr};
int32_t input_tensor_index{0}; int32_t input_tensor_index{0};
int32_t output_tensor_index{0}; int32_t output_tensor_index{0};
std::vector<std::shared_ptr<Tensor>> inputs_tensor_;
std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
}; };
} // namespace vx } // namespace vx

View File

@ -34,16 +34,17 @@ namespace ops {
Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding, Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize, const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride, const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier) const std::array<uint32_t, 2>& dilation,
: Conv2d(graph, weights, padding, ksize, stride, dilation, {0, 0, 0, 0}, int32_t multiplier, DataLayout layout)
multiplier) {} : Conv2d(graph, weights, padding, ksize, stride, dilation,
{0, 0, 0, 0}, multiplier, layout) {}
Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding, Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize, const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride, const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, const std::array<uint32_t, 2>& dilation,
const std::array<uint32_t, 4>& pad, int32_t multiplier) const std::array<uint32_t, 4>& pad, int32_t multiplier, DataLayout layout)
: Operation(graph, VSI_NN_OP_CONV2D), : Operation(graph, VSI_NN_OP_CONV2D, 0, 0, layout),
weights_(weights), weights_(weights),
padding_(padding), padding_(padding),
ksize_(ksize), ksize_(ksize),

View File

@ -33,8 +33,9 @@ namespace ops {
Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding, Pool2d::Pool2d(Graph* graph, PoolType type, PadType padding,
const std::array<uint32_t, 2>& ksize, const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride, RoundType round_type) const std::array<uint32_t, 2>& stride, RoundType round_type,
: Operation(graph, VSI_NN_OP_POOL, 1, 1), DataLayout layout)
: Operation(graph, VSI_NN_OP_POOL, 1, 1, layout),
type_(type), type_(type),
padding_(padding), padding_(padding),
ksize_(ksize), ksize_(ksize),

View File

@ -50,6 +50,7 @@ class TensorImpl : public Tensor {
bool IsConstTensor() { bool IsConstTensor() {
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
} }
const void* GetDataRef() const { return data_; }
GraphImpl* graph_; GraphImpl* graph_;
vsi_nn_tensor_id_t id_; vsi_nn_tensor_id_t id_;
@ -73,6 +74,7 @@ class TensorPlaceholder : public Tensor {
bool IsConstTensor() { bool IsConstTensor() {
return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT; return spec_.attr_ == tim::vx::TensorAttribute::CONSTANT;
} }
const void* GetDataRef() const { return nullptr; }
vsi_nn_tensor_id_t id_; vsi_nn_tensor_id_t id_;
TensorSpec spec_; TensorSpec spec_;