Optimize permute op for constant tensor (#37)
Signed-off-by: Zongwu Yang <zongwu.yang@verisilicon.com>
This commit is contained in:
parent
5cfa7a2c40
commit
22d423714f
3
BUILD
3
BUILD
|
|
@ -32,7 +32,8 @@ cc_library(
|
|||
"src/tim/vx/type_utils.h",
|
||||
"src/tim/vx/type_utils.cc",
|
||||
"src/tim/layout_infer/layout_inference.cc",
|
||||
"src/tim/layout_infer/permute_vector.h"
|
||||
"src/tim/layout_infer/permute_vector.h",
|
||||
"src/tim/layout_infer/layout_infer_context.h",
|
||||
] + glob([
|
||||
"src/tim/vx/ops/*.cc"
|
||||
]) + glob(["src/tim/layout_infer/ops/*.*"]),
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ project(tim-vx)
|
|||
set(CMAKE_C_FLAGS "-Wall -Wextra -Wno-unused-parameter -Wno-sign-compare -Werror -fPIC -Wno-enum-conversion")
|
||||
set(CMAKE_CXX_FLAGS "--std=c++14 -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare -Werror -fPIC")
|
||||
|
||||
OPTION(ENABLE_LAYOUT_INFER "add layout inference support" ON)
|
||||
|
||||
set(CMAKE_C_VISIBILITY_PRESET hidden)
|
||||
set(OVXLIB_API_ATTR "__attribute__\(\(visibility\(\"default\"\)\)\)")
|
||||
message(${OVXLIB_API_ATTR})
|
||||
|
|
@ -22,5 +24,5 @@ endif()
|
|||
include_directories(${PROJECT_SOURCE_DIR}/include/tim/vx)
|
||||
include_directories(${OVXDRV_INCLUDE_DIRS})
|
||||
|
||||
add_subdirectory("src/tim/vx")
|
||||
add_subdirectory("src/tim/")
|
||||
add_subdirectory("samples/lenet")
|
||||
|
|
|
|||
|
|
@ -28,56 +28,10 @@
|
|||
#include <vector>
|
||||
|
||||
#include "tim/vx/context.h"
|
||||
#include "tim/vx/operation.h"
|
||||
#include "src/tim/layout_infer/permute_vector.h"
|
||||
#include "tim/vx/graph.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
namespace layout_inference_impl {
|
||||
class LayoutInferContext {
|
||||
public:
|
||||
LayoutInferContext(const std::shared_ptr<vx::Graph>& src_graph,
|
||||
std::shared_ptr<vx::Graph>& infer_graph)
|
||||
: src_graph_(src_graph), infer_graph_(infer_graph) {}
|
||||
void SetPermuteVector(std::shared_ptr<vx::Tensor> tensor,
|
||||
std::shared_ptr<IPermuteVector> pv);
|
||||
const std::shared_ptr<IPermuteVector> GetPermuteVector(
|
||||
const std::shared_ptr<vx::Tensor>& tensor) const;
|
||||
void MarkVisited(const std::shared_ptr<vx::Operation>& op);
|
||||
bool IsVisited(const std::shared_ptr<vx::Operation>& op) const;
|
||||
bool IsReadyForInfer(const std::shared_ptr<vx::Operation>& op) const;
|
||||
void UpdateTensorMap(const std::shared_ptr<vx::Tensor>& t_src,
|
||||
const std::shared_ptr<vx::Tensor>& t_layout);
|
||||
std::shared_ptr<vx::Tensor> GetMapedTensor(
|
||||
const std::shared_ptr<vx::Tensor>& t_src) const;
|
||||
|
||||
void UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
|
||||
const std::shared_ptr<vx::Tensor>& i_layout);
|
||||
|
||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||
GetGraphInputMap() const {
|
||||
return graph_input_map_;
|
||||
}
|
||||
|
||||
const std::shared_ptr<vx::Graph>& src_graph_;
|
||||
std::shared_ptr<vx::Graph>& infer_graph_;
|
||||
|
||||
private:
|
||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<IPermuteVector>>
|
||||
tensor_pv_;
|
||||
std::vector<std::shared_ptr<vx::Operation>> visited_op_;
|
||||
// tensor_in_src -> tensor_in_layout
|
||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||
tensor_map_;
|
||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||
graph_input_map_;
|
||||
};
|
||||
|
||||
std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
|
||||
std::shared_ptr<layout_inference_impl::LayoutInferContext>& ctx,
|
||||
const std::shared_ptr<vx::Operation>& op);
|
||||
} // namespace layout_inference_impl
|
||||
|
||||
std::pair<std::shared_ptr<vx::Graph>, /* infer graph */
|
||||
std::map<std::shared_ptr<vx::Tensor>,
|
||||
std::shared_ptr<vx::Tensor>> /* graph io tensor map */>
|
||||
|
|
|
|||
|
|
@ -97,6 +97,14 @@ struct TensorSpec {
|
|||
this->quantization_ = other.quantization_;
|
||||
}
|
||||
|
||||
TensorSpec& operator =(const TensorSpec& other) {
|
||||
this->datatype_ = other.datatype_;
|
||||
this->shape_ = other.shape_;
|
||||
this->attr_ = other.attr_;
|
||||
this->quantization_ = other.quantization_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorSpec& SetDataType(DataType datatype) {
|
||||
this->datatype_ = datatype;
|
||||
return *this;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
message("src/tim/vx")
|
||||
|
||||
set(TARGET_NAME "tim-vx")
|
||||
|
||||
add_subdirectory("vx/internal")
|
||||
|
||||
aux_source_directory(./vx VX_SRC)
|
||||
aux_source_directory(./vx/ops OPS_SRC)
|
||||
|
||||
set(SRC)
|
||||
list(APPEND SRC
|
||||
${VX_SRC}
|
||||
${OPS_SRC}
|
||||
)
|
||||
list(REMOVE_ITEM SRC ./vx/context_test.cc)
|
||||
list(REMOVE_ITEM SRC ./vx/graph_test.cc)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include/tim/vx)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/src/tim/vx)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/src/tim/vx/internal/include)
|
||||
|
||||
if(ENABLE_LAYOUT_INFER)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/)
|
||||
|
||||
aux_source_directory(./layout_infer LAYOUT_INFER_FRAMEWORK_SRCS)
|
||||
aux_source_directory(./layout_infer/ops LAYOUT_INFER_OP_SRCS)
|
||||
|
||||
list(APPEND SRC
|
||||
${LAYOUT_INFER_FRAMEWORK_SRCS}
|
||||
${LAYOUT_INFER_OP_SRCS}
|
||||
)
|
||||
list(REMOVE_ITEM SRC ./layout_infer/layout_inference_test.cc)
|
||||
endif()
|
||||
|
||||
add_library(${TARGET_NAME} SHARED ${SRC})
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE
|
||||
-Wl,--whole-archive tim_internal -Wl,--no-whole-archive)
|
||||
|
||||
add_library(${TARGET_NAME}-static STATIC ${SRC})
|
||||
target_link_libraries(${TARGET_NAME}-static PRIVATE
|
||||
-Wl,--whole-archive tim_internal -Wl,--no-whole-archive)
|
||||
|
||||
install(TARGETS ${TARGET_NAME} ${TARGET_NAME}-static
|
||||
DESTINATION ${CMAKE_BINARY_DIR}/install/lib)
|
||||
|
||||
install(DIRECTORY ${CMAKE_SOURCE_DIR}/include/tim/vx DESTINATION ${CMAKE_BINARY_DIR}/install/include/tim/)
|
||||
|
||||
if(ENABLE_LAYOUT_INFER)
|
||||
install(DIRECTORY ${CMAKE_SOURCE_DIR}/include/tim/layout_infer DESTINATION ${CMAKE_BINARY_DIR}/install/include/tim/)
|
||||
endif()
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
#ifndef TIM_VX_LAYOUT_INFER_CONTEXT_H_
|
||||
#define TIM_VX_LAYOUT_INFER_CONTEXT_H_
|
||||
#include "permute_vector.h"
|
||||
#include "tim/layout_infer/layout_inference.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
namespace layout_inference_impl {
|
||||
class LayoutInferContext {
|
||||
public:
|
||||
LayoutInferContext(const std::shared_ptr<vx::Graph>& src_graph,
|
||||
std::shared_ptr<vx::Graph>& infer_graph)
|
||||
: src_graph_(src_graph), infer_graph_(infer_graph) {}
|
||||
void SetPermuteVector(std::shared_ptr<vx::Tensor> tensor,
|
||||
std::shared_ptr<IPermuteVector> pv);
|
||||
const std::shared_ptr<IPermuteVector> GetPermuteVector(
|
||||
const std::shared_ptr<vx::Tensor>& tensor) const;
|
||||
void MarkVisited(const std::shared_ptr<vx::Operation>& op);
|
||||
bool IsVisited(const std::shared_ptr<vx::Operation>& op) const;
|
||||
bool IsReadyForInfer(const std::shared_ptr<vx::Operation>& op) const;
|
||||
void UpdateTensorMap(const std::shared_ptr<vx::Tensor>& t_src,
|
||||
const std::shared_ptr<vx::Tensor>& t_layout);
|
||||
std::shared_ptr<vx::Tensor> GetMapedTensor(
|
||||
const std::shared_ptr<vx::Tensor>& t_src) const;
|
||||
|
||||
void UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
|
||||
const std::shared_ptr<vx::Tensor>& i_layout);
|
||||
|
||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||
GetGraphInputMap() const {
|
||||
return graph_input_map_;
|
||||
}
|
||||
|
||||
const std::shared_ptr<vx::Graph>& src_graph_;
|
||||
std::shared_ptr<vx::Graph>& infer_graph_;
|
||||
|
||||
private:
|
||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<IPermuteVector>>
|
||||
tensor_pv_;
|
||||
std::vector<std::shared_ptr<vx::Operation>> visited_op_;
|
||||
// tensor_in_src -> tensor_in_layout
|
||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||
tensor_map_;
|
||||
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
|
||||
graph_input_map_;
|
||||
};
|
||||
|
||||
} // namespace layout_inference_impl
|
||||
} // namespace transform
|
||||
} // namespace tim
|
||||
|
||||
#endif
|
||||
|
|
@ -22,20 +22,20 @@
|
|||
*
|
||||
*****************************************************************************/
|
||||
|
||||
#include "src/tim/layout_infer/permute_vector.h"
|
||||
#include "permute_vector.h"
|
||||
#include "layout_infer_context.h"
|
||||
#include "tim/layout_infer/layout_inference.h"
|
||||
#include "src/tim/vx/operation_private.h"
|
||||
#include "src/tim/layout_infer/ops/conv2d_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/reduce_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/elementwise_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/activation_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/concat_layout_inferene.h"
|
||||
#include "src/tim/layout_infer/ops/reshape_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/simple_ops_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/pool2d_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/softmax_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/squeeze_layout_inference.h"
|
||||
#include "src/tim/layout_infer/ops/stack_layout_inference.h"
|
||||
#include "ops/conv2d_layout_inference.h"
|
||||
#include "ops/reduce_layout_inference.h"
|
||||
#include "ops/elementwise_layout_inference.h"
|
||||
#include "ops/activation_layout_inference.h"
|
||||
#include "ops/concat_layout_inferene.h"
|
||||
#include "ops/reshape_layout_inference.h"
|
||||
#include "ops/simple_ops_layout_inference.h"
|
||||
#include "ops/pool2d_layout_inference.h"
|
||||
#include "ops/softmax_layout_inference.h"
|
||||
#include "ops/squeeze_layout_inference.h"
|
||||
#include "ops/stack_layout_inference.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
|
|
@ -43,6 +43,11 @@
|
|||
namespace tim {
|
||||
namespace transform {
|
||||
namespace layout_inference_impl {
|
||||
|
||||
std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
|
||||
std::shared_ptr<layout_inference_impl::LayoutInferContext>& ctx,
|
||||
const std::shared_ptr<vx::Operation>& op);
|
||||
|
||||
// Implemention for LayoutInferContext
|
||||
void LayoutInferContext::SetPermuteVector(std::shared_ptr<vx::Tensor> tensor,
|
||||
std::shared_ptr<IPermuteVector> pv) {
|
||||
|
|
@ -84,7 +89,8 @@ bool LayoutInferContext::IsVisited(const std::shared_ptr<vx::Operation>& op) con
|
|||
bool LayoutInferContext::IsReadyForInfer(
|
||||
const std::shared_ptr<vx::Operation>& op) const {
|
||||
for (const auto& tensor : op->impl()->InputsTensor()) {
|
||||
if (tensor_pv_.end() == tensor_pv_.find(tensor)) {
|
||||
if (!tensor->IsConstTensor() &&
|
||||
(tensor_pv_.end() == tensor_pv_.find(tensor))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -187,16 +193,10 @@ std::pair<std::shared_ptr<vx::Graph>,
|
|||
std::deque<std::shared_ptr<vx::Tensor>> tensor_queue;
|
||||
auto graph_inputs = src_graph->InputsTensor();
|
||||
for (const auto& t_src : graph_inputs) {
|
||||
if (t_src->IsConstTensor()) {
|
||||
layout_infer_ctx->UpdateTensorMap(
|
||||
t_src,
|
||||
infer_graph->CreateTensor(t_src->GetSpec(), t_src->GetDataRef()));
|
||||
} else {
|
||||
auto input = infer_graph->CreateTensor(t_src->GetSpec());
|
||||
layout_infer_ctx->UpdateTensorMap(t_src, input);
|
||||
layout_infer_ctx->UpdateGraphInputMap(t_src, input);
|
||||
tensor_queue.push_back(t_src);
|
||||
}
|
||||
auto input = infer_graph->CreateTensor(t_src->GetSpec());
|
||||
layout_infer_ctx->UpdateTensorMap(t_src, input);
|
||||
layout_infer_ctx->UpdateGraphInputMap(t_src, input);
|
||||
tensor_queue.push_back(t_src);
|
||||
layout_infer_ctx->SetPermuteVector(t_src,
|
||||
MakeShared(t_src->GetShape().size()));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,16 +48,41 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
|
|||
}
|
||||
auto input_tensors = op_->impl()->InputsTensor();
|
||||
|
||||
// for input and weight
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
auto pv = context_->GetPermuteVector(input_tensors[i]);
|
||||
auto final_pv = pv->Reverse()->Add(required_pv);
|
||||
if (!final_pv->IsAligned()) {
|
||||
auto perm_out =
|
||||
InsertPermute(context_->GetMapedTensor(input_tensors[i]), final_pv);
|
||||
context_->UpdateTensorMap(input_tensors[i], perm_out);
|
||||
context_->SetPermuteVector(input_tensors[i], required_pv);
|
||||
for (const auto& in : input_tensors) {
|
||||
std::shared_ptr<vx::Tensor> infer_tensor;
|
||||
std::shared_ptr<IPermuteVector> trans_pv;
|
||||
if (in->IsConstTensor() &&
|
||||
!(in->GetSpec().attr_ & vx::TensorAttribute::INPUT)) {
|
||||
// For bias
|
||||
if (in->GetShape().size() == 1) {
|
||||
infer_tensor = context_->infer_graph_->CreateTensor(
|
||||
in->GetSpec(), in->GetDataRef());
|
||||
trans_pv = MakeShared(1);
|
||||
} else {
|
||||
// For input/weight
|
||||
if (!required_pv->IsAligned()) {
|
||||
infer_tensor = PermuteConstTensor(in, required_pv);
|
||||
trans_pv = required_pv;
|
||||
} else {
|
||||
infer_tensor = context_->infer_graph_->CreateTensor(
|
||||
in->GetSpec(), in->GetDataRef());
|
||||
trans_pv = MakeShared(required_pv->Rank());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For input/weight
|
||||
auto pv = context_->GetPermuteVector(in);
|
||||
auto final_pv = pv->Reverse()->Add(required_pv);
|
||||
if (!final_pv->IsAligned()) {
|
||||
infer_tensor = InsertPermute(context_->GetMapedTensor(in), final_pv);
|
||||
trans_pv = required_pv;
|
||||
} else {
|
||||
infer_tensor = context_->GetMapedTensor(in);
|
||||
trans_pv = pv;
|
||||
}
|
||||
}
|
||||
context_->UpdateTensorMap(in, infer_tensor);
|
||||
context_->SetPermuteVector(in, trans_pv);
|
||||
}
|
||||
|
||||
auto pad_type = TranslatePadType(op_->impl()->node()->nn_param.conv2d.pad_type);
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
#include "src/tim/layout_infer/permute_vector.h"
|
||||
#include "src/tim/vx/operation_private.h"
|
||||
#include "tim/vx/ops/transpose.h"
|
||||
#include "src/tim/vx/type_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
|
@ -155,11 +156,15 @@ OpLayoutInfer::AlignPermuteVectorForMutilInputs() {
|
|||
// TODO(yzw): should choose a optimal required_pv
|
||||
auto required_pv = context_->GetPermuteVector(src_inputs[0]);
|
||||
for (const auto& i_src : src_inputs) {
|
||||
std::shared_ptr<vx::Tensor> perm_out;
|
||||
auto pv = context_->GetPermuteVector(i_src);
|
||||
auto final_pv = pv->Reverse()->Add(required_pv);
|
||||
if (!final_pv->IsAligned()) {
|
||||
auto perm_out =
|
||||
InsertPermute(context_->GetMapedTensor(i_src), final_pv);
|
||||
if (i_src->IsConstTensor()) {
|
||||
perm_out = PermuteConstTensor(i_src, final_pv);
|
||||
} else {
|
||||
perm_out = InsertPermute(context_->GetMapedTensor(i_src), final_pv);
|
||||
}
|
||||
context_->UpdateTensorMap(i_src, perm_out);
|
||||
context_->SetPermuteVector(i_src, required_pv);
|
||||
}
|
||||
|
|
@ -169,14 +174,65 @@ OpLayoutInfer::AlignPermuteVectorForMutilInputs() {
|
|||
|
||||
void OpLayoutInfer::ReverseInputsPermuteVector() {
|
||||
for (const auto& i_src : op_->impl()->InputsTensor()) {
|
||||
std::shared_ptr<vx::Tensor> perm_out;
|
||||
auto input_pv = context_->GetPermuteVector(i_src);
|
||||
if (!input_pv->IsAligned()) {
|
||||
auto perm_out = InsertPermute(context_->GetMapedTensor(i_src),
|
||||
input_pv->Reverse());
|
||||
if (i_src->IsConstTensor()) {
|
||||
perm_out = PermuteConstTensor(i_src, input_pv);
|
||||
} else {
|
||||
perm_out =
|
||||
InsertPermute(context_->GetMapedTensor(i_src), input_pv->Reverse());
|
||||
}
|
||||
context_->UpdateTensorMap(i_src, perm_out);
|
||||
context_->SetPermuteVector(i_src, MakeShared(input_pv->Rank()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool OpLayoutInfer::TransposeConstTensorData(
|
||||
const std::shared_ptr<vx::Tensor>& input,
|
||||
const std::shared_ptr<IPermuteVector>& pv, std::vector<uint8_t>& out_data) {
|
||||
auto vx_type = vx::TranslateDataType(input->GetDataType());
|
||||
auto type_size = vsi_nn_GetTypeBytes(vx_type);
|
||||
uint32_t out_size = 1;
|
||||
for (const auto& s : input->GetShape()) out_size *= s;
|
||||
out_size *= type_size;
|
||||
out_data.resize(out_size);
|
||||
if (!input->GetDataRef()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> perm = KOcHWIc2OcIcHW;
|
||||
vx::ShapeType reverse_shape;
|
||||
for (int32_t i = input->GetShape().size() - 1; i >= 0; i--) {
|
||||
reverse_shape.push_back(input->GetShape()[i]);
|
||||
}
|
||||
|
||||
vsi_nn_Transpose(out_data.data(), (uint8_t*)(input->GetDataRef()),
|
||||
(uint32_t*)(reverse_shape.data()),
|
||||
static_cast<uint32_t>(input->GetShape().size()),
|
||||
perm.data(), vx_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<vx::Tensor> OpLayoutInfer::PermuteConstTensor(
|
||||
const std::shared_ptr<vx::Tensor>& input,
|
||||
const std::shared_ptr<IPermuteVector>& pv) {
|
||||
std::vector<uint8_t> data;
|
||||
bool is_ok = TransposeConstTensorData(input, pv, data);
|
||||
assert(is_ok);
|
||||
auto src_shape = input->GetShape();
|
||||
auto dst_spec = input->GetSpec();
|
||||
vx::ShapeType dst_shape;
|
||||
for (uint32_t i = 0; i < src_shape.size(); i++) {
|
||||
dst_shape.push_back(src_shape[pv->AsStdVec()[i]]);
|
||||
}
|
||||
dst_spec.SetShape(dst_shape);
|
||||
if (dst_spec.quantization_.Type() == vx::QuantType::SYMMETRIC_PER_CHANNEL) {
|
||||
dst_spec.quantization_.SetChannelDim(
|
||||
MapAxis(pv->AsStdVec(), dst_spec.quantization_.ChannelDim()));
|
||||
}
|
||||
return context_->infer_graph_->CreateTensor(dst_spec, data.data());
|
||||
}
|
||||
} // namespace transform
|
||||
} // namespace tim
|
||||
|
|
@ -26,12 +26,15 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "../layout_infer_context.h"
|
||||
#include "tim/layout_infer/layout_inference.h"
|
||||
#include "tim/vx/types.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
||||
constexpr std::initializer_list<uint32_t> kCWHN2WHCN = {1, 2, 0, 3};
|
||||
constexpr std::initializer_list<uint32_t> KOcHWIc2OcIcHW = {0, 3, 1, 2};
|
||||
|
||||
class OpLayoutInfer {
|
||||
public:
|
||||
|
|
@ -43,6 +46,8 @@ class OpLayoutInfer {
|
|||
virtual void OnOutputs(
|
||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors);
|
||||
|
||||
virtual ~OpLayoutInfer() = default;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<vx::Tensor> InsertPermute(std::shared_ptr<vx::Tensor> input,
|
||||
std::shared_ptr<IPermuteVector> perm,
|
||||
|
|
@ -63,6 +68,14 @@ class OpLayoutInfer {
|
|||
|
||||
void ReverseInputsPermuteVector();
|
||||
|
||||
bool TransposeConstTensorData(const std::shared_ptr<vx::Tensor>& input,
|
||||
const std::shared_ptr<IPermuteVector>& pv,
|
||||
std::vector<uint8_t>& out_data);
|
||||
|
||||
std::shared_ptr<vx::Tensor> PermuteConstTensor(
|
||||
const std::shared_ptr<vx::Tensor>& input,
|
||||
const std::shared_ptr<IPermuteVector>& pv);
|
||||
|
||||
protected:
|
||||
const std::shared_ptr<vx::Operation> op_;
|
||||
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context_;
|
||||
|
|
|
|||
|
|
@ -41,16 +41,14 @@ class StackLayoutInfer : public OpLayoutInfer {
|
|||
: OpLayoutInfer(op, context) {}
|
||||
void OnInputs(
|
||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||
auto input_tensors = op_->impl()->InputsTensor();
|
||||
auto required_pv = context_->GetPermuteVector(input_tensors[0]);
|
||||
ReverseInputsPermuteVector();
|
||||
int32_t axis = op_->impl()->node()->nn_param.stack.axis;
|
||||
axis = MapAxis(required_pv->AsStdVec(), static_cast<uint32_t>(axis));
|
||||
|
||||
auto stack =
|
||||
context_->infer_graph_->CreateOperation<vx::ops::Stack>(1, axis);
|
||||
auto otensor_infer = CreateOutputsTensor(required_pv);
|
||||
(*stack).BindInput(context_->GetMapedTensor(input_tensors[0]));
|
||||
(*stack).BindOutput(otensor_infer[0]);
|
||||
auto stack = context_->infer_graph_->CreateOperation<vx::ops::Stack>(
|
||||
axis, op_->impl()->input_cnt_);
|
||||
(*stack).BindInput(context_->GetMapedTensor(op_->impl()->InputsTensor()[0]));
|
||||
auto required_pv = MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
|
||||
auto out_infer = CreateOutputsTensor(required_pv);
|
||||
(*stack).BindOutput(out_infer[0]);
|
||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
|
||||
// Add out tensor of src_graph into next_tensor
|
||||
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
message("src/tim/vx")
|
||||
|
||||
set(TARGET_NAME "tim-vx")
|
||||
|
||||
add_subdirectory("internal")
|
||||
|
||||
aux_source_directory(. VX_SRC)
|
||||
aux_source_directory(ops OPS_SRC)
|
||||
|
||||
set(SRC)
|
||||
list(APPEND SRC
|
||||
${VX_SRC}
|
||||
${OPS_SRC}
|
||||
)
|
||||
list(REMOVE_ITEM SRC ./context_test.cc)
|
||||
list(REMOVE_ITEM SRC ./graph_test.cc)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include/tim/vx)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/src/tim/vx)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/src/tim/vx/internal/include)
|
||||
|
||||
add_library(${TARGET_NAME} SHARED ${SRC})
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE
|
||||
-Wl,--whole-archive tim_internal -Wl,--no-whole-archive)
|
||||
|
||||
add_library(${TARGET_NAME}-static STATIC ${SRC})
|
||||
target_link_libraries(${TARGET_NAME}-static PRIVATE
|
||||
-Wl,--whole-archive tim_internal -Wl,--no-whole-archive)
|
||||
|
||||
install(TARGETS ${TARGET_NAME} ${TARGET_NAME}-static
|
||||
DESTINATION ${CMAKE_BINARY_DIR}/install/lib)
|
||||
|
||||
install(DIRECTORY ${CMAKE_SOURCE_DIR}/include DESTINATION ${CMAKE_BINARY_DIR}/install)
|
||||
|
|
@ -51,9 +51,6 @@ OperationImpl& OperationImpl::BindInput(const std::shared_ptr<Tensor>& tensor) {
|
|||
node_->input.tensors[input_tensor_index++] = tensor_id;
|
||||
if (tensor->GetSpec().attr_ & TensorAttribute::INPUT) {
|
||||
graph_->AddInput(tensor_id);
|
||||
}
|
||||
if (tensor->GetSpec().attr_ & TensorAttribute::INPUT ||
|
||||
tensor->GetSpec().attr_ & TensorAttribute::CONSTANT) {
|
||||
graph_->AddInput(tensor);
|
||||
}
|
||||
return *this;
|
||||
|
|
|
|||
Loading…
Reference in New Issue