Rewrite ReduceL1, ReduceL2, ReduceLogSum, ReduceLogSumExp, ReduceSumSquare in the ONNX dialect (#38)
* Rewrite ReduceSumSquare * Edit gen_doc.py * Revise the code * Do shape inference after canonicalization so that there is no need to implement shape inference of rewritten ops * Rewrite ReduceL2 * Add onnx_rewrite.cpp for all rewriting for ONNX ops * Rewrite ReduceL1, ReduceLogSum, ReduceLogSumExp * Edit comments * Change the use of -> to . * Checkout gen_doc.py from the master branch * Use emplace_back instead of push_back * Revise the code * Edit comments Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
0d77840969
commit
2b56c09454
|
@ -47,7 +47,8 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||||
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
|
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
|
||||||
|
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
|
||||||
|
'ReduceLogSumExp', 'ReduceSumSquare']
|
||||||
|
|
||||||
manual_code_in_op_def = dict([
|
manual_code_in_op_def = dict([
|
||||||
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||||
|
|
|
@ -10,6 +10,7 @@ add_library(compiler
|
||||||
pass/shape_inference_interface.hpp
|
pass/shape_inference_interface.hpp
|
||||||
dialect/onnx/onnxop.inc
|
dialect/onnx/onnxop.inc
|
||||||
pass/onnx_combine.cpp
|
pass/onnx_combine.cpp
|
||||||
|
pass/onnx_rewrite.cpp
|
||||||
pass/passes.hpp)
|
pass/passes.hpp)
|
||||||
|
|
||||||
# Include root src directory.
|
# Include root src directory.
|
||||||
|
|
|
@ -2266,6 +2266,7 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal",
|
||||||
|
|
||||||
def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
|
def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX ReduceL1 operation";
|
let summary = "ONNX ReduceL1 operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
|
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
|
||||||
|
@ -2283,6 +2284,7 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
|
||||||
|
|
||||||
def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
|
def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX ReduceL2 operation";
|
let summary = "ONNX ReduceL2 operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
|
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
|
||||||
|
@ -2300,6 +2302,7 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
|
||||||
|
|
||||||
def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
|
def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX ReduceLogSum operation";
|
let summary = "ONNX ReduceLogSum operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the log sum of the input tensor's element along the provided axes. The resulted"
|
"Computes the log sum of the input tensor's element along the provided axes. The resulted"
|
||||||
|
@ -2317,6 +2320,7 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
|
||||||
|
|
||||||
def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
|
def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX ReduceLogSumExp operation";
|
let summary = "ONNX ReduceLogSumExp operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
|
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
|
||||||
|
@ -2419,6 +2423,7 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
|
||||||
|
|
||||||
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX ReduceSumSquare operation";
|
let summary = "ONNX ReduceSumSquare operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the sum square of the input tensor's element along the provided axes. The resulted"
|
"Computes the sum square of the input tensor's element along the provided axes. The resulted"
|
||||||
|
|
|
@ -117,8 +117,8 @@ int main(int argc, char *argv[]) {
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::PassManager pm(&context);
|
mlir::PassManager pm(&context);
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
|
|
@ -0,0 +1,295 @@
|
||||||
|
//===- onnx_rewrite.cpp - ONNX High Level Optimizer -----------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements a set of rewriters for operations in the ONNX dialect
|
||||||
|
// that can be rewritten by using other ONNX operations.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// There are two ways to write rewrite rules:
|
||||||
|
// - Declarative manner: specify rewrite rules in a TableGen record, and
|
||||||
|
// - Manual Manner: subclass the mlir::RewritePattern.
|
||||||
|
//
|
||||||
|
// We prefer to use the former way as much as possible. However, there is a
|
||||||
|
// limitation about operation definition specification (ODS) in TableGen that
|
||||||
|
// requires us to write custom builders, that is
|
||||||
|
// "all ODS-generated `build()` methods require specifying the result type(s),
|
||||||
|
// unless the op has known traits like `SameOperandsAndResultType` that we can
|
||||||
|
// use to auto-generate a `build()` method with result type deduction".
|
||||||
|
//
|
||||||
|
// More information about the limitation can be found here:
|
||||||
|
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/DeclarativeRewrites.md#building-operations
|
||||||
|
//
|
||||||
|
// Currently, we use the latter way of writing rewrite rules. There are two
|
||||||
|
// reasons for this decision:
|
||||||
|
// - To insert custom builders for operations, it is better to change the script
|
||||||
|
// gen_doc.py to generate all possibles custom builders for a large class of
|
||||||
|
// operations. At the time of this patch created, the gen_doc.py was changing,
|
||||||
|
// so we decided to write manually to reduce conflicts.
|
||||||
|
// - In declarative rewriting, we should deal with optional attributes. E.g. for
|
||||||
|
// to handle optional attributes, but I haven't tried it yet.
|
||||||
|
//
|
||||||
|
// Once we have done the above issues, we will switch to use the declarative
|
||||||
|
// manner.
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct ReduceL1OpPattern : public RewritePattern {
|
||||||
|
ReduceL1OpPattern(MLIRContext *context)
|
||||||
|
: RewritePattern(ONNXReduceL1Op::getOperationName(),
|
||||||
|
{ONNXAbsOp::getOperationName(),
|
||||||
|
ONNXReduceSumOp::getOperationName()},
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto opInput = op->getOperands()[0]; // %X
|
||||||
|
auto opResults = op->getResults();
|
||||||
|
auto opAttrs = op->getAttrs();
|
||||||
|
|
||||||
|
// Rewrite
|
||||||
|
ONNXAbsOp absOp;
|
||||||
|
{
|
||||||
|
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||||
|
absOp = rewriter.create<ONNXAbsOp>(
|
||||||
|
loc, UnrankedTensorType::get(elementType), opInput);
|
||||||
|
}
|
||||||
|
|
||||||
|
ONNXReduceSumOp sumOp;
|
||||||
|
{
|
||||||
|
SmallVector<Type, 4> types;
|
||||||
|
for (auto v : opResults) {
|
||||||
|
types.emplace_back(v.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 1> values;
|
||||||
|
values.emplace_back(absOp.getResult());
|
||||||
|
|
||||||
|
SmallVector<NamedAttribute, 4> attrs;
|
||||||
|
for (auto attr : opAttrs) {
|
||||||
|
attrs.emplace_back(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, sumOp.getResult());
|
||||||
|
return matchSuccess();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct ReduceL2OpPattern : public RewritePattern {
|
||||||
|
ReduceL2OpPattern(MLIRContext *context)
|
||||||
|
: RewritePattern(ONNXReduceL2Op::getOperationName(),
|
||||||
|
{ONNXSqrtOp::getOperationName(),
|
||||||
|
ONNXReduceSumSquareOp::getOperationName()},
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto opInput = op->getOperands()[0]; // %X
|
||||||
|
auto opResults = op->getResults();
|
||||||
|
auto opAttrs = op->getAttrs();
|
||||||
|
|
||||||
|
// Rewrite
|
||||||
|
ONNXReduceSumSquareOp sumSquareOp;
|
||||||
|
{
|
||||||
|
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||||
|
sumSquareOp = rewriter.create<ONNXReduceSumSquareOp>(
|
||||||
|
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
ONNXSqrtOp sqrtOp;
|
||||||
|
{
|
||||||
|
SmallVector<Type, 4> types;
|
||||||
|
for (auto v : opResults) {
|
||||||
|
types.emplace_back(v.getType());
|
||||||
|
}
|
||||||
|
sqrtOp = rewriter.create<ONNXSqrtOp>(loc, types, sumSquareOp.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, sqrtOp.getResult());
|
||||||
|
return matchSuccess();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct ReduceLogSumOpPattern : public RewritePattern {
|
||||||
|
ReduceLogSumOpPattern(MLIRContext *context)
|
||||||
|
: RewritePattern(ONNXReduceLogSumOp::getOperationName(),
|
||||||
|
{ONNXReduceSumOp::getOperationName(),
|
||||||
|
ONNXLogOp::getOperationName()},
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto opInput = op->getOperands()[0]; // %X
|
||||||
|
auto opResults = op->getResults();
|
||||||
|
auto opAttrs = op->getAttrs();
|
||||||
|
|
||||||
|
// Rewrite
|
||||||
|
ONNXReduceSumOp sumOp;
|
||||||
|
{
|
||||||
|
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||||
|
sumOp = rewriter.create<ONNXReduceSumOp>(
|
||||||
|
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
ONNXLogOp logOp;
|
||||||
|
{
|
||||||
|
SmallVector<Type, 4> types;
|
||||||
|
for (auto v : opResults) {
|
||||||
|
types.emplace_back(v.getType());
|
||||||
|
}
|
||||||
|
logOp = rewriter.create<ONNXLogOp>(loc, types, sumOp.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, logOp.getResult());
|
||||||
|
return matchSuccess();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct ReduceLogSumExpOpPattern : public RewritePattern {
|
||||||
|
ReduceLogSumExpOpPattern(MLIRContext *context)
|
||||||
|
: RewritePattern(ONNXReduceLogSumExpOp::getOperationName(),
|
||||||
|
{ONNXExpOp::getOperationName(),
|
||||||
|
ONNXReduceLogSumOp::getOperationName()},
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto opInput = op->getOperands()[0]; // %X
|
||||||
|
auto opResults = op->getResults();
|
||||||
|
auto opAttrs = op->getAttrs();
|
||||||
|
|
||||||
|
// Rewrite
|
||||||
|
ONNXExpOp expOp;
|
||||||
|
{
|
||||||
|
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||||
|
expOp = rewriter.create<ONNXExpOp>(
|
||||||
|
loc, UnrankedTensorType::get(elementType), opInput);
|
||||||
|
}
|
||||||
|
|
||||||
|
ONNXReduceLogSumOp logSumOp;
|
||||||
|
{
|
||||||
|
SmallVector<Type, 4> types;
|
||||||
|
for (auto v : opResults) {
|
||||||
|
types.emplace_back(v.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 1> values;
|
||||||
|
values.emplace_back(expOp.getResult());
|
||||||
|
|
||||||
|
SmallVector<NamedAttribute, 4> attrs;
|
||||||
|
for (auto attr : opAttrs) {
|
||||||
|
attrs.emplace_back(attr);
|
||||||
|
}
|
||||||
|
logSumOp = rewriter.create<ONNXReduceLogSumOp>(loc, types, values, attrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, logSumOp.getResult());
|
||||||
|
return matchSuccess();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct ReduceSumSquareOpPattern : public RewritePattern {
|
||||||
|
ReduceSumSquareOpPattern(MLIRContext *context)
|
||||||
|
: RewritePattern(ONNXReduceSumSquareOp::getOperationName(),
|
||||||
|
{ONNXMulOp::getOperationName(),
|
||||||
|
ONNXReduceSumOp::getOperationName()},
|
||||||
|
1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto opInput = op->getOperands()[0]; // %X
|
||||||
|
auto opResults = op->getResults();
|
||||||
|
auto opAttrs = op->getAttrs();
|
||||||
|
|
||||||
|
// Rewrite
|
||||||
|
ONNXMulOp mulOp;
|
||||||
|
{
|
||||||
|
auto elementType = opInput.getType().cast<TensorType>().getElementType();
|
||||||
|
mulOp = rewriter.create<ONNXMulOp>(
|
||||||
|
loc, UnrankedTensorType::get(elementType), opInput, opInput);
|
||||||
|
}
|
||||||
|
|
||||||
|
ONNXReduceSumOp sumOp;
|
||||||
|
{
|
||||||
|
SmallVector<Type, 4> types;
|
||||||
|
for (auto v : opResults) {
|
||||||
|
types.emplace_back(v.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 1> values;
|
||||||
|
values.emplace_back(mulOp.getResult());
|
||||||
|
|
||||||
|
SmallVector<NamedAttribute, 4> attrs;
|
||||||
|
for (auto attr : opAttrs) {
|
||||||
|
attrs.emplace_back(attr);
|
||||||
|
}
|
||||||
|
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, sumOp.getResult());
|
||||||
|
return matchSuccess();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/// on the ONNXReduceL1Op.
|
||||||
|
void ONNXReduceL1Op::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<ReduceL1OpPattern>(context);
|
||||||
|
}
|
||||||
|
/// on the ONNXReduceL2Op.
|
||||||
|
void ONNXReduceL2Op::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<ReduceL2OpPattern>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// on the ONNXReduceLogSumOp.
|
||||||
|
void ONNXReduceLogSumOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<ReduceLogSumOpPattern>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// on the ONNXReduceLogSumExpOp.
|
||||||
|
void ONNXReduceLogSumExpOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<ReduceLogSumExpOpPattern>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// on the ONNXReduceSumSquareOp.
|
||||||
|
void ONNXReduceSumSquareOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<ReduceSumSquareOpPattern>(context);
|
||||||
|
}
|
|
@ -27,3 +27,50 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) ->
|
||||||
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
"std.return"(%2) : (tensor<10x10xf32>) -> ()
|
"std.return"(%2) : (tensor<10x10xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue