Add a pass to decompose ONNX operations (#9)

This commit is contained in:
Tung D. Le 2020-03-05 00:53:59 +09:00 committed by GitHub
parent 7c1dd0279b
commit e97df0b343
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 191 additions and 115 deletions

View File

@ -51,8 +51,7 @@ OpsWithShapeInference = [
# Operations supporting canonicalization. # Operations supporting canonicalization.
OpsWithCanonicalizer = [ OpsWithCanonicalizer = [
'Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'Add', 'Identity', 'Gemm'
'ReduceLogSumExp', 'ReduceSumSquare', 'Gemm'
] ]
# Add an Op in this list if the Op needs result type deduction which is required # Add an Op in this list if the Op needs result type deduction which is required

View File

@ -11,6 +11,7 @@ add_library(compiler
dialect/onnx/onnxop.inc dialect/onnx/onnxop.inc
pass/onnx_combine.cpp pass/onnx_combine.cpp
pass/onnx_rewrite.cpp pass/onnx_rewrite.cpp
pass/onnx_decompose.cpp
pass/passes.hpp) pass/passes.hpp)
# Include root src directory. # Include root src directory.
@ -25,6 +26,11 @@ target_link_libraries(compiler
${MLIRLibs} ${MLIRLibs}
curses) curses)
set(LLVM_TARGET_DEFINITIONS pass/onnx_decompose.td)
onnf_tablegen(onnx_decompose.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_decompose)
add_dependencies(compiler gen_onnx_decompose)
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td) set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls) onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs) onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
@ -55,6 +61,13 @@ onnf_tablegen(krnl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_krnl_ops) add_public_tablegen_target(gen_krnl_ops)
add_dependencies(compiler gen_krnl_ops) add_dependencies(compiler gen_krnl_ops)
add_library(onnf_onnx_decompose pass/onnx_decompose.cpp)
target_include_directories(onnf_onnx_decompose
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
add_dependencies(onnf_onnx_decompose gen_krnl_ops)
add_library(onnf_shape_inference pass/shape_inference_pass.cpp) add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
target_include_directories(onnf_shape_inference target_include_directories(onnf_shape_inference
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
@ -90,7 +103,7 @@ add_subdirectory(runtime)
add_executable(onnf main.cpp) add_executable(onnf main.cpp)
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend) target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_onnx_decompose onnf_shape_inference onnf_lower_frontend)
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
find_package(ZLIB REQUIRED) find_package(ZLIB REQUIRED)
target_link_libraries(onnf ${ZLIB_LIBRARIES}) target_link_libraries(onnf ${ZLIB_LIBRARIES})

View File

@ -2296,7 +2296,6 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal",
def ONNXReduceL1Op:ONNX_Op<"ReduceL1", def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceL1 operation"; let summary = "ONNX ReduceL1 operation";
let description = [{ let description = [{
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted" "Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
@ -2314,7 +2313,6 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
def ONNXReduceL2Op:ONNX_Op<"ReduceL2", def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceL2 operation"; let summary = "ONNX ReduceL2 operation";
let description = [{ let description = [{
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted" "Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
@ -2332,7 +2330,6 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceLogSum operation"; let summary = "ONNX ReduceLogSum operation";
let description = [{ let description = [{
"Computes the log sum of the input tensor's element along the provided axes. The resulted" "Computes the log sum of the input tensor's element along the provided axes. The resulted"
@ -2350,7 +2347,6 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceLogSumExp operation"; let summary = "ONNX ReduceLogSumExp operation";
let description = [{ let description = [{
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted" "Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
@ -2465,7 +2461,6 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceSumSquare operation"; let summary = "ONNX ReduceSumSquare operation";
let description = [{ let description = [{
"Computes the sum square of the input tensor's element along the provided axes. The resulted" "Computes the sum square of the input tensor's element along the provided axes. The resulted"

View File

@ -122,8 +122,9 @@ int main(int argc, char *argv[]) {
} }
mlir::PassManager pm(&context); mlir::PassManager pm(&context);
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
if (emissionTarget >= EmitMLIR) { if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());

View File

@ -0,0 +1,65 @@
//===- onnx_decompose.cpp - ONNX High Level Rewriting ---------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a set of rewriters to decompose an ONNX operation into
// composition of other ONNX operations.
//
// This pass is applied before any other pass so that there is no need to
// implement shape inference for the decomposed operation. Hence, it is expected
// that there is no knowledge about tensor shape at this point
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"
using namespace mlir;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/onnx_decompose.inc"
struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> {
void runOnFunction() final;
};
} // end anonymous namespace.
void DecomposeONNXToONNXPass::runOnFunction() {
auto function = getFunction();
MLIRContext *context = &getContext();
ConversionTarget target(getContext());
target.addLegalDialect<ONNXOpsDialect>();
// These ops will be decomposed into other ONNX ops. Hence, they will not be
// available after this pass.
target.addIllegalOp<ONNXReduceL1Op>();
target.addIllegalOp<ONNXReduceL2Op>();
target.addIllegalOp<ONNXReduceLogSumOp>();
target.addIllegalOp<ONNXReduceLogSumExpOp>();
target.addIllegalOp<ONNXReduceSumSquareOp>();
OwningRewritePatternList patterns;
populateWithGenerated(context, &patterns);
if (failed(applyPartialConversion(function, target, patterns)))
signalPassFailure();
} // end anonymous namespace
/*!
* Create a DecomposeONNX pass.
*/
std::unique_ptr<mlir::Pass> mlir::createDecomposeONNXToONNXPass() {
return std::make_unique<DecomposeONNXToONNXPass>();
}
static PassRegistration<DecomposeONNXToONNXPass> pass("decompose-onnx",
"Decompose ONNX operations into composition of other ONNX operations.");

View File

@ -0,0 +1,57 @@
//===----------------------------------------------------------------------===//
//=- onnx_decompose.td - Rewriting for decomposing ONNX Ops -*- tablegen -*===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Defines language-specific pattern match rewritings for ONNX using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//
#ifndef ONNX_DECOMPOSE
#define ONNX_DECOMPOSE
#ifndef OP_BASE
include "dialect/onnx/onnx.td"
#endif // OP_BASE
/// Note: The DRR definition used for defining patterns is shown below:
///
/// class Pattern<
/// dag sourcePattern, list<dag> resultPatterns,
/// list<dag> additionalConstraints = [],
/// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
#endif // ONNX_DECOMPOSE

View File

@ -118,35 +118,6 @@ struct SplitConvOpPattern : public RewritePattern {
}; };
} // end anonymous namespace } // end anonymous namespace
/// on the ONNXReduceL1Op.
void ONNXReduceL1Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceL1OpPattern>(context);
}
/// on the ONNXReduceL2Op.
void ONNXReduceL2Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceL2OpPattern>(context);
}
/// on the ONNXReduceLogSumOp.
void ONNXReduceLogSumOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceLogSumOpPattern>(context);
}
/// on the ONNXReduceLogSumExpOp.
void ONNXReduceLogSumExpOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceLogSumExpOpPattern>(context);
}
/// on the ONNXReduceSumSquareOp.
void ONNXReduceSumSquareOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceSumSquareOpPattern>(context);
}
/// on the ONNXReduceSumSquareOp. /// on the ONNXReduceSumSquareOp.
void ONNXConvNoBiasOp::getCanonicalizationPatterns( void ONNXConvNoBiasOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {

View File

@ -24,34 +24,4 @@ include "dialect/onnx/onnx.td"
/// dag benefitsAdded = (addBenefit 0) /// dag benefitsAdded = (addBenefit 0)
/// >; /// >;
//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
#endif // ONNX_REWRITE #endif // ONNX_REWRITE

View File

@ -15,6 +15,9 @@
namespace mlir { namespace mlir {
class Pass; class Pass;
/// Pass for rewriting inside frontend dialect.
std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
std::unique_ptr<Pass> createShapeInferencePass(); std::unique_ptr<Pass> createShapeInferencePass();
/// Add pass for lowering to Krnl IR. /// Add pass for lowering to Krnl IR.

View File

@ -38,53 +38,6 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) ->
"std.return"(%2) : (tensor<10x10xf32>) -> () "std.return"(%2) : (tensor<10x10xf32>) -> ()
} }
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> { // CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> {
func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> { func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32> // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32>

View File

@ -0,0 +1,49 @@
// RUN: onnf-opt --decompose-onnx %s -split-input-file | FileCheck %s
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}