Code refinement for mean-stddev-normalization fuse (#632)
1.Added copyright && Added reference or const reference for functions 2.Rewrite function of determing whether there is a common input 3.Use std::remove_if instead of std::find before doing erase 4.Added security check to prevent access to deleted ops Type: Code Improvement Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
af50cc5e3f
commit
2f018cc088
|
|
@ -1,3 +1,26 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2020-2023 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <stdarg.h>
|
#include <stdarg.h>
|
||||||
#include "tim/transform/mean_stddev_normalize_fusion.h"
|
#include "tim/transform/mean_stddev_normalize_fusion.h"
|
||||||
|
|
@ -84,10 +107,13 @@ bool UpdateTempVector(std::vector<std::shared_ptr<vx::Operation>>& temp,
|
||||||
// Remove ops and tensors in each matched normlization patten
|
// Remove ops and tensors in each matched normlization patten
|
||||||
void RemoveTensorsAndOps(
|
void RemoveTensorsAndOps(
|
||||||
std::shared_ptr<vx::Graph>& graph,
|
std::shared_ptr<vx::Graph>& graph,
|
||||||
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
const std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||||
for (uint32_t i = 0; i < norm_ops.size(); i++) {
|
for (uint32_t i = 0; i < norm_ops.size(); i++) {
|
||||||
auto it = std::find(graph->OpVector().begin(), graph->OpVector().end(),
|
auto it =
|
||||||
norm_ops[i]);
|
std::remove_if(graph->OpVector().begin(), graph->OpVector().end(),
|
||||||
|
[norm_ops, i](std::shared_ptr<vx::Operation> oper) {
|
||||||
|
return oper == norm_ops[i];
|
||||||
|
});
|
||||||
graph->OpVector().erase(it); //Remove current op from op_vector_
|
graph->OpVector().erase(it); //Remove current op from op_vector_
|
||||||
auto input_tensors = norm_ops[i]->impl()->InputsTensor();
|
auto input_tensors = norm_ops[i]->impl()->InputsTensor();
|
||||||
auto output_tensors = norm_ops[i]->impl()->OutputsTensor();
|
auto output_tensors = norm_ops[i]->impl()->OutputsTensor();
|
||||||
|
|
@ -136,7 +162,7 @@ void RemoveTensorsAndOps(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CheckMediumMul(const std::shared_ptr<vx::Graph>& graph,
|
bool CheckMul0(const std::shared_ptr<vx::Graph>& graph,
|
||||||
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||||
auto mul0_output_tensor =
|
auto mul0_output_tensor =
|
||||||
norm_ops[NORMALIZATION_INDEX_MUL_0]->impl()->OutputsTensor();
|
norm_ops[NORMALIZATION_INDEX_MUL_0]->impl()->OutputsTensor();
|
||||||
|
|
@ -156,15 +182,29 @@ bool CheckMediumMul(const std::shared_ptr<vx::Graph>& graph,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HaveASameInput(const std::shared_ptr<vx::Operation>& op1,
|
bool CheckMean0Sub0Mul1SameInput(
|
||||||
const std::shared_ptr<vx::Operation>& op2) {
|
std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||||
auto Left = op1->impl()->InputsTensor();
|
auto mean0_inputs_tensors =
|
||||||
auto Right = op2->impl()->InputsTensor();
|
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor();
|
||||||
for (auto left_tensor : Left) {
|
auto sub0_inputs_tensors =
|
||||||
if (std::find(Right.begin(), Right.end(), left_tensor) != Right.end())
|
norm_ops[NORMALIZATION_INDEX_SUB_0]->impl()->InputsTensor();
|
||||||
return true;
|
auto mul1_inputs_tensors =
|
||||||
}
|
norm_ops[NORMALIZATION_INDEX_MUL_1]->impl()->InputsTensor();
|
||||||
|
std::sort(mean0_inputs_tensors.begin(), mean0_inputs_tensors.end());
|
||||||
|
std::sort(sub0_inputs_tensors.begin(), sub0_inputs_tensors.end());
|
||||||
|
std::sort(mul1_inputs_tensors.begin(), mul1_inputs_tensors.end());
|
||||||
|
std::vector<std::shared_ptr<vx::Tensor>> intersect1, intersect2;
|
||||||
|
std::set_intersection(mean0_inputs_tensors.begin(),
|
||||||
|
mean0_inputs_tensors.end(), sub0_inputs_tensors.begin(),
|
||||||
|
sub0_inputs_tensors.end(),
|
||||||
|
std::back_inserter(intersect1));
|
||||||
|
std::set_intersection(sub0_inputs_tensors.begin(), sub0_inputs_tensors.end(),
|
||||||
|
mul1_inputs_tensors.begin(), mul1_inputs_tensors.end(),
|
||||||
|
std::back_inserter(intersect2));
|
||||||
|
if (intersect2.empty())
|
||||||
return false;
|
return false;
|
||||||
|
else
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
||||||
|
|
@ -180,12 +220,15 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
||||||
axis = src_tensor->GetShape().size() - axis - 1; // reverse axis
|
axis = src_tensor->GetShape().size() - axis - 1; // reverse axis
|
||||||
|
|
||||||
// Get eps, gamma,beta;
|
// Get eps, gamma,beta;
|
||||||
// Do datatype convert due to InstanceNormlization op requirements
|
// Do datatype convert due to Layernorm op requirements
|
||||||
int32_t eps_index = graph->GetProducerOp(
|
int32_t eps_index =
|
||||||
norm_ops[5]->impl()->InputsTensor()[0]) == norm_ops[4]
|
graph->GetProducerOp(
|
||||||
|
norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[0]) ==
|
||||||
|
norm_ops[NORMALIZATION_INDEX_MEAN_1]
|
||||||
? 1
|
? 1
|
||||||
: 0;
|
: 0;
|
||||||
auto org_eps = norm_ops[5]->impl()->InputsTensor()[eps_index];
|
auto org_eps =
|
||||||
|
norm_ops[NORMALIZATION_INDEX_ADD_0]->impl()->InputsTensor()[eps_index];
|
||||||
if (!org_eps->IsConstTensor()) {
|
if (!org_eps->IsConstTensor()) {
|
||||||
org_eps = graph->GetProducerOp(org_eps)->impl()->InputsTensor()[0];
|
org_eps = graph->GetProducerOp(org_eps)->impl()->InputsTensor()[0];
|
||||||
}
|
}
|
||||||
|
|
@ -203,8 +246,8 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
||||||
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
|
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
|
||||||
vx::TensorAttribute::CONSTANT);
|
vx::TensorAttribute::CONSTANT);
|
||||||
|
|
||||||
auto beta = graph->CreateTensor(param_spec);
|
auto beta = graph->CreateTensor(param_spec, float_beta);
|
||||||
auto gamma = graph->CreateTensor(param_spec);
|
auto gamma = graph->CreateTensor(param_spec, float_gamma);
|
||||||
float eps = *float_eps;
|
float eps = *float_eps;
|
||||||
beta->CopyDataToTensor(float_beta);
|
beta->CopyDataToTensor(float_beta);
|
||||||
gamma->CopyDataToTensor(float_gamma);
|
gamma->CopyDataToTensor(float_gamma);
|
||||||
|
|
@ -222,7 +265,7 @@ void LayernormConnection(std::shared_ptr<vx::Graph>& graph,
|
||||||
|
|
||||||
void InstancenormConnection(
|
void InstancenormConnection(
|
||||||
std::shared_ptr<vx::Graph>& graph,
|
std::shared_ptr<vx::Graph>& graph,
|
||||||
std::vector<std::shared_ptr<vx::Operation>> norm_ops) {
|
const std::vector<std::shared_ptr<vx::Operation>>& norm_ops) {
|
||||||
auto src_tensor =
|
auto src_tensor =
|
||||||
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor()[0];
|
norm_ops[NORMALIZATION_INDEX_MEAN_0]->impl()->InputsTensor()[0];
|
||||||
auto final_tensor =
|
auto final_tensor =
|
||||||
|
|
@ -257,11 +300,9 @@ void InstancenormConnection(
|
||||||
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
|
vx::TensorSpec param_spec(vx::DataType::FLOAT32, shape,
|
||||||
vx::TensorAttribute::CONSTANT);
|
vx::TensorAttribute::CONSTANT);
|
||||||
|
|
||||||
auto beta = graph->CreateTensor(param_spec);
|
auto beta = graph->CreateTensor(param_spec, float_beta);
|
||||||
auto gamma = graph->CreateTensor(param_spec);
|
auto gamma = graph->CreateTensor(param_spec, float_gamma);
|
||||||
float eps = *float_eps;
|
float eps = *float_eps;
|
||||||
beta->CopyDataToTensor(float_beta);
|
|
||||||
gamma->CopyDataToTensor(float_gamma);
|
|
||||||
vsi_nn_Free(float_gamma);
|
vsi_nn_Free(float_gamma);
|
||||||
vsi_nn_Free(float_beta);
|
vsi_nn_Free(float_beta);
|
||||||
vsi_nn_Free(float_eps);
|
vsi_nn_Free(float_eps);
|
||||||
|
|
@ -299,9 +340,11 @@ void InstancenormConnection(
|
||||||
output
|
output
|
||||||
*/
|
*/
|
||||||
void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& graph) {
|
void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& graph) {
|
||||||
std::vector<std::shared_ptr<vx::Operation>> op_vector = graph->OpVector();
|
std::vector<std::shared_ptr<vx::Operation>>& op_vector = graph->OpVector();
|
||||||
|
|
||||||
for (const auto& op : op_vector) {
|
for (auto& op : op_vector) {
|
||||||
|
if (std::find(op_vector.begin(), op_vector.end(), op) == op_vector.end())
|
||||||
|
continue; //Avoid read deleted data in op_vector
|
||||||
if (op->impl()->kind_ != VSI_NN_OP_REDUCE) continue;
|
if (op->impl()->kind_ != VSI_NN_OP_REDUCE) continue;
|
||||||
|
|
||||||
std::vector<std::shared_ptr<vx::Operation>> temp;
|
std::vector<std::shared_ptr<vx::Operation>> temp;
|
||||||
|
|
@ -328,14 +371,8 @@ void MeanStdDevNormalization(std::shared_ptr<vx::Graph>& graph) {
|
||||||
{VSI_NN_OP_MULTIPLY}))
|
{VSI_NN_OP_MULTIPLY}))
|
||||||
continue; //Mul0
|
continue; //Mul0
|
||||||
|
|
||||||
if (!CheckMediumMul(graph, temp)) continue;
|
if (!CheckMul0(graph, temp)) continue;
|
||||||
if (!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0],
|
if (!CheckMean0Sub0Mul1SameInput(temp)) continue;
|
||||||
temp[NORMALIZATION_INDEX_SUB_0]) &&
|
|
||||||
!HaveASameInput(temp[NORMALIZATION_INDEX_MEAN_0],
|
|
||||||
temp[NORMALIZATION_INDEX_MUL_1]) &&
|
|
||||||
!HaveASameInput(temp[NORMALIZATION_INDEX_SUB_0],
|
|
||||||
temp[NORMALIZATION_INDEX_MUL_1]))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MUL_1, graph,
|
if (!UpdateTempVector(temp, NORMALIZATION_INDEX_MUL_1, graph,
|
||||||
{VSI_NN_OP_ADD}))
|
{VSI_NN_OP_ADD}))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue