Add a pass to decompose ONNX operations (#9)
This commit is contained in:
parent
7c1dd0279b
commit
e97df0b343
|
@ -51,8 +51,7 @@ OpsWithShapeInference = [
|
|||
|
||||
# Operations supporting canonicalization.
|
||||
OpsWithCanonicalizer = [
|
||||
'Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
|
||||
'ReduceLogSumExp', 'ReduceSumSquare', 'Gemm'
|
||||
'Add', 'Identity', 'Gemm'
|
||||
]
|
||||
|
||||
# Add an Op in this list if the Op needs result type deduction which is required
|
||||
|
|
|
@ -11,6 +11,7 @@ add_library(compiler
|
|||
dialect/onnx/onnxop.inc
|
||||
pass/onnx_combine.cpp
|
||||
pass/onnx_rewrite.cpp
|
||||
pass/onnx_decompose.cpp
|
||||
pass/passes.hpp)
|
||||
|
||||
# Include root src directory.
|
||||
|
@ -25,6 +26,11 @@ target_link_libraries(compiler
|
|||
${MLIRLibs}
|
||||
curses)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS pass/onnx_decompose.td)
|
||||
onnf_tablegen(onnx_decompose.inc -gen-rewriters)
|
||||
add_public_tablegen_target(gen_onnx_decompose)
|
||||
add_dependencies(compiler gen_onnx_decompose)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
||||
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
||||
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
||||
|
@ -55,6 +61,13 @@ onnf_tablegen(krnl.cpp.inc -gen-op-defs)
|
|||
add_public_tablegen_target(gen_krnl_ops)
|
||||
add_dependencies(compiler gen_krnl_ops)
|
||||
|
||||
add_library(onnf_onnx_decompose pass/onnx_decompose.cpp)
|
||||
target_include_directories(onnf_onnx_decompose
|
||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||
${ONNF_SRC_ROOT})
|
||||
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
|
||||
add_dependencies(onnf_onnx_decompose gen_krnl_ops)
|
||||
|
||||
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
||||
target_include_directories(onnf_shape_inference
|
||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||
|
@ -90,7 +103,7 @@ add_subdirectory(runtime)
|
|||
|
||||
add_executable(onnf main.cpp)
|
||||
|
||||
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend)
|
||||
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_onnx_decompose onnf_shape_inference onnf_lower_frontend)
|
||||
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
||||
find_package(ZLIB REQUIRED)
|
||||
target_link_libraries(onnf ${ZLIB_LIBRARIES})
|
||||
|
|
|
@ -2296,7 +2296,6 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal",
|
|||
|
||||
def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
|
||||
[NoSideEffect]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX ReduceL1 operation";
|
||||
let description = [{
|
||||
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
|
||||
|
@ -2314,7 +2313,6 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
|
|||
|
||||
def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
|
||||
[NoSideEffect]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX ReduceL2 operation";
|
||||
let description = [{
|
||||
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
|
||||
|
@ -2332,7 +2330,6 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
|
|||
|
||||
def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
|
||||
[NoSideEffect]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX ReduceLogSum operation";
|
||||
let description = [{
|
||||
"Computes the log sum of the input tensor's element along the provided axes. The resulted"
|
||||
|
@ -2350,7 +2347,6 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
|
|||
|
||||
def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
|
||||
[NoSideEffect]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX ReduceLogSumExp operation";
|
||||
let description = [{
|
||||
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
|
||||
|
@ -2465,7 +2461,6 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
|
|||
|
||||
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
||||
[NoSideEffect]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX ReduceSumSquare operation";
|
||||
let description = [{
|
||||
"Computes the sum square of the input tensor's element along the provided axes. The resulted"
|
||||
|
|
|
@ -122,8 +122,9 @@ int main(int argc, char *argv[]) {
|
|||
}
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
if (emissionTarget >= EmitMLIR) {
|
||||
pm.addPass(mlir::createLowerToKrnlPass());
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
//===- onnx_decompose.cpp - ONNX High Level Rewriting ---------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a set of rewriters to decompose an ONNX operation into
|
||||
// composition of other ONNX operations.
|
||||
//
|
||||
// This pass is applied before any other pass so that there is no need to
|
||||
// implement shape inference for the decomposed operation. Hence, it is expected
|
||||
// that there is no knowledge about tensor shape at this point
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||
#include "src/pass/passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "src/onnx_decompose.inc"
|
||||
|
||||
struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> {
|
||||
void runOnFunction() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void DecomposeONNXToONNXPass::runOnFunction() {
|
||||
auto function = getFunction();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<ONNXOpsDialect>();
|
||||
|
||||
// These ops will be decomposed into other ONNX ops. Hence, they will not be
|
||||
// available after this pass.
|
||||
target.addIllegalOp<ONNXReduceL1Op>();
|
||||
target.addIllegalOp<ONNXReduceL2Op>();
|
||||
target.addIllegalOp<ONNXReduceLogSumOp>();
|
||||
target.addIllegalOp<ONNXReduceLogSumExpOp>();
|
||||
target.addIllegalOp<ONNXReduceSumSquareOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(context, &patterns);
|
||||
|
||||
if (failed(applyPartialConversion(function, target, patterns)))
|
||||
signalPassFailure();
|
||||
} // end anonymous namespace
|
||||
|
||||
/*!
|
||||
* Create a DecomposeONNX pass.
|
||||
*/
|
||||
std::unique_ptr<mlir::Pass> mlir::createDecomposeONNXToONNXPass() {
|
||||
return std::make_unique<DecomposeONNXToONNXPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<DecomposeONNXToONNXPass> pass("decompose-onnx",
|
||||
"Decompose ONNX operations into composition of other ONNX operations.");
|
|
@ -0,0 +1,57 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//=- onnx_decompose.td - Rewriting for decomposing ONNX Ops -*- tablegen -*===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines language-specific pattern match rewritings for ONNX using
|
||||
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||
//
|
||||
|
||||
#ifndef ONNX_DECOMPOSE
|
||||
#define ONNX_DECOMPOSE
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "dialect/onnx/onnx.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
/// Note: The DRR definition used for defining patterns is shown below:
|
||||
///
|
||||
/// class Pattern<
|
||||
/// dag sourcePattern, list<dag> resultPatterns,
|
||||
/// list<dag> additionalConstraints = [],
|
||||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
|
||||
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
|
||||
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
|
||||
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
|
||||
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
|
||||
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
|
||||
|
||||
#endif // ONNX_DECOMPOSE
|
|
@ -118,35 +118,6 @@ struct SplitConvOpPattern : public RewritePattern {
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// on the ONNXReduceL1Op.
|
||||
void ONNXReduceL1Op::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ReduceL1OpPattern>(context);
|
||||
}
|
||||
/// on the ONNXReduceL2Op.
|
||||
void ONNXReduceL2Op::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ReduceL2OpPattern>(context);
|
||||
}
|
||||
|
||||
/// on the ONNXReduceLogSumOp.
|
||||
void ONNXReduceLogSumOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ReduceLogSumOpPattern>(context);
|
||||
}
|
||||
|
||||
/// on the ONNXReduceLogSumExpOp.
|
||||
void ONNXReduceLogSumExpOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ReduceLogSumExpOpPattern>(context);
|
||||
}
|
||||
|
||||
/// on the ONNXReduceSumSquareOp.
|
||||
void ONNXReduceSumSquareOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ReduceSumSquareOpPattern>(context);
|
||||
}
|
||||
|
||||
/// on the ONNXReduceSumSquareOp.
|
||||
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
|
|
|
@ -24,34 +24,4 @@ include "dialect/onnx/onnx.td"
|
|||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
|
||||
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
|
||||
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
|
||||
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
|
||||
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
|
||||
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
|
||||
|
||||
#endif // ONNX_REWRITE
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
namespace mlir {
|
||||
class Pass;
|
||||
|
||||
/// Pass for rewriting inside frontend dialect.
|
||||
std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
||||
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
|
||||
/// Add pass for lowering to Krnl IR.
|
||||
|
|
|
@ -38,53 +38,6 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) ->
|
|||
"std.return"(%2) : (tensor<10x10xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> {
|
||||
func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32>
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
// RUN: onnf-opt --decompose-onnx %s -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue