support layout inference for operations (#39)

Add layout inference support for space2depth, depth2space, space2batch, batch2space, pad and
reduce.

Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>

Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
liyuenan 2021-05-13 22:27:23 +08:00 committed by GitHub
parent d645494dcc
commit 748274143b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 471 additions and 19 deletions

View File

@ -32,10 +32,11 @@ namespace tim {
namespace vx {
namespace ops {
class Batch2Space : public Operation {
class BatchToSpace : public Operation {
public:
Batch2Space(Graph* graph, const std::vector<int>& block_size,
const std::vector<int>& crop);
BatchToSpace(Graph* graph, const std::vector<int>& block_size,
const std::vector<int>& crop,
DataLayout layout = DataLayout::WHCN);
protected:
std::vector<int> block_size_;

View File

@ -31,7 +31,8 @@ namespace ops {
class DepthToSpace : public Operation {
public:
DepthToSpace(Graph* Graph, int block_size);
DepthToSpace(Graph* Graph, int block_size,
DataLayout layout = DataLayout::WHCN);
protected:
int block_size_;

View File

@ -45,6 +45,7 @@ DECLARE_REDUCE_OP(Max);
DECLARE_REDUCE_OP(Any);
DECLARE_REDUCE_OP(Prod);
DECLARE_REDUCE_OP(Mean);
DECLARE_REDUCE_OP(Sum);
#undef DECLARE_REDUCE_OP

View File

@ -32,10 +32,11 @@ namespace tim {
namespace vx {
namespace ops {
class Space2Batch : public Operation {
class SpaceToBatch : public Operation {
public:
Space2Batch(Graph* graph, const std::vector<int>& block_size,
const std::vector<int>& pad);
SpaceToBatch(Graph* graph, const std::vector<int>& block_size,
const std::vector<int>& pad,
DataLayout layout = DataLayout::WHCN);
protected:
std::vector<int> block_size_;

View File

@ -31,7 +31,8 @@ namespace ops {
class SpaceToDepth : public Operation {
public:
SpaceToDepth(Graph* graph, std::vector<int> block_size);
SpaceToDepth(Graph* graph, std::vector<int> block_size,
DataLayout layout = DataLayout::WHCN);
protected:
std::vector<int> block_size_;

View File

@ -36,6 +36,12 @@
#include "ops/softmax_layout_inference.h"
#include "ops/squeeze_layout_inference.h"
#include "ops/stack_layout_inference.h"
#include "ops/space2depth_layout_inference.h"
#include "ops/depth2space_layout_inference.h"
#include "ops/space2batch_layout_inference.h"
#include "ops/batch2space_layout_inference.h"
#include "ops/pad_layout_inference.h"
#include "ops/reduce_layout_inference.h"
#include <algorithm>
#include <deque>
@ -127,6 +133,23 @@ void LayoutInferContext::UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>&
break; \
} \
#define REGIST_REDUCE_LAYOUT_INFERENCE(op_idx) \
case op_idx: { \
auto reduce_type = op->impl()->node()->nn_param.reduce.type; \
switch (reduce_type) { \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MEAN, ReduceMean); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MAX, ReduceMax); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MIN, ReduceMin); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_PROD, ReduceProd); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ANY, ReduceAny); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_SUM, ReduceSum); \
default: \
VSILOGW("Op %d: Default layout inference pass for reduce.", reduce_type);\
assert(false); \
} \
break; \
} \
std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
std::shared_ptr<layout_inference_impl::LayoutInferContext>& ctx,
const std::shared_ptr<vx::Operation>& op) {
@ -169,6 +192,12 @@ std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SOFTMAX, Softmax);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SQUEEZE, Squeeze);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_STACK, Stack);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2DEPTH, SpaceToDepth);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DEPTH2SPACE, DepthToSpace);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2BATCH, SpaceToBatch);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH2SPACE, BatchToSpace);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PAD, Pad);
REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE);
default:
VSILOGW("Op %d: Default layout inference pass.", op_id);

View File

