Code refinement for mean-stddev-normalization fuse (#632)
1.Added copyright && Added reference or const reference for functions 2.Rewrite function of determing whether there is a common input 3.Use std::remove_if instead of std::find before doing erase 4.Added security check to prevent access to deleted ops Type: Code Improvement Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
af50cc5e3f
commit
2f018cc088
|
|
@ -1,3 +1,26 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2020-2023 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <stdarg.h>
|
||||
#include "tim/transform/mean_stddev_normalize_fusion.h"
|
||||
|
|
@ -84,10 +107,13 @@ bool UpdateTempVector(std::vector<std::shared_ptr<vx::Operation>>& temp,
|
|||
// Remove ops and tensors in each matched normlization patten
|
||||
void RemoveTensorsAndOps(
|
||||
std::shared_ptr<vx::Graph>& graph,
|
||||
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||
const std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||
for (uint32_t i = 0; i < norm_ops.size(); i++) {
|
||||
auto it = std::find(graph->OpVector().begin(), graph->OpVector().end(),
|
||||
norm_ops[i]);
|
||||
auto it =
|
||||
std::remove_if(graph->OpVector().begin(), graph->OpVector().end(),
|
||||
[norm_ops, i](std::shared_ptr<vx::Operation> oper) {
|
||||
return oper == norm_ops[i];
|
||||
});
|
||||
graph->OpVector().erase(it); //Remove current op from op_vector_
|
||||
auto input_tensors = norm_ops[i]->impl()->InputsTensor();
|
||||
auto output_tensors = norm_ops[i]->impl()->OutputsTensor();
|
||||
|
|
@ -136,8 +162,8 @@ void RemoveTensorsAndOps(
|
|||
}
|
||||
}
|
||||
|
||||
bool CheckMediumMul(const std::shared_ptr<vx::Graph>& graph,
|
||||
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||
bool CheckMul0(const std::shared_ptr<vx::Graph>& graph,
|
||||
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||
auto mul0_output_tensor =
|
||||
norm_ops[NORMALIZATION_INDEX_MUL_0]->impl()->OutputsTensor();
|
||||
auto mul0_consumers = graph->GetConsumersOp(mul0_output_tensor[0]);
|
||||
|
|
@ -156,15 +182,29 @@ bool CheckMediumMul(const std::shared_ptr<vx::Graph>& graph,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool HaveASameInput(const std::shared_ptr<vx::Operation>& op1,
|
||||
const std::shared_ptr<vx::Operation>& op2) {
|
||||
auto Left = op1->impl()->InputsTensor();
|
||||
auto Right = op2->impl()->InputsTensor();
|
||||
for (auto left_tensor : Left) {
|
||||
if (std::find(Right.begin(), Right.end(), left_tensor) != Right.end())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
bool CheckMean0Sub0Mul1SameInput(
|
||||
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||
auto mean0_inputs_tensors =
|
||||
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor();
|
||||
auto sub0_inputs_tensors =
|
||||
norm_ops[NORMALIZATION_INDEX_SUB_0]->impl()->InputsTensor();
|
||||
auto mul1_inputs_tensors =
|
||||
norm_ops[NORMALIZATION_INDEX_MUL_1]->impl()->InputsTensor();
|
||||
std::sort(mean0_inputs_tensors.begin(), mean0_inputs_tensors.end());
|
||||
std::sort(sub0_inputs_tensors.begin(), sub0_inputs_tensors.end());
|
||||
std::sort(mul1_inputs_tensors.begin(), mul1_inputs_tensors.end());
|
||||
std::vector<std::shared_ptr<vx::Tensor>> intersect1, intersect2;
|
||||
std::set_intersection(mean0_inputs_tensors.begin(),
|
||||
mean0_inputs_tensors.end(), sub0_inputs_tensors.begin(),
|
||||
sub0_inputs_tensors.end(),
|
||||
std::back_inserter(intersect1));
|
||||
std::set_intersection(sub0_inputs_tensors.begin(), sub0_inputs_tensors.end(),
|
||||
mul1_inputs_tensors.begin(), mul1_inputs_tensors.end(),
|
||||
std::back_inserter(intersect2));
|
||||
if (intersect2.empty())
|
||||
return false;
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
||||
|
|
@ -180,12 +220,15 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
|||
axis = src_tensor->GetShape().size() - axis - 1; // reverse axis
|
||||
|
||||
// Get eps, gamma,beta;
|
||||
// Do datatype convert due to InstanceNormlization op requirements
|
||||
int32_t eps_index = graph->GetProducerOp(
|
||||
norm_ops[5]->impl()->InputsTensor()[0]) == norm_ops[4]
|
||||
? 1
|
||||
: 0;
|
||||
auto org_eps = norm_ops[5]->impl()->InputsTensor()[eps_index];
|
||||
// Do datatype convert due to Layernorm op requirements
|
||||
int32_t eps_index =
|
||||
graph->GetProducerOp(
|
||||
norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[0]) ==
|
||||
norm_ops[NORMALIZATION_INDEX_MEAN_1]
|
||||
? 1
|
||||
: 0;
|
||||
auto org_eps =
|
||||
norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[eps_index];
|
||||
if (!org_eps->IsConstTensor()) {
|
||||
org_eps = graph->GetProducerOp(org_eps)->impl()->InputsTensor()[0];
|
||||
}
|
||||
|
|
@ -203,8 +246,8 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
|||
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
|
||||
vx::TensorAttribute::CONSTANT);
|
||||
|
||||
auto beta = graph->CreateTensor(param_spec);
|
||||
auto gamma = graph->CreateTensor(param_spec);
|
||||
auto beta = graph->CreateTensor(param_spec, float_beta);
|
||||
auto gamma = graph->CreateTensor(param_spec, float_gamma);
|
||||
float eps = *float_eps;
|
||||
beta->CopyDataToTensor(float_beta);
|
||||
gamma->CopyDataToTensor(float_gamma);
|
||||
|
|
@ -222,7 +265,7 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
|||
|
||||
void InstancenormConnection(
|
||||
std::shared_ptr<vx::Graph>& graph,
|
||||
std::vector<std::shared_ptr<vx::Operation>> norm_ops) {
|
||||
const std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||
auto src_tensor =
|
||||
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor()[0];
|
||||
auto final_tensor =
|
||||
|
|
@ -257,11 +300,9 @@ void InstancenormConnection(
|
|||
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
|
||||
vx::TensorAttribute::CONSTANT);
|
||||
|
||||
auto beta = graph->CreateTensor(param_spec);
|
||||
auto gamma = graph->CreateTensor(param_spec);
|
||||
auto beta = graph->CreateTensor(param_spec, float_beta);
|
||||
auto gamma = graph->CreateTensor(param_spec, float_gamma);
|
||||
float eps = *float_eps;
|
||||
beta->CopyDataToTensor(float_beta);
|
||||
gamma->CopyDataToTensor(float_gamma);
|
||||
vsi_nn_Free(float_gamma);
|
||||
vsi_nn_Free(float_beta);
|
||||
vsi_nn_Free(float_eps);
|
||||
|
|
@ -299,9 +340,11 @@ void InstancenormConnection(
|
|||
output
|
||||
*/
|
||||
void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& graph) {
|
||||
std::vector<std::shared_ptr<vx::Operation>> op_vector = graph->OpVector();
|
||||
std::vector<std::shared_ptr<vx::Operation>>& op_vector = graph->OpVector();
|
||||
|
||||
for (const auto& op : op_vector) {
|
||||
for (auto& op : op_vector) {
|
||||
if (std::find(op_vector.begin(), op_vector.end(), op) == op_vector.end())
|
||||
continue; //Avoid read deleted data in op_vector
|
||||
if (op->impl()->kind_ != VSI_NN_OP_REDUCE) continue;
|
||||
|
||||
std::vector<std::shared_ptr<vx::Operation>> temp;
|
||||
|
|
@ -328,14 +371,8 @@ void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& graph) {
|
|||
{VSI_NN_OP_MULTIPLY}))
|
||||
continue; //Mul0
|
||||
|
||||
if (!CheckMediumMul(graph, temp)) continue;
|
||||
if (!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0],
|
||||
temp[NORMALIZATION_INDEX_SUB_0]) &&
|
||||
!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0],
|
||||
temp[NORMALIZATION_INDEX_MUL_1]) &&
|
||||
!HaveASameInput(temp[NORMALIZATION_INDEX_SUB_0],
|
||||
temp[NORMALIZATION_INDEX_MUL_1]))
|
||||
continue;
|
||||
if (!CheckMul0(graph, temp)) continue;
|
||||
if (!CheckMean0Sub0Mul1SameInput(temp)) continue;
|
||||
|
||||
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MUL_1, graph,
|
||||
{VSI_NN_OP_ADD}))
|
||||
|
|
|
|||
Loading…
Reference in New Issue