diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 1c593a5..c654a69 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -51,8 +51,7 @@ OpsWithShapeInference = [ # Operations supporting canonicalization. OpsWithCanonicalizer = [ - 'Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', - 'ReduceLogSumExp', 'ReduceSumSquare', 'Gemm' + 'Add', 'Identity', 'Gemm' ] # Add an Op in this list if the Op needs result type deduction which is required diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b210275..99efe8b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -11,6 +11,7 @@ add_library(compiler dialect/onnx/onnxop.inc pass/onnx_combine.cpp pass/onnx_rewrite.cpp + pass/onnx_decompose.cpp pass/passes.hpp) # Include root src directory. @@ -25,6 +26,11 @@ target_link_libraries(compiler ${MLIRLibs} curses) +set(LLVM_TARGET_DEFINITIONS pass/onnx_decompose.td) +onnf_tablegen(onnx_decompose.inc -gen-rewriters) +add_public_tablegen_target(gen_onnx_decompose) +add_dependencies(compiler gen_onnx_decompose) + set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td) onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls) onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs) @@ -55,6 +61,13 @@ onnf_tablegen(krnl.cpp.inc -gen-op-defs) add_public_tablegen_target(gen_krnl_ops) add_dependencies(compiler gen_krnl_ops) +add_library(onnf_onnx_decompose pass/onnx_decompose.cpp) +target_include_directories(onnf_onnx_decompose + PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} + ${ONNF_SRC_ROOT}) +target_link_libraries(onnf_onnx_decompose ${MLIRLibs}) +add_dependencies(onnf_onnx_decompose gen_krnl_ops) + add_library(onnf_shape_inference pass/shape_inference_pass.cpp) target_include_directories(onnf_shape_inference PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} @@ -90,7 +103,7 @@ add_subdirectory(runtime) add_executable(onnf main.cpp) -target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend) +target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_onnx_decompose onnf_shape_inference onnf_lower_frontend) whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) find_package(ZLIB REQUIRED) target_link_libraries(onnf ${ZLIB_LIBRARIES}) diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index abbda6d..30f00bd 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -2296,7 +2296,6 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal", def ONNXReduceL1Op:ONNX_Op<"ReduceL1", [NoSideEffect]> { - let hasCanonicalizer = 1; let summary = "ONNX ReduceL1 operation"; let description = [{ "Computes the L1 norm of the input tensor's element along the provided axes. The resulted" @@ -2314,7 +2313,6 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1", def ONNXReduceL2Op:ONNX_Op<"ReduceL2", [NoSideEffect]> { - let hasCanonicalizer = 1; let summary = "ONNX ReduceL2 operation"; let description = [{ "Computes the L2 norm of the input tensor's element along the provided axes. The resulted" @@ -2332,7 +2330,6 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2", def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", [NoSideEffect]> { - let hasCanonicalizer = 1; let summary = "ONNX ReduceLogSum operation"; let description = [{ "Computes the log sum of the input tensor's element along the provided axes. The resulted" @@ -2350,7 +2347,6 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", [NoSideEffect]> { - let hasCanonicalizer = 1; let summary = "ONNX ReduceLogSumExp operation"; let description = [{ "Computes the log sum exponent of the input tensor's element along the provided axes. The resulted" @@ -2465,7 +2461,6 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", [NoSideEffect]> { - let hasCanonicalizer = 1; let summary = "ONNX ReduceSumSquare operation"; let description = [{ "Computes the sum square of the input tensor's element along the provided axes. The resulted" diff --git a/src/main.cpp b/src/main.cpp index e99329b..8893382 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -122,8 +122,9 @@ int main(int argc, char *argv[]) { } mlir::PassManager pm(&context); - pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createDecomposeONNXToONNXPass()); pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(mlir::createCanonicalizerPass()); if (emissionTarget >= EmitMLIR) { pm.addPass(mlir::createLowerToKrnlPass()); diff --git a/src/pass/onnx_decompose.cpp b/src/pass/onnx_decompose.cpp new file mode 100644 index 0000000..0949c3b --- /dev/null +++ b/src/pass/onnx_decompose.cpp @@ -0,0 +1,65 @@ +//===- onnx_decompose.cpp - ONNX High Level Rewriting ---------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements a set of rewriters to decompose an ONNX operation into +// composition of other ONNX operations. +// +// This pass is applied before any other pass so that there is no need to +// implement shape inference for the decomposed operation. Hence, it is expected +// that there is no knowledge about tensor shape at this point +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/dialect/onnx/onnx_ops.hpp" +#include "src/pass/passes.hpp" + +using namespace mlir; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "src/onnx_decompose.inc" + +struct DecomposeONNXToONNXPass : public FunctionPass { + void runOnFunction() final; +}; +} // end anonymous namespace. + +void DecomposeONNXToONNXPass::runOnFunction() { + auto function = getFunction(); + MLIRContext *context = &getContext(); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + + // These ops will be decomposed into other ONNX ops. Hence, they will not be + // available after this pass. + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + OwningRewritePatternList patterns; + populateWithGenerated(context, &patterns); + + if (failed(applyPartialConversion(function, target, patterns))) + signalPassFailure(); +} // end anonymous namespace + +/*! + * Create a DecomposeONNX pass. + */ +std::unique_ptr mlir::createDecomposeONNXToONNXPass() { + return std::make_unique(); +} + +static PassRegistration pass("decompose-onnx", + "Decompose ONNX operations into composition of other ONNX operations."); diff --git a/src/pass/onnx_decompose.td b/src/pass/onnx_decompose.td new file mode 100644 index 0000000..087a1f8 --- /dev/null +++ b/src/pass/onnx_decompose.td @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +//=- onnx_decompose.td - Rewriting for decomposing ONNX Ops -*- tablegen -*===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// Defines language-specific pattern match rewritings for ONNX using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// + +#ifndef ONNX_DECOMPOSE +#define ONNX_DECOMPOSE + +#ifndef OP_BASE +include "dialect/onnx/onnx.td" +#endif // OP_BASE + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) +//===----------------------------------------------------------------------===// +def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims), + (ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>; + +//===----------------------------------------------------------------------===// +// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X)) +//===----------------------------------------------------------------------===// +def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims), + (ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>; + +//===----------------------------------------------------------------------===// +// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X)) +//===----------------------------------------------------------------------===// +def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims), + (ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>; + +//===----------------------------------------------------------------------===// +// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X) +//===----------------------------------------------------------------------===// +def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims), + (ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>; + +//===----------------------------------------------------------------------===// +// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X) +//===----------------------------------------------------------------------===// +def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims), + (ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>; + +#endif // ONNX_DECOMPOSE diff --git a/src/pass/onnx_rewrite.cpp b/src/pass/onnx_rewrite.cpp index afe43c1..dcc4dc1 100644 --- a/src/pass/onnx_rewrite.cpp +++ b/src/pass/onnx_rewrite.cpp @@ -118,35 +118,6 @@ struct SplitConvOpPattern : public RewritePattern { }; } // end anonymous namespace -/// on the ONNXReduceL1Op. -void ONNXReduceL1Op::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} -/// on the ONNXReduceL2Op. -void ONNXReduceL2Op::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -/// on the ONNXReduceLogSumOp. -void ONNXReduceLogSumOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -/// on the ONNXReduceLogSumExpOp. -void ONNXReduceLogSumExpOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -/// on the ONNXReduceSumSquareOp. -void ONNXReduceSumSquareOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - /// on the ONNXReduceSumSquareOp. void ONNXConvNoBiasOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { diff --git a/src/pass/onnx_rewrite.td b/src/pass/onnx_rewrite.td index 43dc99c..ab73989 100644 --- a/src/pass/onnx_rewrite.td +++ b/src/pass/onnx_rewrite.td @@ -24,34 +24,4 @@ include "dialect/onnx/onnx.td" /// dag benefitsAdded = (addBenefit 0) /// >; -//===----------------------------------------------------------------------===// -// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) -//===----------------------------------------------------------------------===// -def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims), - (ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>; - -//===----------------------------------------------------------------------===// -// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X)) -//===----------------------------------------------------------------------===// -def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims), - (ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>; - -//===----------------------------------------------------------------------===// -// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X)) -//===----------------------------------------------------------------------===// -def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims), - (ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>; - -//===----------------------------------------------------------------------===// -// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X) -//===----------------------------------------------------------------------===// -def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims), - (ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>; - -//===----------------------------------------------------------------------===// -// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X) -//===----------------------------------------------------------------------===// -def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims), - (ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>; - #endif // ONNX_REWRITE diff --git a/src/pass/passes.hpp b/src/pass/passes.hpp index 26fa543..b7bdc96 100644 --- a/src/pass/passes.hpp +++ b/src/pass/passes.hpp @@ -15,6 +15,9 @@ namespace mlir { class Pass; +/// Pass for rewriting inside frontend dialect. +std::unique_ptr createDecomposeONNXToONNXPass(); + std::unique_ptr createShapeInferencePass(); /// Add pass for lowering to Krnl IR. diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 7661f28..bbc2686 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -38,53 +38,6 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> "std.return"(%2) : (tensor<10x10xf32>) -> () } -// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor) -> tensor<*xf32> -func @test_reducel1(%arg0 : tensor) -> tensor<*xf32> { - %0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor) -> tensor<*xf32> - // CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> -} - -// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor) -> tensor<*xf32> -func @test_reducel2(%arg0 : tensor) -> tensor<*xf32> { - %0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xf32> - // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> - // CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> -} - -// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor) -> tensor<*xf32> -func @test_reducelogsum(%arg0 : tensor) -> tensor<*xf32> { - %0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor) -> tensor<*xf32> - // CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> -} - -// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor) -> tensor<*xf32> -func @test_reducelogsumexp(%arg0 : tensor) -> tensor<*xf32> { - %0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor) -> tensor<*xf32> - // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> - // CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> -} - -// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor) -> tensor<*xf32> -func @test_reducesumsquare(%arg0 : tensor) -> tensor<*xf32> { - %0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xf32> - // CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> -} - // CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor) -> tensor<*xf32> { func @test_constant_pad(%arg0 : tensor) -> tensor<*xf32> { // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor) -> tensor<*xf32> diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir new file mode 100644 index 0000000..f01001b --- /dev/null +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -0,0 +1,49 @@ +// RUN: onnf-opt --decompose-onnx %s -split-input-file | FileCheck %s + +// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducel1(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor) -> tensor<*xf32> + // CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> +} + +// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducel2(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> +} + +// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducelogsum(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor) -> tensor<*xf32> + // CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> +} + +// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducelogsumexp(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor) -> tensor<*xf32> + // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> +} + +// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducesumsquare(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> +} +