Code refinement for mean-stddev-normalization fuse (#632)

1.Added copyright  && Added reference or const reference for functions
2.Rewrite function of determing whether there is a common input
3.Use std::remove_if instead of std::find before doing erase
4.Added security check to prevent access to deleted ops

Type: Code Improvement

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2023-08-15 13:15:03 +08:00 committed by GitHub
parent af50cc5e3f
commit 2f018cc088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 74 additions and 37 deletions

View File

@ -1,3 +1,26 @@
/****************************************************************************
*
* Copyright (c) 2020-2023 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include <algorithm>
#include <stdarg.h>
#include "tim/transform/mean_stddev_normalize_fusion.h"
@ -84,10 +107,13 @@ bool UpdateTempVector(std::vector<std::shared_ptr<vx::Operation>>& temp,
// Remove ops and tensors in each matched normlization patten
void RemoveTensorsAndOps(
std::shared_ptr<vx::Graph>& graph,
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
const std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
for (uint32_t i = 0; i < norm_ops.size(); i++) {
auto it = std::find(graph->OpVector().begin(), graph->OpVector().end(),
norm_ops[i]);
auto it =
std::remove_if(graph->OpVector().begin(), graph->OpVector().end(),
[norm_ops, i](std::shared_ptr<vx::Operation> oper) {
return oper == norm_ops[i];
});
graph->OpVector().erase(it); //Remove current op from op_vector_
auto input_tensors = norm_ops[i]->impl()->InputsTensor();
auto output_tensors = norm_ops[i]->impl()->OutputsTensor();
@ -136,8 +162,8 @@ void RemoveTensorsAndOps(
}
}
bool CheckMediumMul(const std::shared_ptr<vx::Graph>& graph,
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
bool CheckMul0(const std::shared_ptr<vx::Graph>& graph,
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
auto mul0_output_tensor =
norm_ops[NORMALIZATION_INDEX_MUL_0]->impl()->OutputsTensor();
auto mul0_consumers = graph->GetConsumersOp(mul0_output_tensor[0]);
@ -156,15 +182,29 @@ bool CheckMediumMul(const std::shared_ptr<vx::Graph>& graph,
return true;
}
bool HaveASameInput(const std::shared_ptr<vx::Operation>& op1,
const std::shared_ptr<vx::Operation>& op2) {
auto Left = op1->impl()->InputsTensor();
auto Right = op2->impl()->InputsTensor();
for (auto left_tensor : Left) {
if (std::find(Right.begin(), Right.end(), left_tensor) != Right.end())
return true;
}
return false;
bool CheckMean0Sub0Mul1SameInput(
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
auto mean0_inputs_tensors =
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor();
auto sub0_inputs_tensors =
norm_ops[NORMALIZATION_INDEX_SUB_0]->impl()->InputsTensor();
auto mul1_inputs_tensors =
norm_ops[NORMALIZATION_INDEX_MUL_1]->impl()->InputsTensor();
std::sort(mean0_inputs_tensors.begin(), mean0_inputs_tensors.end());
std::sort(sub0_inputs_tensors.begin(), sub0_inputs_tensors.end());
std::sort(mul1_inputs_tensors.begin(), mul1_inputs_tensors.end());
std::vector<std::shared_ptr<vx::Tensor>> intersect1, intersect2;
std::set_intersection(mean0_inputs_tensors.begin(),
mean0_inputs_tensors.end(), sub0_inputs_tensors.begin(),
sub0_inputs_tensors.end(),
std::back_inserter(intersect1));
std::set_intersection(sub0_inputs_tensors.begin(), sub0_inputs_tensors.end(),
mul1_inputs_tensors.begin(), mul1_inputs_tensors.end(),
std::back_inserter(intersect2));
if (intersect2.empty())
return false;
else
return true;
}
void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
@ -180,12 +220,15 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
axis = src_tensor->GetShape().size() - axis - 1; // reverse axis
// Get eps, gamma,beta;
// Do datatype convert due to InstanceNormlization op requirements
int32_t eps_index = graph->GetProducerOp(
norm_ops[5]->impl()->InputsTensor()[0]) == norm_ops[4]
? 1
: 0;
auto org_eps = norm_ops[5]->impl()->InputsTensor()[eps_index];
// Do datatype convert due to Layernorm op requirements
int32_t eps_index =
graph->GetProducerOp(
norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[0]) ==
norm_ops[NORMALIZATION_INDEX_MEAN_1]
? 1
: 0;
auto org_eps =
norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[eps_index];
if (!org_eps->IsConstTensor()) {
org_eps = graph->GetProducerOp(org_eps)->impl()->InputsTensor()[0];
}
@ -203,8 +246,8 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
vx::TensorAttribute::CONSTANT);
auto beta = graph->CreateTensor(param_spec);
auto gamma = graph->CreateTensor(param_spec);
auto beta = graph->CreateTensor(param_spec, float_beta);
auto gamma = graph->CreateTensor(param_spec, float_gamma);
float eps = *float_eps;
beta->CopyDataToTensor(float_beta);
gamma->CopyDataToTensor(float_gamma);
@ -222,7 +265,7 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
void InstancenormConnection(
std::shared_ptr<vx::Graph>& graph,
std::vector<std::shared_ptr<vx::Operation>> norm_ops) {
const std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
auto src_tensor =
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor()[0];
auto final_tensor =
@ -257,11 +300,9 @@ void InstancenormConnection(
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
vx::TensorAttribute::CONSTANT);
auto beta = graph->CreateTensor(param_spec);
auto gamma = graph->CreateTensor(param_spec);
auto beta = graph->CreateTensor(param_spec, float_beta);
auto gamma = graph->CreateTensor(param_spec, float_gamma);
float eps = *float_eps;
beta->CopyDataToTensor(float_beta);
gamma->CopyDataToTensor(float_gamma);
vsi_nn_Free(float_gamma);
vsi_nn_Free(float_beta);
vsi_nn_Free(float_eps);
@ -299,9 +340,11 @@ void InstancenormConnection(
output
*/
void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& graph) {
std::vector<std::shared_ptr<vx::Operation>> op_vector = graph->OpVector();
std::vector<std::shared_ptr<vx::Operation>>& op_vector = graph->OpVector();
for (const auto& op : op_vector) {
for (auto& op : op_vector) {
if (std::find(op_vector.begin(), op_vector.end(), op) == op_vector.end())
continue; //Avoid read deleted data in op_vector
if (op->impl()->kind_ != VSI_NN_OP_REDUCE) continue;
std::vector<std::shared_ptr<vx::Operation>> temp;
@ -328,14 +371,8 @@ void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& graph) {
{VSI_NN_OP_MULTIPLY}))
continue; //Mul0
if (!CheckMediumMul(graph, temp)) continue;
if (!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0],
temp[NORMALIZATION_INDEX_SUB_0]) &&
!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0],
temp[NORMALIZATION_INDEX_MUL_1]) &&
!HaveASameInput(temp[NORMALIZATION_INDEX_SUB_0],
temp[NORMALIZATION_INDEX_MUL_1]))
continue;
if (!CheckMul0(graph, temp)) continue;
if (!CheckMean0Sub0Mul1SameInput(temp)) continue;
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MUL_1, graph,
{VSI_NN_OP_ADD}))