diff --git a/include/tim/vx/ops/resize.h b/include/tim/vx/ops/resize.h index 25d41a6..4a3ca80 100644 --- a/include/tim/vx/ops/resize.h +++ b/include/tim/vx/ops/resize.h @@ -32,7 +32,8 @@ namespace ops { class Resize : public Operation { public: Resize(Graph* graph, ResizeType type, float factor, bool align_corners, - bool half_pixel_centers, int target_height, int target_width); + bool half_pixel_centers, int target_height, int target_width, + DataLayout layout = DataLayout::WHCN); protected: const ResizeType type_; diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index 7bfc135..b2b252d 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -42,6 +42,8 @@ #include "ops/batch2space_layout_inference.h" #include "ops/pad_layout_inference.h" #include "ops/reduce_layout_inference.h" +#include "ops/fullyconnected_layout_inference.h" +#include "ops/resize_layout_inference.h" #include #include @@ -198,7 +200,8 @@ std::vector> HandleLayoutInfer( REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH2SPACE, BatchToSpace); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PAD, Pad); REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE); - + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_FCL2, FullyConnected); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RESIZE, Resize); default: VSILOGW("Op %d: Default layout inference pass.", op_id); assert(false); diff --git a/src/tim/transform/ops/batch2space_layout_inference.h b/src/tim/transform/ops/batch2space_layout_inference.h index 65fab8d..85e952d 100644 --- a/src/tim/transform/ops/batch2space_layout_inference.h +++ b/src/tim/transform/ops/batch2space_layout_inference.h @@ -67,7 +67,7 @@ class BatchToSpaceLayoutInfer : public OpLayoutInfer { sizeof(int) * 4); auto batch2space = - context_->infer_graph_->CreateOperation( + context_->infer_graph_->CreateOperation( block_size, crop, vx::DataLayout::WHCN); auto out_tensor_infer = CreateOutputsTensor(required_pv); (*batch2space).BindInput(context_->GetMapedTensor(input_tensors[0])); diff --git a/src/tim/transform/ops/fullyconnected_layout_inference.h b/src/tim/transform/ops/fullyconnected_layout_inference.h new file mode 100644 index 0000000..c329630 --- /dev/null +++ b/src/tim/transform/ops/fullyconnected_layout_inference.h @@ -0,0 +1,76 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_FULLYCONNECTED_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_FULLYCONNECTED_LAYOUT_INFERENCE_H_ + +#include "tim/vx/ops/fullyconnected.h" + +#include "src/tim/transform/ops/op_layout_inference.h" +#include "src/tim/transform/permute_vector.h" +#include "src/tim/vx/operation_private.h" + +namespace tim { +namespace transform { +class FullyConnectedLayoutInfer : public OpLayoutInfer { + public: + FullyConnectedLayoutInfer( + const std::shared_ptr op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + + void OnInputs( + std::vector>& next_tensors) override { + + auto input_tensors = op_->impl()->InputsTensor(); + for (const auto& in : input_tensors) { + if (in->IsConstTensor()) { + auto infer_tensor = context_->infer_graph_->CreateTensor(in->GetSpec(), + in->GetDataRef()); + auto trans_pv = MakeShared(in->GetShape().size()); + + context_->UpdateTensorMap(in, infer_tensor); + context_->SetPermuteVector(in, trans_pv); + } + } + uint32_t axis = op_->impl()->node()->nn_param.fcl.axis; + uint32_t weight = op_->impl()->node()->nn_param.fcl.weights; + + auto fcl = context_->infer_graph_->CreateOperation( + axis, weight); + auto required_pv = + MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size()); + auto out_infer = CreateOutputsTensor(required_pv); + (*fcl) + .BindInputs({context_->GetMapedTensor(op_->impl()->InputsTensor()[0]), + context_->GetMapedTensor(op_->impl()->InputsTensor()[1]), + context_->GetMapedTensor(op_->impl()->InputsTensor()[2])}) + .BindOutput(out_infer[0]); + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; + +} // namespace transform +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/transform/ops/resize_layout_inference.h b/src/tim/transform/ops/resize_layout_inference.h new file mode 100644 index 0000000..ca71cec --- /dev/null +++ b/src/tim/transform/ops/resize_layout_inference.h @@ -0,0 +1,83 @@ +/**************************************************************************** + * + * Copyright (c) 2020 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_RESIZE_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_RESIZE_LAYOUT_INFERENCE_H_ + +#include "tim/vx/ops/resize.h" + +#include "src/tim/transform/ops/op_layout_inference.h" +#include "src/tim/transform/permute_vector.h" +#include "src/tim/vx/operation_private.h" +namespace tim { +namespace transform { +class ResizeLayoutInfer : public OpLayoutInfer { + public: + ResizeLayoutInfer( + const std::shared_ptr op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + + void OnInputs( + std::vector>& next_tensors) override { + assert(op_->impl()->InputsTensor().size() == 1); + vx::DataLayout layout = op_->impl()->layout_; + auto required_pv = MakeShared(4); + if (layout == vx::DataLayout::CWHN) { + required_pv = std::make_shared>(kCWHN2WHCN); + } + auto i_src = op_->impl()->InputsTensor()[0]; + auto input_pv = context_->GetPermuteVector(i_src); + auto final_pv = input_pv->Reverse()->Add(required_pv); + + if (!final_pv->IsAligned()) { + auto perm_out = InsertPermute(i_src, final_pv); + context_->UpdateTensorMap(i_src, perm_out); + context_->SetPermuteVector(i_src, final_pv); + } + + auto resize_type = + static_cast(op_->impl()->node()->nn_param.resize.type); + auto factor = op_->impl()->node()->nn_param.resize.factor; + auto aglin_corners = op_->impl()->node()->nn_param.resize.align_corners; + auto half_pixel_centers = + op_->impl()->node()->nn_param.resize.half_pixel_centers; + auto target_width = op_->impl()->node()->nn_param.resize.size[0]; + auto target_height = op_->impl()->node()->nn_param.resize.size[1]; + + auto resize = context_->infer_graph_->CreateOperation( + resize_type, factor, aglin_corners, half_pixel_centers, target_height, + target_width); + + auto out_infer = CreateOutputsTensor(required_pv); + (*resize).BindInput(context_->GetMapedTensor(i_src)); + (*resize).BindOutput(out_infer[0]); + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv); + next_tensors.push_back(op_->impl()->OutputsTensor()[0]); + } +}; + +} // namespace transform +} // namespace tim + +#endif \ No newline at end of file diff --git a/src/tim/transform/ops/space2batch_layout_inference.h b/src/tim/transform/ops/space2batch_layout_inference.h index 4d1886f..2a99b05 100644 --- a/src/tim/transform/ops/space2batch_layout_inference.h +++ b/src/tim/transform/ops/space2batch_layout_inference.h @@ -67,7 +67,7 @@ class SpaceToBatchLayoutInfer : public OpLayoutInfer { sizeof(int) * 4); auto space2batch = - context_->infer_graph_->CreateOperation( + context_->infer_graph_->CreateOperation( block_size, pad, vx::DataLayout::WHCN); auto out_tensor_infer = CreateOutputsTensor(required_pv); (*space2batch).BindInput(context_->GetMapedTensor(input_tensors[0])); diff --git a/src/tim/vx/ops/resize.cc b/src/tim/vx/ops/resize.cc index a06b1c6..41566a5 100644 --- a/src/tim/vx/ops/resize.cc +++ b/src/tim/vx/ops/resize.cc @@ -32,8 +32,9 @@ namespace vx { namespace ops { Resize::Resize(Graph* graph, ResizeType type, float factor, bool align_corners, - bool half_pixel_centers, int target_height, int target_width) - : Operation(graph, VSI_NN_OP_RESIZE), + bool half_pixel_centers, int target_height, int target_width, + DataLayout layout) + : Operation(graph, VSI_NN_OP_RESIZE, 0, 0, layout), type_(type), factor_(factor), align_corners_(align_corners),