diff --git a/src/tim/transform/ops/activation_layout_inference.h b/src/tim/transform/ops/activation_layout_inference.h index de833d2..64adf31 100644 --- a/src/tim/transform/ops/activation_layout_inference.h +++ b/src/tim/transform/ops/activation_layout_inference.h @@ -29,6 +29,7 @@ #include "ops/op_layout_inference.h" #include "permute_vector.h" #include "builtin_op_impl.h" +#include "tim/vx/ops/transpose.h" namespace tim { namespace transform { @@ -65,15 +66,50 @@ class PReluLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { - ReverseInputsPermuteVector(); auto src_input = op_->impl()->InputsTensor()[0]; + auto src_slope = op_->impl()->InputsTensor()[1]; auto input_pv = context_->GetPermuteVector(src_input); - auto prelu = context_->infer_graph_->CreateOperation( - op_->impl()->node()->nn_param.prelu.axis); - auto out_infer = CreateOutputsTensor(input_pv); - for (const auto& i_src : op_->impl()->InputsTensor()) { - (*prelu).BindInput(context_->GetMapedTensor(i_src)); + + if (src_slope->IsConstTensor()) { + std::shared_ptr infer_tensor; + std::shared_ptr slope_pv; + std::vector dataRef(src_slope->GetSpec().GetByteSize()); + src_slope->CopyDataFromTensor(dataRef.data()); + auto infer_slope = context_->infer_graph_->CreateTensor( + src_slope->GetSpec(), (const void*)dataRef.data()); + slope_pv = MakeShared(src_slope->GetShape().size()); + + if(!input_pv->IsAligned()){ + // compute transpose param + std::vector perm; + for(uint32_t i = 0,j=0; i< input_pv->Rank(); i++,j++){ + if(j == slope_pv->Rank()) break; + if(input_pv->At(i) < slope_pv->Rank()){ + perm.push_back(input_pv->At(i)); + } + else i++; // if dims of input is higher than slope + } + auto out_slope = context_->infer_graph_->CreateTensor(src_slope->GetSpec().AsTransientSpec()); + auto permute = context_->infer_graph_->CreateOperation(perm); + (*permute).BindInput(infer_slope).BindOutput(out_slope); + context_->UpdateTensorMap(src_slope, out_slope); + } + else { + context_->UpdateTensorMap(src_slope, infer_slope); + } + context_->SetPermuteVector(src_slope,slope_pv); } + else{ + VSILOGE("Slope tensor cannot be handled yet if not constant."); + assert(false); + } + auto axis = MapAxis(input_pv->AsStdVec(), + op_->impl()->node()->nn_param.prelu.axis); + auto prelu = context_->infer_graph_->CreateOperation(axis); + auto out_infer = CreateOutputsTensor(input_pv); + + (*prelu).BindInput(context_->GetMapedTensor(src_input)).BindInput( + context_->GetMapedTensor(src_slope)); (*prelu).BindOutput(out_infer[0]); context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv); next_tensors.push_back(op_->impl()->OutputsTensor()[0]);