From 2f018cc08876bb40e5d3d3058721e76480bdceaa Mon Sep 17 00:00:00 2001 From: Chen Feiyue <69809761+chenfeiyue-cfy@users.noreply.github.com> Date: Tue, 15 Aug 2023 13:15:03 +0800 Subject: [PATCH] Code refinement for mean-stddev-normalization fuse (#632) 1.Added copyright && Added reference or const reference for functions 2.Rewrite function of determing whether there is a common input 3.Use std::remove_if instead of std::find before doing erase 4.Added security check to prevent access to deleted ops Type: Code Improvement Signed-off-by: Feiyue Chen --- .../transform/mean_stddev_normalize_fusion.cc | 111 ++++++++++++------ 1 file changed, 74 insertions(+), 37 deletions(-) diff --git a/src/tim/transform/mean_stddev_normalize_fusion.cc b/src/tim/transform/mean_stddev_normalize_fusion.cc index 95213bf..eeb55df 100644 --- a/src/tim/transform/mean_stddev_normalize_fusion.cc +++ b/src/tim/transform/mean_stddev_normalize_fusion.cc @@ -1,3 +1,26 @@ +/**************************************************************************** + * + * Copyright (c) 2020-2023 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ #include #include #include "tim/transform/mean_stddev_normalize_fusion.h" @@ -84,10 +107,13 @@ bool UpdateTempVector(std::vector>& temp, // Remove ops and tensors in each matched normlization patten void RemoveTensorsAndOps( std::shared_ptr& graph, - std::vector>& norm_ops) { + const std::vector>& norm_ops) { for (uint32_t i = 0; i < norm_ops.size(); i++) { - auto it = std::find(graph->OpVector().begin(), graph->OpVector().end(), - norm_ops[i]); + auto it = + std::remove_if(graph->OpVector().begin(), graph->OpVector().end(), + [norm_ops, i](std::shared_ptr oper) { + return oper == norm_ops[i]; + }); graph->OpVector().erase(it); //Remove current op from op_vector_ auto input_tensors = norm_ops[i]->impl()->InputsTensor(); auto output_tensors = norm_ops[i]->impl()->OutputsTensor(); @@ -136,8 +162,8 @@ void RemoveTensorsAndOps( } } -bool CheckMediumMul(const std::shared_ptr& graph, - std::vector>& norm_ops) { +bool CheckMul0(const std::shared_ptr& graph, + std::vector>& norm_ops) { auto mul0_output_tensor = norm_ops[NORMALIZATION_INDEX_MUL_0]->impl()->OutputsTensor(); auto mul0_consumers = graph->GetConsumersOp(mul0_output_tensor[0]); @@ -156,15 +182,29 @@ bool CheckMediumMul(const std::shared_ptr& graph, return true; } -bool HaveASameInput(const std::shared_ptr& op1, - const std::shared_ptr& op2) { - auto Left = op1->impl()->InputsTensor(); - auto Right = op2->impl()->InputsTensor(); - for (auto left_tensor : Left) { - if (std::find(Right.begin(), Right.end(), left_tensor) != Right.end()) - return true; - } - return false; +bool CheckMean0Sub0Mul1SameInput( + std::vector>& norm_ops) { + auto mean0_inputs_tensors = + norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor(); + auto sub0_inputs_tensors = + norm_ops[NORMALIZATION_INDEX_SUB_0]->impl()->InputsTensor(); + auto mul1_inputs_tensors = + norm_ops[NORMALIZATION_INDEX_MUL_1]->impl()->InputsTensor(); + std::sort(mean0_inputs_tensors.begin(), mean0_inputs_tensors.end()); + std::sort(sub0_inputs_tensors.begin(), sub0_inputs_tensors.end()); + std::sort(mul1_inputs_tensors.begin(), mul1_inputs_tensors.end()); + std::vector> intersect1, intersect2; + std::set_intersection(mean0_inputs_tensors.begin(), + mean0_inputs_tensors.end(), sub0_inputs_tensors.begin(), + sub0_inputs_tensors.end(), + std::back_inserter(intersect1)); + std::set_intersection(sub0_inputs_tensors.begin(), sub0_inputs_tensors.end(), + mul1_inputs_tensors.begin(), mul1_inputs_tensors.end(), + std::back_inserter(intersect2)); + if (intersect2.empty()) + return false; + else + return true; } void LayernormConnection(std::shared_ptr& graph, @@ -180,12 +220,15 @@ void LayernormConnection(std::shared_ptr& graph, axis = src_tensor->GetShape().size() - axis - 1; // reverse axis // Get eps, gamma,beta; - // Do datatype convert due to InstanceNormlization op requirements - int32_t eps_index = graph->GetProducerOp( - norm_ops[5]->impl()->InputsTensor()[0]) == norm_ops[4] - ? 1 - : 0; - auto org_eps = norm_ops[5]->impl()->InputsTensor()[eps_index]; + // Do datatype convert due to Layernorm op requirements + int32_t eps_index = + graph->GetProducerOp( + norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[0]) == + norm_ops[NORMALIZATION_INDEX_MEAN_1] + ? 1 + : 0; + auto org_eps = + norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[eps_index]; if (!org_eps->IsConstTensor()) { org_eps = graph->GetProducerOp(org_eps)->impl()->InputsTensor()[0]; } @@ -203,8 +246,8 @@ void LayernormConnection(std::shared_ptr& graph, vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape, vx::TensorAttribute::CONSTANT); - auto beta = graph->CreateTensor(param_spec); - auto gamma = graph->CreateTensor(param_spec); + auto beta = graph->CreateTensor(param_spec, float_beta); + auto gamma = graph->CreateTensor(param_spec, float_gamma); float eps = *float_eps; beta->CopyDataToTensor(float_beta); gamma->CopyDataToTensor(float_gamma); @@ -222,7 +265,7 @@ void LayernormConnection(std::shared_ptr& graph, void InstancenormConnection( std::shared_ptr& graph, - std::vector> norm_ops) { + const std::vector>& norm_ops) { auto src_tensor = norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor()[0]; auto final_tensor = @@ -257,11 +300,9 @@ void InstancenormConnection( vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape, vx::TensorAttribute::CONSTANT); - auto beta = graph->CreateTensor(param_spec); - auto gamma = graph->CreateTensor(param_spec); + auto beta = graph->CreateTensor(param_spec, float_beta); + auto gamma = graph->CreateTensor(param_spec, float_gamma); float eps = *float_eps; - beta->CopyDataToTensor(float_beta); - gamma->CopyDataToTensor(float_gamma); vsi_nn_Free(float_gamma); vsi_nn_Free(float_beta); vsi_nn_Free(float_eps); @@ -299,9 +340,11 @@ void InstancenormConnection( output */ void MeanStdDevNormalization(std::shared_ptr& graph) { - std::vector> op_vector = graph->OpVector(); + std::vector>& op_vector = graph->OpVector(); - for (const auto& op : op_vector) { + for (auto& op : op_vector) { + if (std::find(op_vector.begin(), op_vector.end(), op) == op_vector.end()) + continue; //Avoid read deleted data in op_vector if (op->impl()->kind_ != VSI_NN_OP_REDUCE) continue; std::vector> temp; @@ -328,14 +371,8 @@ void MeanStdDevNormalization(std::shared_ptr& graph) { {VSI_NN_OP_MULTIPLY})) continue; //Mul0 - if (!CheckMediumMul(graph, temp)) continue; - if (!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0], - temp[NORMALIZATION_INDEX_SUB_0]) && - !HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0], - temp[NORMALIZATION_INDEX_MUL_1]) && - !HaveASameInput(temp[NORMALIZATION_INDEX_SUB_0], - temp[NORMALIZATION_INDEX_MUL_1])) - continue; + if (!CheckMul0(graph, temp)) continue; + if (!CheckMean0Sub0Mul1SameInput(temp)) continue; if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MUL_1, graph, {VSI_NN_OP_ADD}))