Add a pass to decompose ONNX operations (#9)
This commit is contained in:
parent
7c1dd0279b
commit
e97df0b343
|
@ -51,8 +51,7 @@ OpsWithShapeInference = [
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
OpsWithCanonicalizer = [
|
OpsWithCanonicalizer = [
|
||||||
'Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
|
'Add', 'Identity', 'Gemm'
|
||||||
'ReduceLogSumExp', 'ReduceSumSquare', 'Gemm'
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add an Op in this list if the Op needs result type deduction which is required
|
# Add an Op in this list if the Op needs result type deduction which is required
|
||||||
|
|
|
@ -11,6 +11,7 @@ add_library(compiler
|
||||||
dialect/onnx/onnxop.inc
|
dialect/onnx/onnxop.inc
|
||||||
pass/onnx_combine.cpp
|
pass/onnx_combine.cpp
|
||||||
pass/onnx_rewrite.cpp
|
pass/onnx_rewrite.cpp
|
||||||
|
pass/onnx_decompose.cpp
|
||||||
pass/passes.hpp)
|
pass/passes.hpp)
|
||||||
|
|
||||||
# Include root src directory.
|
# Include root src directory.
|
||||||
|
@ -25,6 +26,11 @@ target_link_libraries(compiler
|
||||||
${MLIRLibs}
|
${MLIRLibs}
|
||||||
curses)
|
curses)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS pass/onnx_decompose.td)
|
||||||
|
onnf_tablegen(onnx_decompose.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(gen_onnx_decompose)
|
||||||
|
add_dependencies(compiler gen_onnx_decompose)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
||||||
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
||||||
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
||||||
|
@ -55,6 +61,13 @@ onnf_tablegen(krnl.cpp.inc -gen-op-defs)
|
||||||
add_public_tablegen_target(gen_krnl_ops)
|
add_public_tablegen_target(gen_krnl_ops)
|
||||||
add_dependencies(compiler gen_krnl_ops)
|
add_dependencies(compiler gen_krnl_ops)
|
||||||
|
|
||||||
|
add_library(onnf_onnx_decompose pass/onnx_decompose.cpp)
|
||||||
|
target_include_directories(onnf_onnx_decompose
|
||||||
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
|
${ONNF_SRC_ROOT})
|
||||||
|
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
|
||||||
|
add_dependencies(onnf_onnx_decompose gen_krnl_ops)
|
||||||
|
|
||||||
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
||||||
target_include_directories(onnf_shape_inference
|
target_include_directories(onnf_shape_inference
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
|
@ -90,7 +103,7 @@ add_subdirectory(runtime)
|
||||||
|
|
||||||
add_executable(onnf main.cpp)
|
add_executable(onnf main.cpp)
|
||||||
|
|
||||||
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend)
|
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_onnx_decompose onnf_shape_inference onnf_lower_frontend)
|
||||||
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
||||||
find_package(ZLIB REQUIRED)
|
find_package(ZLIB REQUIRED)
|
||||||
target_link_libraries(onnf ${ZLIB_LIBRARIES})
|
target_link_libraries(onnf ${ZLIB_LIBRARIES})
|
||||||
|
|
|
@ -2296,7 +2296,6 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal",
|
||||||
|
|
||||||
def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
|
def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let summary = "ONNX ReduceL1 operation";
|
let summary = "ONNX ReduceL1 operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
|
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
|
||||||
|
@ -2314,7 +2313,6 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
|
||||||
|
|
||||||
def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
|
def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let summary = "ONNX ReduceL2 operation";
|
let summary = "ONNX ReduceL2 operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
|
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
|
||||||
|
@ -2332,7 +2330,6 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
|
||||||
|
|
||||||
def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
|
def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let summary = "ONNX ReduceLogSum operation";
|
let summary = "ONNX ReduceLogSum operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the log sum of the input tensor's element along the provided axes. The resulted"
|
"Computes the log sum of the input tensor's element along the provided axes. The resulted"
|
||||||
|
@ -2350,7 +2347,6 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
|
||||||
|
|
||||||
def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
|
def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let summary = "ONNX ReduceLogSumExp operation";
|
let summary = "ONNX ReduceLogSumExp operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
|
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
|
||||||
|
@ -2465,7 +2461,6 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
|
||||||
|
|
||||||
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let summary = "ONNX ReduceSumSquare operation";
|
let summary = "ONNX ReduceSumSquare operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the sum square of the input tensor's element along the provided axes. The resulted"
|
"Computes the sum square of the input tensor's element along the provided axes. The resulted"
|
||||||
|
|
|
@ -122,8 +122,9 @@ int main(int argc, char *argv[]) {
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::PassManager pm(&context);
|
mlir::PassManager pm(&context);
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
//===- onnx_decompose.cpp - ONNX High Level Rewriting ---------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements a set of rewriters to decompose an ONNX operation into
|
||||||
|
// composition of other ONNX operations.
|
||||||
|
//
|
||||||
|
// This pass is applied before any other pass so that there is no need to
|
||||||
|
// implement shape inference for the decomposed operation. Hence, it is expected
|
||||||
|
// that there is no knowledge about tensor shape at this point
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||||
|
#include "src/pass/passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
|
#include "src/onnx_decompose.inc"
|
||||||
|
|
||||||
|
struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> {
|
||||||
|
void runOnFunction() final;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace.
|
||||||
|
|
||||||
|
void DecomposeONNXToONNXPass::runOnFunction() {
|
||||||
|
auto function = getFunction();
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<ONNXOpsDialect>();
|
||||||
|
|
||||||
|
// These ops will be decomposed into other ONNX ops. Hence, they will not be
|
||||||
|
// available after this pass.
|
||||||
|
target.addIllegalOp<ONNXReduceL1Op>();
|
||||||
|
target.addIllegalOp<ONNXReduceL2Op>();
|
||||||
|
target.addIllegalOp<ONNXReduceLogSumOp>();
|
||||||
|
target.addIllegalOp<ONNXReduceLogSumExpOp>();
|
||||||
|
target.addIllegalOp<ONNXReduceSumSquareOp>();
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
populateWithGenerated(context, &patterns);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(function, target, patterns)))
|
||||||
|
signalPassFailure();
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Create a DecomposeONNX pass.
|
||||||
|
*/
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::createDecomposeONNXToONNXPass() {
|
||||||
|
return std::make_unique<DecomposeONNXToONNXPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<DecomposeONNXToONNXPass> pass("decompose-onnx",
|
||||||
|
"Decompose ONNX operations into composition of other ONNX operations.");
|
|
@ -0,0 +1,57 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//=- onnx_decompose.td - Rewriting for decomposing ONNX Ops -*- tablegen -*===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Defines language-specific pattern match rewritings for ONNX using
|
||||||
|
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef ONNX_DECOMPOSE
|
||||||
|
#define ONNX_DECOMPOSE
|
||||||
|
|
||||||
|
#ifndef OP_BASE
|
||||||
|
include "dialect/onnx/onnx.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
/// Note: The DRR definition used for defining patterns is shown below:
|
||||||
|
///
|
||||||
|
/// class Pattern<
|
||||||
|
/// dag sourcePattern, list<dag> resultPatterns,
|
||||||
|
/// list<dag> additionalConstraints = [],
|
||||||
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
|
/// >;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
|
||||||
|
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
|
||||||
|
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
|
||||||
|
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
|
||||||
|
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
|
||||||
|
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
|
||||||
|
|
||||||
|
#endif // ONNX_DECOMPOSE
|
|
@ -118,35 +118,6 @@ struct SplitConvOpPattern : public RewritePattern {
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
/// on the ONNXReduceL1Op.
|
|
||||||
void ONNXReduceL1Op::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<ReduceL1OpPattern>(context);
|
|
||||||
}
|
|
||||||
/// on the ONNXReduceL2Op.
|
|
||||||
void ONNXReduceL2Op::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<ReduceL2OpPattern>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// on the ONNXReduceLogSumOp.
|
|
||||||
void ONNXReduceLogSumOp::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<ReduceLogSumOpPattern>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// on the ONNXReduceLogSumExpOp.
|
|
||||||
void ONNXReduceLogSumExpOp::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<ReduceLogSumExpOpPattern>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// on the ONNXReduceSumSquareOp.
|
|
||||||
void ONNXReduceSumSquareOp::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<ReduceSumSquareOpPattern>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// on the ONNXReduceSumSquareOp.
|
/// on the ONNXReduceSumSquareOp.
|
||||||
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
|
|
@ -24,34 +24,4 @@ include "dialect/onnx/onnx.td"
|
||||||
/// dag benefitsAdded = (addBenefit 0)
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
/// >;
|
/// >;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
|
|
||||||
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
|
|
||||||
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
|
|
||||||
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
|
|
||||||
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
|
|
||||||
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;
|
|
||||||
|
|
||||||
#endif // ONNX_REWRITE
|
#endif // ONNX_REWRITE
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class Pass;
|
class Pass;
|
||||||
|
|
||||||
|
/// Pass for rewriting inside frontend dialect.
|
||||||
|
std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createShapeInferencePass();
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
/// Add pass for lowering to Krnl IR.
|
/// Add pass for lowering to Krnl IR.
|
||||||
|
|
|
@ -38,53 +38,6 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) ->
|
||||||
"std.return"(%2) : (tensor<10x10xf32>) -> ()
|
"std.return"(%2) : (tensor<10x10xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
|
||||||
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
|
||||||
|
|
||||||
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
|
||||||
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
|
||||||
|
|
||||||
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
|
||||||
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
|
||||||
|
|
||||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
|
||||||
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
|
||||||
|
|
||||||
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
|
||||||
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
|
||||||
|
|
||||||
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> {
|
// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> {
|
||||||
func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
|
func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32>
|
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32>
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
// RUN: onnf-opt --decompose-onnx %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue