Refine prelu layout inference (#577)

In the past we reverse all inputs to default order pv and caused
unnecessary transpose operation.
In this commit  only const slope  will be handled and do transpose if necessary.

Type: Code Improvement

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-04-25 11:25:55 +08:00 committed by GitHub
parent 3a3c9fa5fa
commit 3c5ee7a46e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 42 additions and 6 deletions

View File

@ -29,6 +29,7 @@
#include "ops/op_layout_inference.h"
#include "permute_vector.h"
#include "builtin_op_impl.h"
#include "tim/vx/ops/transpose.h"
namespace tim {
namespace transform {
@ -65,15 +66,50 @@ class PReluLayoutInfer : public OpLayoutInfer {
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
ReverseInputsPermuteVector();
auto src_input = op_->impl()->InputsTensor()[0];
auto src_slope = op_->impl()->InputsTensor()[1];
auto input_pv = context_->GetPermuteVector(src_input);
auto prelu = context_->infer_graph_->CreateOperation<vx::ops::Prelu>(
op_->impl()->node()->nn_param.prelu.axis);
auto out_infer = CreateOutputsTensor(input_pv);
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*prelu).BindInput(context_->GetMapedTensor(i_src));
if (src_slope->IsConstTensor()) {
std::shared_ptr<vx::Tensor> infer_tensor;
std::shared_ptr<IPermuteVector> slope_pv;
std::vector<uint8_t> dataRef(src_slope->GetSpec().GetByteSize());
src_slope->CopyDataFromTensor(dataRef.data());
auto infer_slope = context_->infer_graph_->CreateTensor(
src_slope->GetSpec(), (const void*)dataRef.data());
slope_pv = MakeShared(src_slope->GetShape().size());
if(!input_pv->IsAligned()){
// compute transpose param
std::vector<uint32_t> perm;
for(uint32_t i = 0,j=0; i< input_pv->Rank(); i++,j++){
if(j == slope_pv->Rank()) break;
if(input_pv->At(i) < slope_pv->Rank()){
perm.push_back(input_pv->At(i));
}
else i++; // if dims of input is higher than slope
}
auto out_slope = context_->infer_graph_->CreateTensor(src_slope->GetSpec().AsTransientSpec());
auto permute = context_->infer_graph_->CreateOperation<vx::ops::Transpose>(perm);
(*permute).BindInput(infer_slope).BindOutput(out_slope);
context_->UpdateTensorMap(src_slope, out_slope);
}
else {
context_->UpdateTensorMap(src_slope, infer_slope);
}
context_->SetPermuteVector(src_slope,slope_pv);
}
else{
VSILOGE("Slope tensor cannot be handled yet if not constant.");
assert(false);
}
auto axis = MapAxis(input_pv->AsStdVec(),
op_->impl()->node()->nn_param.prelu.axis);
auto prelu = context_->infer_graph_->CreateOperation<vx::ops::Prelu>(axis);
auto out_infer = CreateOutputsTensor(input_pv);
(*prelu).BindInput(context_->GetMapedTensor(src_input)).BindInput(
context_->GetMapedTensor(src_slope));
(*prelu).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);