From 8894360c74cbbf9465c23ed8147d7b3b62666bcb Mon Sep 17 00:00:00 2001 From: Chen Feiyue <69809761+chenfeiyue-cfy@users.noreply.github.com> Date: Wed, 22 May 2024 09:51:58 +0800 Subject: [PATCH] Adding api-compatible guarding for updated ops (#695) * Fixed typing error in gather test Deleted the external output tensor creation in scalar_1d test Type: Bug Fix Signed-off-by: Feiyue Chen * Added ifdef marco for some later added ops Type: Code Improvement Signed-off-by: Feiyue Chen --------- Signed-off-by: Feiyue Chen --- include/tim/vx/ops/scatternd_onnx_v16.h | 3 +++ include/tim/vx/ops/simple_operations.h | 27 ++++++++++++++----- src/tim/transform/layout_inference.cc | 14 ++++++++-- .../ops/simple_ops_layout_inference.h | 26 ++++++++++++++---- src/tim/vx/ops/gather_test.cc | 3 --- src/tim/vx/ops/scatternd_onnx_v16.cc | 4 ++- src/tim/vx/ops/scatternd_onnx_v16_test.cc | 2 ++ src/tim/vx/ops/simple_operations.cc | 26 ++++++++++++++---- src/tim/vx/ops/simple_operations_test.cc | 6 ++++- 9 files changed, 88 insertions(+), 23 deletions(-) diff --git a/include/tim/vx/ops/scatternd_onnx_v16.h b/include/tim/vx/ops/scatternd_onnx_v16.h index 9698e9a..024c2e2 100644 --- a/include/tim/vx/ops/scatternd_onnx_v16.h +++ b/include/tim/vx/ops/scatternd_onnx_v16.h @@ -25,6 +25,8 @@ #define TIM_VX_OPS_SCATTERND_ONNX_V16_H_ #include "tim/vx/builtin_op.h" +#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE + namespace tim { namespace vx { namespace ops { @@ -57,4 +59,5 @@ class ScatterND_ONNX_V16 : public BuiltinOp { } // namespace vx } // namespace tim +#endif #endif /* TIM_VX_OPS_SCATTERND_ONNX_V16_H_ */ diff --git a/include/tim/vx/ops/simple_operations.h b/include/tim/vx/ops/simple_operations.h index ea26dcf..2f4d526 100644 --- a/include/tim/vx/ops/simple_operations.h +++ b/include/tim/vx/ops/simple_operations.h @@ -65,7 +65,7 @@ namespace ops { * ## ATan * * ATan(x) : arctan(x) - * + * * ## ACosh * * ACosh(x) : arccosh(x) @@ -119,11 +119,6 @@ DECLARE_SIMPLE_OP(DataConvert) DECLARE_SIMPLE_OP(Neg) DECLARE_SIMPLE_OP(Abs) DECLARE_SIMPLE_OP(Sin) -DECLARE_SIMPLE_OP(Cos) -DECLARE_SIMPLE_OP(Tan) -DECLARE_SIMPLE_OP(ATan) -DECLARE_SIMPLE_OP(ATanh) -DECLARE_SIMPLE_OP(ACosh) DECLARE_SIMPLE_OP(Exp) DECLARE_SIMPLE_OP(Log) DECLARE_SIMPLE_OP(Sqrt) @@ -136,6 +131,26 @@ DECLARE_SIMPLE_OP(Round) DECLARE_SIMPLE_OP(Cast) DECLARE_SIMPLE_OP(Rcp) +#ifdef VSI_FEAT_OP_COS +DECLARE_SIMPLE_OP(Cos) +#endif + +#ifdef VSI_FEAT_OP_TAN +DECLARE_SIMPLE_OP(Tan) +#endif + +#ifdef VSI_FEAT_OP_ATAN +DECLARE_SIMPLE_OP(ATan) +#endif + +#ifdef VSI_FEAT_OP_ATANH +DECLARE_SIMPLE_OP(ATanh) +#endif + +#ifdef VSI_FEAT_OP_ACOSH +DECLARE_SIMPLE_OP(ACosh) +#endif + #undef DECLARE_SIMPLE_OP } // namespace ops diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index a464c4d..7479470 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -262,12 +262,22 @@ std::vector> HandleLayoutInfer( REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_NEG, Neg); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ABS, Abs); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SIN, Sin); - REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_COS, Cos); - REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TAN, Tan); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh); +#ifdef VSI_FEAT_OP_COS + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_COS, Cos); +#endif +#ifdef VSI_FEAT_OP_TAN + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TAN, Tan); +#endif +#ifdef VSI_FEAT_OP_ATAN REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ATAN, ATan); +#endif +#ifdef VSI_FEAT_OP_ATANH REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ATANH, ATanh); +#endif +#ifdef VSI_FEAT_OP_ACOSH REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ACOSH, ACosh); +#endif REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_EXP, Exp); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LOG, Log); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SQRT, Sqrt); diff --git a/src/tim/transform/ops/simple_ops_layout_inference.h b/src/tim/transform/ops/simple_ops_layout_inference.h index 7100fc6..01b313a 100644 --- a/src/tim/transform/ops/simple_ops_layout_inference.h +++ b/src/tim/transform/ops/simple_ops_layout_inference.h @@ -60,11 +60,7 @@ using DataConvertLayoutInfer = SimpleOpsLayoutInfer; using NegLayoutInfer = SimpleOpsLayoutInfer; using AbsLayoutInfer = SimpleOpsLayoutInfer; using SinLayoutInfer = SimpleOpsLayoutInfer; -using CosLayoutInfer = SimpleOpsLayoutInfer; -using TanLayoutInfer = SimpleOpsLayoutInfer; -using ATanLayoutInfer = SimpleOpsLayoutInfer; -using ATanhLayoutInfer = SimpleOpsLayoutInfer; -using ACoshLayoutInfer = SimpleOpsLayoutInfer; + using ExpLayoutInfer = SimpleOpsLayoutInfer; using LogLayoutInfer = SimpleOpsLayoutInfer; using SqrtLayoutInfer = SimpleOpsLayoutInfer; @@ -72,6 +68,26 @@ using RsqrtLayoutInfer = SimpleOpsLayoutInfer; using SquareLayoutInfer = SimpleOpsLayoutInfer; using LogicalNotLayoutInfer = SimpleOpsLayoutInfer; +#ifdef VSI_FEAT_OP_COS +using CosLayoutInfer = SimpleOpsLayoutInfer; +#endif + +#ifdef VSI_FEAT_OP_TAN +using TanLayoutInfer = SimpleOpsLayoutInfer; +#endif + +#ifdef VSI_FEAT_OP_ATAN +using ATanLayoutInfer = SimpleOpsLayoutInfer; +#endif + +#ifdef VSI_FEAT_OP_ATANH +using ATanhLayoutInfer = SimpleOpsLayoutInfer; +#endif + +#ifdef VSI_FEAT_OP_ACOSH +using ACoshLayoutInfer = SimpleOpsLayoutInfer; +#endif + } // namespace transform } // namespace tim diff --git a/src/tim/vx/ops/gather_test.cc b/src/tim/vx/ops/gather_test.cc index 205d1cd..f04f394 100644 --- a/src/tim/vx/ops/gather_test.cc +++ b/src/tim/vx/ops/gather_test.cc @@ -168,14 +168,11 @@ TEST(Gather, scalar_index_input1D) { tim::vx::TensorAttribute::INPUT); tim::vx::TensorSpec gatherout_spec(tim::vx::DataType::FLOAT32, out_shape, tim::vx::TensorAttribute::OUTPUT); - tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, out_shape, - tim::vx::TensorAttribute::OUTPUT); auto input_tensor = graph->CreateTensor(input_spec); auto index_tensor = graph->CreateTensor(index_spec); index_tensor->SetScalar(1); auto gatherout_tensor = graph->CreateTensor(gatherout_spec); - auto output_tensor = graph->CreateTensor(output_spec); std::vector in_data = { 1,2,3,4,5}; diff --git a/src/tim/vx/ops/scatternd_onnx_v16.cc b/src/tim/vx/ops/scatternd_onnx_v16.cc index 7e9d737..8f64d54 100644 --- a/src/tim/vx/ops/scatternd_onnx_v16.cc +++ b/src/tim/vx/ops/scatternd_onnx_v16.cc @@ -22,6 +22,7 @@ * *****************************************************************************/ #include "tim/vx/ops/scatternd_onnx_v16.h" +#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE #include "builtin_op_impl.h" #include "vsi_nn_pub.h" @@ -60,4 +61,5 @@ std::shared_ptr ScatterND_ONNX_V16::Clone(std::shared_ptr& gra } // namespace ops } // namespace vx -} // namespace tim \ No newline at end of file +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/vx/ops/scatternd_onnx_v16_test.cc b/src/tim/vx/ops/scatternd_onnx_v16_test.cc index ef8c28e..dff3c85 100644 --- a/src/tim/vx/ops/scatternd_onnx_v16_test.cc +++ b/src/tim/vx/ops/scatternd_onnx_v16_test.cc @@ -24,6 +24,7 @@ #include "tim/vx/context.h" #include "tim/vx/graph.h" #include "tim/vx/ops/scatternd_onnx_v16.h" +#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE #include "gtest/gtest.h" @@ -71,3 +72,4 @@ TEST(ScatterND_ONNX_V16, shape_8) { EXPECT_EQ(golden, output); } +#endif \ No newline at end of file diff --git a/src/tim/vx/ops/simple_operations.cc b/src/tim/vx/ops/simple_operations.cc index d6b881a..ff8a274 100644 --- a/src/tim/vx/ops/simple_operations.cc +++ b/src/tim/vx/ops/simple_operations.cc @@ -40,11 +40,6 @@ DEFINE_SIMPLE_OP(DataConvert, VSI_NN_OP_DATACONVERT) DEFINE_SIMPLE_OP(Neg, VSI_NN_OP_NEG) DEFINE_SIMPLE_OP(Abs, VSI_NN_OP_ABS) DEFINE_SIMPLE_OP(Sin, VSI_NN_OP_SIN) -DEFINE_SIMPLE_OP(Cos, VSI_NN_OP_COS) -DEFINE_SIMPLE_OP(Tan, VSI_NN_OP_TAN) -DEFINE_SIMPLE_OP(ATan, VSI_NN_OP_ATAN) -DEFINE_SIMPLE_OP(ATanh, VSI_NN_OP_ATANH) -DEFINE_SIMPLE_OP(ACosh, VSI_NN_OP_ACOSH) DEFINE_SIMPLE_OP(Exp, VSI_NN_OP_EXP) DEFINE_SIMPLE_OP(Log, VSI_NN_OP_LOG) DEFINE_SIMPLE_OP(Sqrt, VSI_NN_OP_SQRT) @@ -57,6 +52,27 @@ DEFINE_SIMPLE_OP(Round, VSI_NN_OP_ROUND) DEFINE_SIMPLE_OP(Cast, VSI_NN_OP_CAST) DEFINE_SIMPLE_OP(Rcp, VSI_NN_OP_RCP) +#ifdef VSI_FEAT_OP_COS +DEFINE_SIMPLE_OP(Cos, VSI_NN_OP_COS) +#endif + +#ifdef VSI_FEAT_OP_TAN +DEFINE_SIMPLE_OP(Tan, VSI_NN_OP_TAN) +#endif + +#ifdef VSI_FEAT_OP_ATAN +DEFINE_SIMPLE_OP(ATan, VSI_NN_OP_ATAN) +#endif + +#ifdef VSI_FEAT_OP_ATANH +DEFINE_SIMPLE_OP(ATanh, VSI_NN_OP_ATANH) +#endif + +#ifdef VSI_FEAT_OP_ACOSH +DEFINE_SIMPLE_OP(ACosh, VSI_NN_OP_ACOSH) +#endif + + #undef DEFINE_SIMPLE_OP } // namespace ops diff --git a/src/tim/vx/ops/simple_operations_test.cc b/src/tim/vx/ops/simple_operations_test.cc index 81cb19a..7db4ce1 100644 --- a/src/tim/vx/ops/simple_operations_test.cc +++ b/src/tim/vx/ops/simple_operations_test.cc @@ -263,6 +263,7 @@ TEST(Rcp, shape_5_1_fp32) { EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } +#ifdef VSI_FEAT_OP_COS TEST(Cos, shape_5_1_fp32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -290,7 +291,9 @@ TEST(Cos, shape_5_1_fp32) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } +#endif +#ifdef VSI_FEAT_OP_TAN TEST(Tan, shape_5_1_fp32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -317,4 +320,5 @@ TEST(Tan, shape_5_1_fp32) { std::vector output(5); EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-4f)); -} \ No newline at end of file +} +#endif \ No newline at end of file