[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)

* Lower ONNXSumOp

* Add inferShapes() and test cases

* Load the first operand to the result

* Update SharingWork.md

* Update SharingWork.md

* Update SharingWork.md

* Add support for Max, Min

* Pass operation instead of location to mapToLowerScalarOp

* Add support for Elu, Selu, LeakyRelu, HardSigmoid

* Add test cases

* Update SharingWork.md

* Rewrite the part of lowering variadic ops and use it for binary ops

* Use two diffenrent templates for Unary and Variadic Ops

* Revise the code
This commit is contained in:
TUNG LEDUC 2019-12-12 11:49:50 +09:00 committed by Tian Jin
parent fb1b43f842
commit 5ed79083d5
8 changed files with 924 additions and 68 deletions

View File

@ -6,27 +6,34 @@ ONNX operations for which some work is needed.
* M for multi-broadcast, U for unidir-broadcast * M for multi-broadcast, U for unidir-broadcast
ONNX Oper | Person working on it | ONNX 2 KRNL | Basic functionality | Extended functionality (e.g. broadcast) | ONNX Oper | Person working on it | ONNX 2 KRNL | Basic functionality | Extended functionality (e.g. broadcast) |
----------|----------------------|--------------|---------------------|---------------------------------------- | ---------- | --------------------- | -------------- | --------------------- | ---------------------------------------- |
Add | ? | v | v | noM | Add | Tung (updated) | v | v | noM |
And | ? | v | v | noM | And | Tung | v | v | noM |
Cosh | ? | v | v | noM | Cosh | Tung | v | v | noM |
Div | ? | v | v | | Div | Tung | v | v | |
Exp | ? | v | v | | Elu | Tung | v | v | |
FullGemm | | | | noU | Exp | Tung | v | v | |
Gemm | | | | noU | FullGemm | | | | noU |
MatMul | | | | noM | Gemm | | | | noU |
Mul | ? | v | v | noM | HardSigmoid | Tung | v | v | |
Or | ? | v | v | noM | LeakyRelu | Tung | v | v | |
Relu | ? | v | v | | MatMul | | | | noM |
Sigmoid | ? | v | v | | Max | Tung | v | v | noM |
Sinh | ? | v | v | | Min | Tung | v | v | noM |
Sub | ? | v | v | noM | Mul | Tung | v | v | noM |
Tanh | ? | v | v | | Or | Tung | v | v | noM |
Xor | ? | v | v | noM | Relu | Tung | v | v | |
| Selu | Tung | v | v | |
| Sigmoid | Tung | v | v | |
| Sinh | Tung | v | v | |
| Sub | Tung | v | v | noM |
| Sum | Tung | v | v | noM |
| Tanh | Tung | v | v | |
| Xor | Tung | v | v | noM |
ONNX operations for which the work is completed (full functionality) and tested ONNX operations for which the work is completed (full functionality) and tested
ONNX Oper | Person working on it | Initial work | Basic functionality | Extended functionality (e.g. broadcast) | ONNX Oper | Person working on it | Initial work | Basic functionality | Extended functionality (e.g. broadcast) |
----------|----------------------|--------------|---------------------|---------------------------------------- | ---------- | ---------------------- | -------------- | --------------------- | ---------------------------------------- |

View File

@ -265,7 +265,8 @@ def collect_types(schema, input) :
def gen_schema(schema) : def gen_schema(schema) :
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'MatMul', 'Gemm'] 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
line_indent = ' ' line_indent = ' '

View File

@ -70,6 +70,14 @@ void ONNXCoshOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult()->setType(getOperand()->getType());
} }
//===----------------------------------------------------------------------===//
// HardSigmoid
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
/// the shape inference interface.
void ONNXHardSigmoidOp::inferShapes() {
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sigmoid // Sigmoid
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// Infer the output shape of the ONNXSigmoidOp. This method is required by the
@ -78,6 +86,14 @@ void ONNXSigmoidOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult()->setType(getOperand()->getType());
} }
//===----------------------------------------------------------------------===//
// Elu
/// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface.
void ONNXEluOp::inferShapes() {
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Relu // Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the /// Infer the output shape of the ONNXReluOp. This method is required by the
@ -86,6 +102,22 @@ void ONNXReluOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult()->setType(getOperand()->getType());
} }
//===----------------------------------------------------------------------===//
// LeakyRelu
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
/// the shape inference interface.
void ONNXLeakyReluOp::inferShapes() {
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===//
// Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface.
void ONNXSeluOp::inferShapes() {
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the
@ -144,6 +176,32 @@ void ONNXXorOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Sum
/// Infer the output shape of the ONNXSumOp. This method is required by the
/// shape inference interface.
void ONNXSumOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// Max
/// Infer the output shape of the ONNXMaxOp. This method is required by the
/// shape inference interface.
void ONNXMaxOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// Min
/// Infer the output shape of the ONNXMinOp. This method is required by the
/// shape inference interface.
void ONNXMinOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// MatMul // MatMul
void ONNXMatMulOp::inferShapes() { void ONNXMatMulOp::inferShapes() {

View File

@ -531,7 +531,7 @@ def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear",
} }
def ONNXEluOp:ONNX_Op<"Elu", def ONNXEluOp:ONNX_Op<"Elu",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Elu operation"; let summary = "ONNX Elu operation";
let description = [{ let description = [{
"Elu takes one input data (Tensor<T>) and produces one output data" "Elu takes one input data (Tensor<T>) and produces one output data"
@ -991,7 +991,7 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
} }
def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX HardSigmoid operation"; let summary = "ONNX HardSigmoid operation";
let description = [{ let description = [{
"HardSigmoid takes one input data (Tensor<T>) and produces one output data" "HardSigmoid takes one input data (Tensor<T>) and produces one output data"
@ -1191,7 +1191,7 @@ def ONNXLSTMOp:ONNX_Op<"LSTM",
} }
def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX LeakyRelu operation"; let summary = "ONNX LeakyRelu operation";
let description = [{ let description = [{
"LeakyRelu takes input data (Tensor<T>) and an argument alpha, and produces one" "LeakyRelu takes input data (Tensor<T>) and an argument alpha, and produces one"
@ -1436,7 +1436,7 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger",
} }
def ONNXMaxOp:ONNX_Op<"Max", def ONNXMaxOp:ONNX_Op<"Max",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Max operation"; let summary = "ONNX Max operation";
let description = [{ let description = [{
"Element-wise max of each of the input tensors (with Numpy-style broadcasting support)." "Element-wise max of each of the input tensors (with Numpy-style broadcasting support)."
@ -1548,7 +1548,7 @@ def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization",
} }
def ONNXMinOp:ONNX_Op<"Min", def ONNXMinOp:ONNX_Op<"Min",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Min operation"; let summary = "ONNX Min operation";
let description = [{ let description = [{
"Element-wise min of each of the input tensors (with Numpy-style broadcasting support)." "Element-wise min of each of the input tensors (with Numpy-style broadcasting support)."
@ -2625,7 +2625,7 @@ def ONNXScatterNDOp:ONNX_Op<"ScatterND",
} }
def ONNXSeluOp:ONNX_Op<"Selu", def ONNXSeluOp:ONNX_Op<"Selu",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Selu operation"; let summary = "ONNX Selu operation";
let description = [{ let description = [{
"Selu takes one input data (Tensor<T>) and produces one output data" "Selu takes one input data (Tensor<T>) and produces one output data"
@ -2972,7 +2972,7 @@ def ONNXSubOp:ONNX_Op<"Sub",
} }
def ONNXSumOp:ONNX_Op<"Sum", def ONNXSumOp:ONNX_Op<"Sum",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Sum operation"; let summary = "ONNX Sum operation";
let description = [{ let description = [{
"Element-wise sum of each of the input tensors (with Numpy-style broadcasting support)." "Element-wise sum of each of the input tensors (with Numpy-style broadcasting support)."

View File

@ -148,6 +148,12 @@ struct ScalarOp<ONNXExpOp> {
using IOp = ExpOp; // not use using IOp = ExpOp; // not use
}; };
template <>
struct ScalarOp<ONNXSumOp> {
using FOp = AddFOp;
using IOp = AddIOp;
};
template <typename ElementwiseNaryOp> template <typename ElementwiseNaryOp>
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp; using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
template <typename ElementwiseNaryOp> template <typename ElementwiseNaryOp>
@ -157,11 +163,11 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
// Scalar unary ops for lowering to Krnl dialect. // Scalar unary ops for lowering to Krnl dialect.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <typename UnaryOp> template <typename UnaryOp>
Value* mapToLowerScalarOp(Location loc, ArrayRef<Type> result_types, Value* mapToLowerScalarOp(Operation* op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
/* Lower UnaryOp to Ops in the Standard dialect. /* Lower UnaryOp to Ops in the Standard dialect.
*/ */
auto loc = op->getLoc();
Type element_type = operands.front()->getType(); Type element_type = operands.front()->getType();
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
return rewriter.create<ScalarIOp<UnaryOp>>( return rewriter.create<ScalarIOp<UnaryOp>>(
@ -179,11 +185,14 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXTanhOp // Scalar unary ops for lowering ONNXTanhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXTanhOp>(Location loc, ArrayRef<Type> result_types, Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) {
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
auto loc = op->getLoc();
Value* operand = operands[0]; Value* operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
@ -191,6 +200,7 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Location loc, ArrayRef<Type> result_types,
auto result = auto result =
rewriter.create<DivFOp>(loc, rewriter.create<SubFOp>(loc, exp, negExp), rewriter.create<DivFOp>(loc, rewriter.create<SubFOp>(loc, exp, negExp),
rewriter.create<AddFOp>(loc, exp, negExp)); rewriter.create<AddFOp>(loc, exp, negExp));
return result; return result;
} }
@ -198,11 +208,14 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Location loc, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXSinhOp // Scalar unary ops for lowering ONNXSinhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXSinhOp>(Location loc, ArrayRef<Type> result_types, Value* mapToLowerScalarOp<ONNXSinhOp>(Operation* op,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) {
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc();
Value* operand = operands[0]; Value* operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
@ -210,6 +223,7 @@ Value* mapToLowerScalarOp<ONNXSinhOp>(Location loc, ArrayRef<Type> result_types,
auto negExp = rewriter.create<ExpOp>(loc, neg); auto negExp = rewriter.create<ExpOp>(loc, neg);
auto result = rewriter.create<DivFOp>( auto result = rewriter.create<DivFOp>(
loc, rewriter.create<SubFOp>(loc, exp, negExp), two); loc, rewriter.create<SubFOp>(loc, exp, negExp), two);
return result; return result;
} }
@ -217,11 +231,14 @@ Value* mapToLowerScalarOp<ONNXSinhOp>(Location loc, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXCoshOp // Scalar unary ops for lowering ONNXCoshOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXCoshOp>(Location loc, ArrayRef<Type> result_types, Value* mapToLowerScalarOp<ONNXCoshOp>(Operation* op,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) {
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc();
Value* operand = operands[0]; Value* operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
@ -229,6 +246,7 @@ Value* mapToLowerScalarOp<ONNXCoshOp>(Location loc, ArrayRef<Type> result_types,
auto negExp = rewriter.create<ExpOp>(loc, neg); auto negExp = rewriter.create<ExpOp>(loc, neg);
auto result = rewriter.create<DivFOp>( auto result = rewriter.create<DivFOp>(
loc, rewriter.create<AddFOp>(loc, exp, negExp), two); loc, rewriter.create<AddFOp>(loc, exp, negExp), two);
return result; return result;
} }
@ -236,18 +254,84 @@ Value* mapToLowerScalarOp<ONNXCoshOp>(Location loc, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXSigmoidOp // Scalar unary ops for lowering ONNXSigmoidOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXSigmoidOp>(Location loc, Value* mapToLowerScalarOp<ONNXSigmoidOp>(Operation* op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands, ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) { ConversionPatternRewriter& rewriter) {
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
auto loc = op->getLoc();
Value* operand = operands[0]; Value* operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg); auto negExp = rewriter.create<ExpOp>(loc, neg);
auto result = rewriter.create<DivFOp>( auto result = rewriter.create<DivFOp>(
loc, one, rewriter.create<AddFOp>(loc, one, negExp)); loc, one, rewriter.create<AddFOp>(loc, one, negExp));
return result;
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXHardSigmoidOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) {
// %Y = AddFOp(MulFOp(alpha, %X), beta)
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
// %Y,
// Constant 0)
// ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1),
// %Z,
// Constant 1)
auto loc = op->getLoc();
Value* operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
auto add = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
auto maxPredicate =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, add, zero);
auto max = rewriter.create<SelectOp>(loc, maxPredicate, add, zero);
auto minPredicate =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, max, one);
auto result = rewriter.create<SelectOp>(loc, minPredicate, max, one);
return result;
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXEluOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
// %X)
auto loc = op->getLoc();
Value* operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero,
rewriter.create<MulFOp>(
loc, alpha, rewriter.create<SubFOp>(loc, exp, one)),
operand);
return result; return result;
} }
@ -255,30 +339,122 @@ Value* mapToLowerScalarOp<ONNXSigmoidOp>(Location loc,
// Scalar unary ops for lowering ONNXReluOp // Scalar unary ops for lowering ONNXReluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXReluOp>(Location loc, ArrayRef<Type> result_types, Value* mapToLowerScalarOp<ONNXReluOp>(Operation* op,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) {
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// ConstantOp 0, // ConstantOp 0,
// %X) // %X)
auto loc = op->getLoc();
Value* operand = operands[0]; Value* operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto lessThanZero = auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand); auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
return result; return result;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Element-wise n-ary ops lowering to Krnl dialect. // Scalar unary ops for lowering ONNXLeakyReluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <typename ElementwiseNaryOp, unsigned numArgs> template <>
struct ONNXElementwiseNaryOpLowering : public ConversionPattern { Value* mapToLowerScalarOp<ONNXLeakyReluOp>(Operation* op,
ONNXElementwiseNaryOpLowering(MLIRContext* ctx) ArrayRef<Type> result_types, ArrayRef<Value*> operands,
: ConversionPattern(ElementwiseNaryOp::getOperationName(), 1, ctx) {} ConversionPatternRewriter& rewriter) {
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, %X),
// %X)
auto loc = op->getLoc();
Value* operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(
loc, lessThanZero, rewriter.create<MulFOp>(loc, alpha, operand), operand);
return result;
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXSeluOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) {
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
// MulFOp(gamma, %X),
// MulFOp(gamma,
// SubFOp(MulFOp(alpha, ExpOp(%X)),
// alpha)))
auto loc = op->getLoc();
Value* operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto greaterThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
auto select = rewriter.create<SelectOp>(loc, greaterThanZero, operand,
rewriter.create<SubFOp>(
loc, rewriter.create<MulFOp>(loc, alpha, exp), alpha));
auto result = rewriter.create<MulFOp>(loc, gamma, select);
return result;
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXMaxOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXMaxOp>(Operation* op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
// %X,
// %Y)
auto loc = op->getLoc();
Value* lhs = operands[0];
Value* rhs = operands[1];
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result;
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXMinOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXMinOp>(Operation* op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
// %X,
// %Y)
auto loc = op->getLoc();
Value* lhs = operands[0];
Value* rhs = operands[1];
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result;
}
// Element-wise unary ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOp>
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext* ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise binary operation must have all operands and the result of // An element-wise unary operation must have all operands and the result of
// the same type. This should have been verified by the verifier. // the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>(); auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
@ -287,7 +463,7 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern {
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertTensorToMemRef(tensorType);
// If the output has a dynamic dimension, pass the operands required for // If the output has a dynamic dimension, pass the operands required for
// each dynamic dimension to the AllocOp. The first operand of the binary // each dynamic dimension to the AllocOp. The first operand of the
// operation is used. The operands of the op need to match in terms of // operation is used. The operands of the op need to match in terms of
// dimensions with the result at this pre-optimization phase. // dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match. // TODO: verify that dimensions match.
@ -359,15 +535,9 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern {
for (auto arg : iterationBlock.getArguments()) for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg); loopIVs.push_back(arg);
SmallVector<Value*, numArgs> loadedVals; auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
for (unsigned i = 0; i < numArgs; i++) { auto loweredOpResult = mapToLowerScalarOp<ElementwiseUnaryOp>(
auto loadedVal = rewriter.create<LoadOp>(loc, operands[i], loopIVs); op, memRefType.getElementType(), {loadedVal}, rewriter);
loadedVals.push_back(loadedVal);
}
auto loweredOpResult = mapToLowerScalarOp<ElementwiseNaryOp>(
loc, memRefType.getElementType(), loadedVals, rewriter);
// Store result in the resulting array. // Store result in the resulting array.
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs); rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
@ -377,12 +547,113 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern {
} }
}; };
template <typename ElementwiseNaryOp> // Element-wise variadic ops lowering to Krnl dialect.
using ONNXElementwiseUnaryOpLowering = //===----------------------------------------------------------------------===//
ONNXElementwiseNaryOpLowering<ElementwiseNaryOp, 1>; template <typename ElementwiseVariadicOp>
template <typename ElementwiseNaryOp> struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
using ONNXElementwiseBinaryOpLowering = ONNXElementwiseVariadicOpLowering(MLIRContext* ctx)
ONNXElementwiseNaryOpLowering<ElementwiseNaryOp, 2>; : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
// TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result
// of the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
auto numArgs = op->getNumOperands();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
// If the output has a dynamic dimension, pass the operands required for
// each dynamic dimension to the AllocOp. The first operand of the
// operation is used. The operands of the op need to match in terms of
// dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match.
// TODO: can the dimension of the result differ after optimizations?
Value* alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, operands[0]);
// Number of loops
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value*> originalLoops;
originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
}
// Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value*> optimizedLoops;
optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
}
Block& optimizationBlock = optimizedLoopsOp.region().front();
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest.
// TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
// to KrnlIterateOp instead.
for (int i = 0; i < rank; ++i) {
if (memRefShape[i] < 0) {
pack.pushConstantBound(0);
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], i).getResult());
} else {
pack.pushConstantBound(0);
pack.pushConstantBound(memRefShape[i]);
}
}
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block& iterationBlock = iterateOp.bodyRegion().front();
// Now perform the insertions into the body of the
// just generated instructions:
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
rewriter.setInsertionPointToEnd(&optimizationBlock);
// Return from KrnlOptimizeLoopsOp body.
// When no optimizations are present we just return the loops
// unchaged.
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
// 2. Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation:
SmallVector<Value*, 4> loopIVs;
for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg);
// Fold over operands for each of their scalar values
Value *accumulated, *next;
accumulated = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
for (unsigned i = 1; i < numArgs; i++) {
next = rewriter.create<LoadOp>(loc, operands[i], loopIVs);
accumulated = mapToLowerScalarOp<ElementwiseVariadicOp>(
op, memRefType.getElementType(), {accumulated, next}, rewriter);
}
// Store result in the resulting array.
rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type. // Conversion from Tensor type to the Standard dialect MemRef type.
@ -469,14 +740,21 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXAddOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXMulOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXDivOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXSubOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXAndOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXOrOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXXorOp>>(&getContext()); ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`

View File

@ -93,7 +93,11 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
op->getName().getStringRef() != "onnx.Sinh" && op->getName().getStringRef() != "onnx.Sinh" &&
op->getName().getStringRef() != "onnx.Cosh" && op->getName().getStringRef() != "onnx.Cosh" &&
op->getName().getStringRef() != "onnx.Sigmoid" && op->getName().getStringRef() != "onnx.Sigmoid" &&
op->getName().getStringRef() != "onnx.HardSigmoid" &&
op->getName().getStringRef() != "onnx.Elu" &&
op->getName().getStringRef() != "onnx.Relu" && op->getName().getStringRef() != "onnx.Relu" &&
op->getName().getStringRef() != "onnx.LeakyRelu" &&
op->getName().getStringRef() != "onnx.Selu" &&
op->getName().getStringRef() != "onnx.Mul" && op->getName().getStringRef() != "onnx.Mul" &&
op->getName().getStringRef() != "onnx.Add" && op->getName().getStringRef() != "onnx.Add" &&
op->getName().getStringRef() != "onnx.Div" && op->getName().getStringRef() != "onnx.Div" &&
@ -101,6 +105,9 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
op->getName().getStringRef() != "onnx.And" && op->getName().getStringRef() != "onnx.And" &&
op->getName().getStringRef() != "onnx.Or" && op->getName().getStringRef() != "onnx.Or" &&
op->getName().getStringRef() != "onnx.Xor" && op->getName().getStringRef() != "onnx.Xor" &&
op->getName().getStringRef() != "onnx.Sum" &&
op->getName().getStringRef() != "onnx.Max" &&
op->getName().getStringRef() != "onnx.Min" &&
op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm") op->getName().getStringRef() != "onnx.FullGemm")

View File

@ -278,3 +278,169 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32> // CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }
func @test_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_sum
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_max(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Max"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_max
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_min(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Min"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_min
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_leakyrelu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_selu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32
// CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32
// CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32
// CHECK: store [[SELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_hardsigmoid
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[BETA:%.+]] = constant {{2.+}} : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
// CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32
// CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32
// CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32
// CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32
// CHECK: store [[SELECT2]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}

View File

@ -571,3 +571,342 @@ func @test_relu_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: return [[RET_RES]] : memref<?x10xf32> // CHECK: return [[RET_RES]] : memref<?x10xf32>
} }
func @test_sum_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Sum"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_sum_sum
/// First Sum
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
/// Second Sum
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADD]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_max_max(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Max"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Max"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_max_max
/// First Max
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
/// Second Max
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_min_min(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Min"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Min"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_min_min
/// First Min
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
/// Second Min
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Elu"(%0) {Elu.alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elu_elu
/// First Elu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
/// Second Elu
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
// CHECK: store [[SELECT]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.LeakyRelu"(%0) {LeakyRelu.alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_leakyrelu_leakyrelu
/// First LeakyRelu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
/// Second LeakyRelu
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
// CHECK: store [[SELECT]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Selu"(%0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_selu_selu
/// First Selu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32
// CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32
// CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32
// CHECK: store [[SELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
/// Second Selu
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32
// CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32
// CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32
// CHECK: store [[SELU_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.HardSigmoid"(%0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_hardsigmoid_hardsigmoid
/// First HardSigmoid
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[BETA:%.+]] = constant {{2.+}} : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
// CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32
// CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32
// CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32
// CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32
// CHECK: store [[SELECT2]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
/// Second HardSigmoid
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
// CHECK: [[BETA:%.+]] = constant {{2.+}} : f32
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
// CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32
// CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32
// CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32
// CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32
// CHECK: store [[SELECT2]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}