Adding api-compatible guarding for updated ops (#695)

* Fixed typing error in gather test

Deleted the external output tensor creation in scalar_1d test

Type: Bug Fix
Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>

* Added ifdef marco for some later added ops

Type: Code Improvement
Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>

---------

Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Chen Feiyue 2024-05-22 09:51:58 +08:00 committed by GitHub
parent e1c2f0a18d
commit 8894360c74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 88 additions and 23 deletions

View File

@ -25,6 +25,8 @@
#define TIM_VX_OPS_SCATTERND_ONNX_V16_H_
#include "tim/vx/builtin_op.h"
#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE
namespace tim {
namespace vx {
namespace ops {
@ -57,4 +59,5 @@ class ScatterND_ONNX_V16 : public BuiltinOp {
} // namespace vx
} // namespace tim
#endif
#endif /* TIM_VX_OPS_SCATTERND_ONNX_V16_H_ */

View File

@ -119,11 +119,6 @@ DECLARE_SIMPLE_OP(DataConvert)
DECLARE_SIMPLE_OP(Neg)
DECLARE_SIMPLE_OP(Abs)
DECLARE_SIMPLE_OP(Sin)
DECLARE_SIMPLE_OP(Cos)
DECLARE_SIMPLE_OP(Tan)
DECLARE_SIMPLE_OP(ATan)
DECLARE_SIMPLE_OP(ATanh)
DECLARE_SIMPLE_OP(ACosh)
DECLARE_SIMPLE_OP(Exp)
DECLARE_SIMPLE_OP(Log)
DECLARE_SIMPLE_OP(Sqrt)
@ -136,6 +131,26 @@ DECLARE_SIMPLE_OP(Round)
DECLARE_SIMPLE_OP(Cast)
DECLARE_SIMPLE_OP(Rcp)
#ifdef VSI_FEAT_OP_COS
DECLARE_SIMPLE_OP(Cos)
#endif
#ifdef VSI_FEAT_OP_TAN
DECLARE_SIMPLE_OP(Tan)
#endif
#ifdef VSI_FEAT_OP_ATAN
DECLARE_SIMPLE_OP(ATan)
#endif
#ifdef VSI_FEAT_OP_ATANH
DECLARE_SIMPLE_OP(ATanh)
#endif
#ifdef VSI_FEAT_OP_ACOSH
DECLARE_SIMPLE_OP(ACosh)
#endif
#undef DECLARE_SIMPLE_OP
} // namespace ops

View File

@ -262,12 +262,22 @@ std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_NEG, Neg);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ABS, Abs);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SIN, Sin);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_COS, Cos);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TAN, Tan);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh);
#ifdef VSI_FEAT_OP_COS
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_COS, Cos);
#endif
#ifdef VSI_FEAT_OP_TAN
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TAN, Tan);
#endif
#ifdef VSI_FEAT_OP_ATAN
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ATAN, ATan);
#endif
#ifdef VSI_FEAT_OP_ATANH
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ATANH, ATanh);
#endif
#ifdef VSI_FEAT_OP_ACOSH
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ACOSH, ACosh);
#endif
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_EXP, Exp);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LOG, Log);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SQRT, Sqrt);

View File

@ -60,11 +60,7 @@ using DataConvertLayoutInfer = SimpleOpsLayoutInfer<vx::ops::DataConvert>;
using NegLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Neg>;
using AbsLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Abs>;
using SinLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Sin>;
using CosLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Cos>;
using TanLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Tan>;
using ATanLayoutInfer = SimpleOpsLayoutInfer<vx::ops::ATan>;
using ATanhLayoutInfer = SimpleOpsLayoutInfer<vx::ops::ATanh>;
using ACoshLayoutInfer = SimpleOpsLayoutInfer<vx::ops::ACosh>;
using ExpLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Exp>;
using LogLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Log>;
using SqrtLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Sqrt>;
@ -72,6 +68,26 @@ using RsqrtLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Rsqrt>;
using SquareLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Square>;
using LogicalNotLayoutInfer = SimpleOpsLayoutInfer<vx::ops::LogicalNot>;
#ifdef VSI_FEAT_OP_COS
using CosLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Cos>;
#endif
#ifdef VSI_FEAT_OP_TAN
using TanLayoutInfer = SimpleOpsLayoutInfer<vx::ops::Tan>;
#endif
#ifdef VSI_FEAT_OP_ATAN
using ATanLayoutInfer = SimpleOpsLayoutInfer<vx::ops::ATan>;
#endif
#ifdef VSI_FEAT_OP_ATANH
using ATanhLayoutInfer = SimpleOpsLayoutInfer<vx::ops::ATanh>;
#endif
#ifdef VSI_FEAT_OP_ACOSH
using ACoshLayoutInfer = SimpleOpsLayoutInfer<vx::ops::ACosh>;
#endif
} // namespace transform
} // namespace tim

View File

@ -168,14 +168,11 @@ TEST(Gather, scalar_index_input1D) {
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec gatherout_spec(tim::vx::DataType::FLOAT32, out_shape,
tim::vx::TensorAttribute::OUTPUT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, out_shape,
tim::vx::TensorAttribute::OUTPUT);
auto input_tensor = graph->CreateTensor(input_spec);
auto index_tensor = graph->CreateTensor(index_spec);
index_tensor->SetScalar(1);
auto gatherout_tensor = graph->CreateTensor(gatherout_spec);
auto output_tensor = graph->CreateTensor(output_spec);
std::vector<float> in_data = {
1,2,3,4,5};

View File

@ -22,6 +22,7 @@
*
*****************************************************************************/
#include "tim/vx/ops/scatternd_onnx_v16.h"
#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE
#include "builtin_op_impl.h"
#include "vsi_nn_pub.h"
@ -61,3 +62,4 @@ std::shared_ptr<Operation> ScatterND_ONNX_V16::Clone(std::shared_ptr<Graph>& gra
} // namespace ops
} // namespace vx
} // namespace tim
#endif

View File

@ -24,6 +24,7 @@
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/ops/scatternd_onnx_v16.h"
#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE
#include "gtest/gtest.h"
@ -71,3 +72,4 @@ TEST(ScatterND_ONNX_V16, shape_8) {
EXPECT_EQ(golden, output);
}
#endif

View File

@ -40,11 +40,6 @@ DEFINE_SIMPLE_OP(DataConvert, VSI_NN_OP_DATACONVERT)
DEFINE_SIMPLE_OP(Neg, VSI_NN_OP_NEG)
DEFINE_SIMPLE_OP(Abs, VSI_NN_OP_ABS)
DEFINE_SIMPLE_OP(Sin, VSI_NN_OP_SIN)
DEFINE_SIMPLE_OP(Cos, VSI_NN_OP_COS)
DEFINE_SIMPLE_OP(Tan, VSI_NN_OP_TAN)
DEFINE_SIMPLE_OP(ATan, VSI_NN_OP_ATAN)
DEFINE_SIMPLE_OP(ATanh, VSI_NN_OP_ATANH)
DEFINE_SIMPLE_OP(ACosh, VSI_NN_OP_ACOSH)
DEFINE_SIMPLE_OP(Exp, VSI_NN_OP_EXP)
DEFINE_SIMPLE_OP(Log, VSI_NN_OP_LOG)
DEFINE_SIMPLE_OP(Sqrt, VSI_NN_OP_SQRT)
@ -57,6 +52,27 @@ DEFINE_SIMPLE_OP(Round, VSI_NN_OP_ROUND)
DEFINE_SIMPLE_OP(Cast, VSI_NN_OP_CAST)
DEFINE_SIMPLE_OP(Rcp, VSI_NN_OP_RCP)
#ifdef VSI_FEAT_OP_COS
DEFINE_SIMPLE_OP(Cos, VSI_NN_OP_COS)
#endif
#ifdef VSI_FEAT_OP_TAN
DEFINE_SIMPLE_OP(Tan, VSI_NN_OP_TAN)
#endif
#ifdef VSI_FEAT_OP_ATAN
DEFINE_SIMPLE_OP(ATan, VSI_NN_OP_ATAN)
#endif
#ifdef VSI_FEAT_OP_ATANH
DEFINE_SIMPLE_OP(ATanh, VSI_NN_OP_ATANH)
#endif
#ifdef VSI_FEAT_OP_ACOSH
DEFINE_SIMPLE_OP(ACosh, VSI_NN_OP_ACOSH)
#endif
#undef DEFINE_SIMPLE_OP
} // namespace ops

View File

@ -263,6 +263,7 @@ TEST(Rcp, shape_5_1_fp32) {
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
}
#ifdef VSI_FEAT_OP_COS
TEST(Cos, shape_5_1_fp32) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
@ -290,7 +291,9 @@ TEST(Cos, shape_5_1_fp32) {
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
}
#endif
#ifdef VSI_FEAT_OP_TAN
TEST(Tan, shape_5_1_fp32) {
auto ctx = tim::vx::Context::Create();
auto graph = ctx->CreateGraph();
@ -318,3 +321,4 @@ TEST(Tan, shape_5_1_fp32) {
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-4f));
}
#endif