@ -0,0 +1,84 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_BATCH2SPACE_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_BATCH2SPACE_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/batch2space.h"
#include "src/tim/transform/ops/op_layout_inference.h"
#include "src/tim/transform/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
class BatchToSpaceLayoutInfer : public OpLayoutInfer {
public:
BatchToSpaceLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
vx::DataLayout layout = op_->impl()->layout_;
auto required_pv = MakeShared(4);
if (layout == vx::DataLayout::CWHN) {
required_pv = std::make_shared<PermuteVector<4>>(kCWHN2WHCN);
}
auto input_tensors = op_->impl()->InputsTensor();
auto pv = context_->GetPermuteVector(input_tensors[0]);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
context_->UpdateTensorMap(input_tensors[0], perm_out);
context_->SetPermuteVector(input_tensors[0], required_pv);
}
uint32_t block_size_num =
op_->impl()->node()->nn_param.batch2space.block_size_num;
std::vector<int> block_size(block_size_num);
memcpy(block_size.data(),
op_->impl()->node()->nn_param.batch2space.block_size,
sizeof(int) * block_size_num);
std::vector<int> crop(4);
memcpy(crop.data(), op_->impl()->node()->nn_param.batch2space.crop,
sizeof(int) * 4);
auto batch2space =
context_->infer_graph_->CreateOperation<vx::ops::BatchToSpace>(
block_size, crop, vx::DataLayout::WHCN);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*batch2space).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*batch2space).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,77 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_DEPTH2SPACE_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_DEPTH2SPACE_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/depth2space.h"
#include "src/tim/transform/ops/op_layout_inference.h"
#include "src/tim/transform/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
class DepthToSpaceLayoutInfer : public OpLayoutInfer {
public:
DepthToSpaceLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
vx::DataLayout layout = op_->impl()->layout_;
auto required_pv = MakeShared(4);
if (layout == vx::DataLayout::CWHN) {
required_pv = std::make_shared<PermuteVector<4>>(kCWHN2WHCN);
}
auto input_tensors = op_->impl()->InputsTensor();
auto pv = context_->GetPermuteVector(input_tensors[0]);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
context_->UpdateTensorMap(input_tensors[0], perm_out);
context_->SetPermuteVector(input_tensors[0], required_pv);
}
int block_size = op_->impl()->node()->nn_param.depth2space.block_size;
auto space2depth =
context_->infer_graph_->CreateOperation<vx::ops::DepthToSpace>(
block_size, vx::DataLayout::WHCN);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*space2depth).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*space2depth).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -234,5 +234,17 @@ std::shared_ptr<vx::Tensor> OpLayoutInfer::PermuteConstTensor(
}
return context_->infer_graph_->CreateTensor(dst_spec, data.data());
}
std::vector<uint32_t> OpLayoutInfer::MapPadding(const std::vector<uint32_t>& perm,
const std::vector<uint32_t>& padding) {
assert(perm.size() == padding.size());
std::vector<uint32_t> r(padding.size());
for (int i = 0; i < padding.size(); ++i) {
r[i] = padding[perm[i]];
}
return r;
}
} // namespace transform
} // namespace tim

View File

