Refine prelu layout inference (#577)
In the past we reverse all inputs to default order pv and caused unnecessary transpose operation. In this commit only const slope will be handled and do transpose if necessary. Type: Code Improvement Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
3a3c9fa5fa
commit
3c5ee7a46e
|
|
@ -29,6 +29,7 @@
|
|||
#include "ops/op_layout_inference.h"
|
||||
#include "permute_vector.h"
|
||||
#include "builtin_op_impl.h"
|
||||
#include "tim/vx/ops/transpose.h"
|
||||
|
||||
namespace tim {
|
||||
namespace transform {
|
||||
|
|
@ -65,15 +66,50 @@ class PReluLayoutInfer : public OpLayoutInfer {
|
|||
|
||||
void OnInputs(
|
||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||
ReverseInputsPermuteVector();
|
||||
auto src_input = op_->impl()->InputsTensor()[0];
|
||||
auto src_slope = op_->impl()->InputsTensor()[1];
|
||||
auto input_pv = context_->GetPermuteVector(src_input);
|
||||
auto prelu = context_->infer_graph_->CreateOperation<vx::ops::Prelu>(
|
||||
op_->impl()->node()->nn_param.prelu.axis);
|
||||
auto out_infer = CreateOutputsTensor(input_pv);
|
||||
for (const auto& i_src : op_->impl()->InputsTensor()) {
|
||||
(*prelu).BindInput(context_->GetMapedTensor(i_src));
|
||||
|
||||
if (src_slope->IsConstTensor()) {
|
||||
std::shared_ptr<vx::Tensor> infer_tensor;
|
||||
std::shared_ptr<IPermuteVector> slope_pv;
|
||||
std::vector<uint8_t> dataRef(src_slope->GetSpec().GetByteSize());
|
||||
src_slope->CopyDataFromTensor(dataRef.data());
|
||||
auto infer_slope = context_->infer_graph_->CreateTensor(
|
||||
src_slope->GetSpec(), (const void*)dataRef.data());
|
||||
slope_pv = MakeShared(src_slope->GetShape().size());
|
||||
|
||||
if(!input_pv->IsAligned()){
|
||||
// compute transpose param
|
||||
std::vector<uint32_t> perm;
|
||||
for(uint32_t i = 0,j=0; i< input_pv->Rank(); i++,j++){
|
||||
if(j == slope_pv->Rank()) break;
|
||||
if(input_pv->At(i) < slope_pv->Rank()){
|
||||
perm.push_back(input_pv->At(i));
|
||||
}
|
||||
else i++; // if dims of input is higher than slope
|
||||
}
|
||||
auto out_slope = context_->infer_graph_->CreateTensor(src_slope->GetSpec().AsTransientSpec());
|
||||
auto permute = context_->infer_graph_->CreateOperation<vx::ops::Transpose>(perm);
|
||||
(*permute).BindInput(infer_slope).BindOutput(out_slope);
|
||||
context_->UpdateTensorMap(src_slope, out_slope);
|
||||
}
|
||||
else {
|
||||
context_->UpdateTensorMap(src_slope, infer_slope);
|
||||
}
|
||||
context_->SetPermuteVector(src_slope,slope_pv);
|
||||
}
|
||||
else{
|
||||
VSILOGE("Slope tensor cannot be handled yet if not constant.");
|
||||
assert(false);
|
||||
}
|
||||
auto axis = MapAxis(input_pv->AsStdVec(),
|
||||
op_->impl()->node()->nn_param.prelu.axis);
|
||||
auto prelu = context_->infer_graph_->CreateOperation<vx::ops::Prelu>(axis);
|
||||
auto out_infer = CreateOutputsTensor(input_pv);
|
||||
|
||||
(*prelu).BindInput(context_->GetMapedTensor(src_input)).BindInput(
|
||||
context_->GetMapedTensor(src_slope));
|
||||
(*prelu).BindOutput(out_infer[0]);
|
||||
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
|
||||
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
|
||||
|
|
|
|||
Loading…
Reference in New Issue