@ -75,6 +75,9 @@ class OpLayoutInfer {
std::shared_ptr<vx::Tensor> PermuteConstTensor(
const std::shared_ptr<vx::Tensor>& input,
const std::shared_ptr<IPermuteVector>& pv);
std::vector<uint32_t> MapPadding(const std::vector<uint32_t>& perm,
const std::vector<uint32_t>& padding);
protected:
const std::shared_ptr<vx::Operation> op_;

View File

@ -0,0 +1,74 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_PAD_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_PAD_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/pad.h"
#include "src/tim/transform/ops/op_layout_inference.h"
#include "src/tim/transform/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
class PadLayoutInfer : public OpLayoutInfer {
public:
PadLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
assert(op_->impl()->InputsTensor().size() == 1);
auto i_src = op_->impl()->InputsTensor()[0];
auto input_pv = context_->GetPermuteVector(i_src);
uint32_t dim_num = op_->impl()->node()->nn_param.pad.dim_num;
std::vector<uint32_t> front_size(dim_num);
std::vector<uint32_t> back_size(dim_num);
memcpy(front_size.data(), op_->impl()->node()->nn_param.pad.front_size,
sizeof(uint32_t) * dim_num);
memcpy(back_size.data(), op_->impl()->node()->nn_param.pad.back_size,
sizeof(uint32_t) * dim_num);
int32_t pad_value = op_->impl()->node()->nn_param.pad.const_val;
if (!input_pv->IsAligned()) {
front_size = MapPadding(input_pv->AsStdVec(), front_size);
back_size = MapPadding(input_pv->AsStdVec(), back_size);
}
auto pad = context_->infer_graph_->CreateOperation<vx::ops::Pad>(
front_size, back_size, pad_value);
auto out_infer = CreateOutputsTensor(input_pv);
(*pad).BindInput(context_->GetMapedTensor(i_src));
(*pad).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -36,6 +36,7 @@ namespace tim {
namespace transform {
template <typename OpType>
class ReduceLayoutInfer : public OpLayoutInfer {
public:
ReduceLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
@ -60,7 +61,7 @@ class ReduceLayoutInfer : public OpLayoutInfer {
(*reduce).BindInput(context_->GetMapedTensor(t_src));
if (op_->impl()->node()->nn_param.reduce.keep_dim) {
auto otensor_infer = CreateOutputsTensor(pv);
(*reduce).BindOuput(otensor_infer[0]);
(*reduce).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], pv);
} else {
auto out_pv = MakeShared(pv->Rank() - unique_axis.size());
@ -87,6 +88,7 @@ using ReduceMaxLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceMax>;
using ReduceAnyLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceAny>;
using ReduceProdLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceProd>;
using ReduceMeanLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceMean>;
using ReduceSumLayoutInfer = ReduceLayoutInfer<tim::vx::ops::ReduceSum>;
} // namespace transform
} // namespace tim

View File

@ -0,0 +1,84 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_SPACE2BATCH_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_SPACE2BATCH_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/space2batch.h"
#include "src/tim/transform/ops/op_layout_inference.h"
#include "src/tim/transform/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
class SpaceToBatchLayoutInfer : public OpLayoutInfer {
public:
SpaceToBatchLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
vx::DataLayout layout = op_->impl()->layout_;
auto required_pv = MakeShared(4);
if (layout == vx::DataLayout::CWHN) {
required_pv = std::make_shared<PermuteVector<4>>(kCWHN2WHCN);
}
auto input_tensors = op_->impl()->InputsTensor();
auto pv = context_->GetPermuteVector(input_tensors[0]);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
context_->UpdateTensorMap(input_tensors[0], perm_out);
context_->SetPermuteVector(input_tensors[0], required_pv);
}
uint32_t block_size_num =
op_->impl()->node()->nn_param.space2batch.block_size_num;
std::vector<int> block_size(block_size_num);
memcpy(block_size.data(),
op_->impl()->node()->nn_param.space2batch.block_size,
sizeof(int) * block_size_num);
std::vector<int> pad(4);
memcpy(pad.data(), op_->impl()->node()->nn_param.space2batch.pad,
sizeof(int) * 4);
auto space2batch =
context_->infer_graph_->CreateOperation<vx::ops::SpaceToBatch>(
block_size, pad, vx::DataLayout::WHCN);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*space2batch).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*space2batch).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -0,0 +1,78 @@
/****************************************************************************
*
* Copyright (c) 2020 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef TIM_LAYOUT_INFER_SPACE2DEPTH_LAYOUT_INFERENCE_H_
#define TIM_LAYOUT_INFER_SPACE2DEPTH_LAYOUT_INFERENCE_H_
#include "tim/vx/ops/space2depth.h"
#include "src/tim/transform/ops/op_layout_inference.h"
#include "src/tim/transform/permute_vector.h"
#include "src/tim/vx/operation_private.h"
namespace tim {
namespace transform {
class SpaceToDepthLayoutInfer : public OpLayoutInfer {
public:
SpaceToDepthLayoutInfer(
const std::shared_ptr<vx::Operation> op,
std::shared_ptr<layout_inference_impl::LayoutInferContext>& context)
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
vx::DataLayout layout = op_->impl()->layout_;
auto required_pv = MakeShared(4);
if (layout == vx::DataLayout::CWHN) {
required_pv = std::make_shared<PermuteVector<4>>(kCWHN2WHCN);
}
auto input_tensors = op_->impl()->InputsTensor();
auto pv = context_->GetPermuteVector(input_tensors[0]);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
context_->UpdateTensorMap(input_tensors[0], perm_out);
context_->SetPermuteVector(input_tensors[0], required_pv);
}
std::vector<int> block_size = {
op_->impl()->node()->nn_param.space2depth.block_size[0],
op_->impl()->node()->nn_param.space2depth.block_size[1]};
auto space2depth =
context_->infer_graph_->CreateOperation<vx::ops::SpaceToDepth>(
block_size, vx::DataLayout::WHCN);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*space2depth).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*space2depth).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
}
};
} // namespace transform
} // namespace tim
#endif

View File

@ -30,9 +30,9 @@ namespace tim {
namespace vx {
namespace ops {
Batch2Space::Batch2Space(Graph* graph, const std::vector<int>& block_size,
const std::vector<int>& crop)
: Operation(graph, VSI_NN_OP_BATCH2SPACE),
BatchToSpace::BatchToSpace(Graph* graph, const std::vector<int>& block_size,
const std::vector<int>& crop, DataLayout layout)
: Operation(graph, VSI_NN_OP_BATCH2SPACE, 0, 0, layout),
block_size_(block_size),
crop_(crop) {
this->impl()->node()->nn_param.batch2space.block_size = block_size_.data();

View File

@ -30,8 +30,9 @@ namespace tim {
namespace vx {
namespace ops {
DepthToSpace::DepthToSpace(Graph* graph, int block_size)
: Operation(graph, VSI_NN_OP_DEPTH2SPACE), block_size_(block_size) {
DepthToSpace::DepthToSpace(Graph* graph, int block_size, DataLayout layout)
: Operation(graph, VSI_NN_OP_DEPTH2SPACE, 0, 0, layout),
block_size_(block_size) {
this->impl()->node()->nn_param.depth2space.block_size = block_size_;
}
} // namespace ops

View File

@ -47,6 +47,7 @@ DEFINE_REDUCE_OP(Max, VSI_NN_REDUCE_MAX);
DEFINE_REDUCE_OP(Any, VSI_NN_REDUCE_ANY);
DEFINE_REDUCE_OP(Prod, VSI_NN_REDUCE_PROD);
DEFINE_REDUCE_OP(Mean, VSI_NN_REDUCE_MEAN);
DEFINE_REDUCE_OP(Sum, VSI_NN_REDUCE_SUM);
#undef DEFINE_REDUCE_OP

View File

@ -30,9 +30,9 @@ namespace tim {
namespace vx {
namespace ops {
Space2Batch::Space2Batch(Graph* graph, const std::vector<int>& block_size,
const std::vector<int>& pad)
: Operation(graph, VSI_NN_OP_SPACE2BATCH),
SpaceToBatch::SpaceToBatch(Graph* graph, const std::vector<int>& block_size,
const std::vector<int>& pad, DataLayout layout)
: Operation(graph, VSI_NN_OP_SPACE2BATCH, 0, 0, layout),
block_size_(block_size),
pad_(pad) {
this->impl()->node()->nn_param.space2batch.block_size = block_size_.data();

View File

@ -30,8 +30,10 @@ namespace tim {
namespace vx {
namespace ops {
SpaceToDepth::SpaceToDepth(Graph* graph, std::vector<int> block_size)
: Operation(graph, VSI_NN_OP_SPACE2DEPTH), block_size_(block_size) {
SpaceToDepth::SpaceToDepth(Graph* graph, std::vector<int> block_size,
DataLayout layout)
: Operation(graph, VSI_NN_OP_SPACE2DEPTH, 0, 0, layout),
block_size_(block_size) {
this->impl()->node()->nn_param.space2depth.block_size[0] = block_size_[0];
this->impl()->node()->nn_param.space2depth.block_size[1] = block_size_[1];
}