diff --git a/SharingWork.md b/SharingWork.md index 6120d04..fa12559 100644 --- a/SharingWork.md +++ b/SharingWork.md @@ -6,27 +6,34 @@ ONNX operations for which some work is needed. * M for multi-broadcast, U for unidir-broadcast -ONNX Oper | Person working on it | ONNX 2 KRNL | Basic functionality | Extended functionality (e.g. broadcast) -----------|----------------------|--------------|---------------------|---------------------------------------- -Add | ? | v | v | noM -And | ? | v | v | noM -Cosh | ? | v | v | noM -Div | ? | v | v | -Exp | ? | v | v | -FullGemm | | | | noU -Gemm | | | | noU -MatMul | | | | noM -Mul | ? | v | v | noM -Or | ? | v | v | noM -Relu | ? | v | v | -Sigmoid | ? | v | v | -Sinh | ? | v | v | -Sub | ? | v | v | noM -Tanh | ? | v | v | -Xor | ? | v | v | noM +| ONNX Oper | Person working on it | ONNX 2 KRNL | Basic functionality | Extended functionality (e.g. broadcast) | +| ---------- | --------------------- | -------------- | --------------------- | ---------------------------------------- | +| Add | Tung (updated) | v | v | noM | +| And | Tung | v | v | noM | +| Cosh | Tung | v | v | noM | +| Div | Tung | v | v | | +| Elu | Tung | v | v | | +| Exp | Tung | v | v | | +| FullGemm | | | | noU | +| Gemm | | | | noU | +| HardSigmoid | Tung | v | v | | +| LeakyRelu | Tung | v | v | | +| MatMul | | | | noM | +| Max | Tung | v | v | noM | +| Min | Tung | v | v | noM | +| Mul | Tung | v | v | noM | +| Or | Tung | v | v | noM | +| Relu | Tung | v | v | | +| Selu | Tung | v | v | | +| Sigmoid | Tung | v | v | | +| Sinh | Tung | v | v | | +| Sub | Tung | v | v | noM | +| Sum | Tung | v | v | noM | +| Tanh | Tung | v | v | | +| Xor | Tung | v | v | noM | ONNX operations for which the work is completed (full functionality) and tested -ONNX Oper | Person working on it | Initial work | Basic functionality | Extended functionality (e.g. broadcast) -----------|----------------------|--------------|---------------------|---------------------------------------- +| ONNX Oper | Person working on it | Initial work | Basic functionality | Extended functionality (e.g. broadcast) | +| ---------- | ---------------------- | -------------- | --------------------- | ---------------------------------------- | diff --git a/src/compiler/dialect/onnx/gen_doc.py b/src/compiler/dialect/onnx/gen_doc.py index 530ac9a..f5c0ded 100644 --- a/src/compiler/dialect/onnx/gen_doc.py +++ b/src/compiler/dialect/onnx/gen_doc.py @@ -265,7 +265,8 @@ def collect_types(schema, input) : def gen_schema(schema) : ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', - 'MatMul', 'Gemm'] + 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', + 'Elu', 'Selu', 'HardSigmoid'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index 0650a24..a900e18 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -70,6 +70,14 @@ void ONNXCoshOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +//===----------------------------------------------------------------------===// +// HardSigmoid +/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by +/// the shape inference interface. +void ONNXHardSigmoidOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + //===----------------------------------------------------------------------===// // Sigmoid /// Infer the output shape of the ONNXSigmoidOp. This method is required by the @@ -78,6 +86,14 @@ void ONNXSigmoidOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +//===----------------------------------------------------------------------===// +// Elu +/// Infer the output shape of the ONNXEluOp. This method is required by the +/// shape inference interface. +void ONNXEluOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + //===----------------------------------------------------------------------===// // Relu /// Infer the output shape of the ONNXReluOp. This method is required by the @@ -86,6 +102,22 @@ void ONNXReluOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +//===----------------------------------------------------------------------===// +// LeakyRelu +/// Infer the output shape of the ONNXLeakyReluOp. This method is required by +/// the shape inference interface. +void ONNXLeakyReluOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + +//===----------------------------------------------------------------------===// +// Selu +/// Infer the output shape of the ONNXSeluOp. This method is required by +/// the shape inference interface. +void ONNXSeluOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the @@ -144,6 +176,32 @@ void ONNXXorOp::inferShapes() { //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Sum +/// Infer the output shape of the ONNXSumOp. This method is required by the +/// shape inference interface. +void ONNXSumOp::inferShapes() { + getResult()->setType(getOperand(0)->getType()); +} + +//===----------------------------------------------------------------------===// +// Max +/// Infer the output shape of the ONNXMaxOp. This method is required by the +/// shape inference interface. +void ONNXMaxOp::inferShapes() { + getResult()->setType(getOperand(0)->getType()); +} + +//===----------------------------------------------------------------------===// +// Min +/// Infer the output shape of the ONNXMinOp. This method is required by the +/// shape inference interface. +void ONNXMinOp::inferShapes() { + getResult()->setType(getOperand(0)->getType()); +} + +//===----------------------------------------------------------------------===// + // MatMul void ONNXMatMulOp::inferShapes() { diff --git a/src/compiler/dialect/onnx/onnxop.inc b/src/compiler/dialect/onnx/onnxop.inc index 6aa42ae..7d18c13 100644 --- a/src/compiler/dialect/onnx/onnxop.inc +++ b/src/compiler/dialect/onnx/onnxop.inc @@ -531,7 +531,7 @@ def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", } def ONNXEluOp:ONNX_Op<"Elu", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Elu operation"; let description = [{ "Elu takes one input data (Tensor) and produces one output data" @@ -991,7 +991,7 @@ def ONNXGreaterOp:ONNX_Op<"Greater", } def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX HardSigmoid operation"; let description = [{ "HardSigmoid takes one input data (Tensor) and produces one output data" @@ -1191,7 +1191,7 @@ def ONNXLSTMOp:ONNX_Op<"LSTM", } def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX LeakyRelu operation"; let description = [{ "LeakyRelu takes input data (Tensor) and an argument alpha, and produces one" @@ -1436,7 +1436,7 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", } def ONNXMaxOp:ONNX_Op<"Max", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Max operation"; let description = [{ "Element-wise max of each of the input tensors (with Numpy-style broadcasting support)." @@ -1548,7 +1548,7 @@ def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", } def ONNXMinOp:ONNX_Op<"Min", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Min operation"; let description = [{ "Element-wise min of each of the input tensors (with Numpy-style broadcasting support)." @@ -2625,7 +2625,7 @@ def ONNXScatterNDOp:ONNX_Op<"ScatterND", } def ONNXSeluOp:ONNX_Op<"Selu", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Selu operation"; let description = [{ "Selu takes one input data (Tensor) and produces one output data" @@ -2972,7 +2972,7 @@ def ONNXSubOp:ONNX_Op<"Sub", } def ONNXSumOp:ONNX_Op<"Sum", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Sum operation"; let description = [{ "Element-wise sum of each of the input tensors (with Numpy-style broadcasting support)." diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index 6e28d19..9416934 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -148,6 +148,12 @@ struct ScalarOp { using IOp = ExpOp; // not use }; +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + template using ScalarFOp = typename ScalarOp::FOp; template @@ -157,11 +163,11 @@ using ScalarIOp = typename ScalarOp::IOp; // Scalar unary ops for lowering to Krnl dialect. //===----------------------------------------------------------------------===// template -Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, +Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter& rewriter) { /* Lower UnaryOp to Ops in the Standard dialect. */ - + auto loc = op->getLoc(); Type element_type = operands.front()->getType(); if (element_type.isa()) { return rewriter.create>( @@ -179,11 +185,14 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, // Scalar unary ops for lowering ONNXTanhOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value* mapToLowerScalarOp(Operation* op, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter& rewriter) { // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) + auto loc = op->getLoc(); Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); @@ -191,6 +200,7 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, auto result = rewriter.create(loc, rewriter.create(loc, exp, negExp), rewriter.create(loc, exp, negExp)); + return result; } @@ -198,11 +208,14 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value* mapToLowerScalarOp(Operation* op, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter& rewriter) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) + auto loc = op->getLoc(); Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); auto neg = rewriter.create(loc, zero, operand); @@ -210,6 +223,7 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); + return result; } @@ -217,11 +231,14 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value* mapToLowerScalarOp(Operation* op, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter& rewriter) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) + auto loc = op->getLoc(); Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); auto neg = rewriter.create(loc, zero, operand); @@ -229,6 +246,7 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); + return result; } @@ -236,18 +254,84 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Location loc, +Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter& rewriter) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) + auto loc = op->getLoc(); Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, one, rewriter.create(loc, one, negExp)); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXHardSigmoidOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Operation* op, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter& rewriter) { + // %Y = AddFOp(MulFOp(alpha, %X), beta) + // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), + // %Y, + // Constant 0) + // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), + // %Z, + // Constant 1) + auto loc = op->getLoc(); + Value* operand = operands[0]; + auto alphaAttr = op->getAttrOfType("HardSigmoid.alpha"); + auto betaAttr = op->getAttrOfType("HardSigmoid.beta"); + + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto alpha = rewriter.create(loc, alphaAttr); + auto beta = rewriter.create(loc, betaAttr); + + auto add = rewriter.create( + loc, rewriter.create(loc, alpha, operand), beta); + auto maxPredicate = + rewriter.create(loc, CmpFPredicate::OGT, add, zero); + auto max = rewriter.create(loc, maxPredicate, add, zero); + auto minPredicate = + rewriter.create(loc, CmpFPredicate::OLT, max, one); + auto result = rewriter.create(loc, minPredicate, max, one); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXEluOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), + // %X) + auto loc = op->getLoc(); + Value* operand = operands[0]; + + auto alphaAttr = op->getAttrOfType("Elu.alpha"); + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto alpha = rewriter.create(loc, alphaAttr); + auto exp = rewriter.create(loc, operand); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create(loc, lessThanZero, + rewriter.create( + loc, alpha, rewriter.create(loc, exp, one)), + operand); + return result; } @@ -255,30 +339,122 @@ Value* mapToLowerScalarOp(Location loc, // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value* mapToLowerScalarOp(Operation* op, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter& rewriter) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) + auto loc = op->getLoc(); Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); + return result; } //===----------------------------------------------------------------------===// -// Element-wise n-ary ops lowering to Krnl dialect. +// Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// -template -struct ONNXElementwiseNaryOpLowering : public ConversionPattern { - ONNXElementwiseNaryOpLowering(MLIRContext* ctx) - : ConversionPattern(ElementwiseNaryOp::getOperationName(), 1, ctx) {} +template <> +Value* mapToLowerScalarOp(Operation* op, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter& rewriter) { + // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // MulFOp(alpha, %X), + // %X) + auto loc = op->getLoc(); + Value* operand = operands[0]; + + auto alphaAttr = op->getAttrOfType("LeakyRelu.alpha"); + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto alpha = rewriter.create(loc, alphaAttr); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create( + loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSeluOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Operation* op, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter& rewriter) { + // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), + // MulFOp(gamma, %X), + // MulFOp(gamma, + // SubFOp(MulFOp(alpha, ExpOp(%X)), + // alpha))) + auto loc = op->getLoc(); + Value* operand = operands[0]; + auto alphaAttr = op->getAttrOfType("Selu.alpha"); + auto gammaAttr = op->getAttrOfType("Selu.gamma"); + + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto alpha = rewriter.create(loc, alphaAttr); + auto gamma = rewriter.create(loc, gammaAttr); + auto exp = rewriter.create(loc, operand); + auto greaterThanZero = + rewriter.create(loc, CmpFPredicate::OGT, operand, zero); + auto select = rewriter.create(loc, greaterThanZero, operand, + rewriter.create( + loc, rewriter.create(loc, alpha, exp), alpha)); + auto result = rewriter.create(loc, gamma, select); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXMaxOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), + // %X, + // %Y) + auto loc = op->getLoc(); + Value* lhs = operands[0]; + Value* rhs = operands[1]; + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXMinOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), + // %X, + // %Y) + auto loc = op->getLoc(); + Value* lhs = operands[0]; + Value* rhs = operands[1]; + auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; +} + +// Element-wise unary ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { + ONNXElementwiseUnaryOpLowering(MLIRContext* ctx) + : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { // TODO: Check that the types are valid. - // An element-wise binary operation must have all operands and the result of + // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); @@ -287,7 +463,7 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern { auto memRefType = convertTensorToMemRef(tensorType); // If the output has a dynamic dimension, pass the operands required for - // each dynamic dimension to the AllocOp. The first operand of the binary + // each dynamic dimension to the AllocOp. The first operand of the // operation is used. The operands of the op need to match in terms of // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. @@ -359,15 +535,9 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern { for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); - SmallVector loadedVals; - for (unsigned i = 0; i < numArgs; i++) { - auto loadedVal = rewriter.create(loc, operands[i], loopIVs); - loadedVals.push_back(loadedVal); - } - - auto loweredOpResult = mapToLowerScalarOp( - loc, memRefType.getElementType(), loadedVals, rewriter); - + auto loadedVal = rewriter.create(loc, operands[0], loopIVs); + auto loweredOpResult = mapToLowerScalarOp( + op, memRefType.getElementType(), {loadedVal}, rewriter); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); @@ -377,12 +547,113 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern { } }; -template -using ONNXElementwiseUnaryOpLowering = - ONNXElementwiseNaryOpLowering; -template -using ONNXElementwiseBinaryOpLowering = - ONNXElementwiseNaryOpLowering; +// Element-wise variadic ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { + ONNXElementwiseVariadicOpLowering(MLIRContext* ctx) + : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} + PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + // TODO: Check that the types are valid. + // An element-wise variadic operation must have all operands and the result + // of the same type. This should have been verified by the verifier. + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + auto numArgs = op->getNumOperands(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + + // If the output has a dynamic dimension, pass the operands required for + // each dynamic dimension to the AllocOp. The first operand of the + // operation is used. The operands of the op need to match in terms of + // dimensions with the result at this pre-optimization phase. + // TODO: verify that dimensions match. + // TODO: can the dimension of the result differ after optimizations? + Value* alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, operands[0]); + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + + // Define loops. + auto loopsOp = rewriter.create(loc, rank); + std::vector originalLoops; + originalLoops.reserve(rank); + for (auto result : loopsOp.getResults()) { + originalLoops.push_back(result); + } + + // Define loop optimization. + auto optimizedLoopsOp = rewriter.create(loc, rank); + std::vector optimizedLoops; + optimizedLoops.reserve(rank); + for (auto result : optimizedLoopsOp.getResults()) { + optimizedLoops.push_back(result); + } + Block& optimizationBlock = optimizedLoopsOp.region().front(); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest. + // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape + // to KrnlIterateOp instead. + for (int i = 0; i < rank; ++i) { + if (memRefShape[i] < 0) { + pack.pushConstantBound(0); + pack.pushOperandBound( + rewriter.create(loc, operands[0], i).getResult()); + } else { + pack.pushConstantBound(0); + pack.pushConstantBound(memRefShape[i]); + } + } + + auto iterateOp = rewriter.create(loc, pack); + Block& iterationBlock = iterateOp.bodyRegion().front(); + + // Now perform the insertions into the body of the + // just generated instructions: + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + rewriter.setInsertionPoint(optimizedLoopsOp); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + + // Fold over operands for each of their scalar values + Value *accumulated, *next; + accumulated = rewriter.create(loc, operands[0], loopIVs); + for (unsigned i = 1; i < numArgs; i++) { + next = rewriter.create(loc, operands[i], loopIVs); + accumulated = mapToLowerScalarOp( + op, memRefType.getElementType(), {accumulated, next}, rewriter); + } + // Store result in the resulting array. + rewriter.create(loc, accumulated, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. @@ -469,14 +740,21 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, - ONNXElementwiseBinaryOpLowering, - ONNXElementwiseBinaryOpLowering, - ONNXElementwiseBinaryOpLowering, - ONNXElementwiseBinaryOpLowering, - ONNXElementwiseBinaryOpLowering, - ONNXElementwiseBinaryOpLowering, - ONNXElementwiseBinaryOpLowering>(&getContext()); + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index 138a793..678f861 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -93,7 +93,11 @@ class ShapeInferencePass : public mlir::FunctionPass { op->getName().getStringRef() != "onnx.Sinh" && op->getName().getStringRef() != "onnx.Cosh" && op->getName().getStringRef() != "onnx.Sigmoid" && + op->getName().getStringRef() != "onnx.HardSigmoid" && + op->getName().getStringRef() != "onnx.Elu" && op->getName().getStringRef() != "onnx.Relu" && + op->getName().getStringRef() != "onnx.LeakyRelu" && + op->getName().getStringRef() != "onnx.Selu" && op->getName().getStringRef() != "onnx.Mul" && op->getName().getStringRef() != "onnx.Add" && op->getName().getStringRef() != "onnx.Div" && @@ -101,6 +105,9 @@ class ShapeInferencePass : public mlir::FunctionPass { op->getName().getStringRef() != "onnx.And" && op->getName().getStringRef() != "onnx.Or" && op->getName().getStringRef() != "onnx.Xor" && + op->getName().getStringRef() != "onnx.Sum" && + op->getName().getStringRef() != "onnx.Max" && + op->getName().getStringRef() != "onnx.Min" && op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.FullGemm") diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 7f1df67..4ee9453 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -278,3 +278,169 @@ func @test_relu(%arg0 : tensor) -> tensor<*xf32> { // CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref // CHECK: return [[RES]] : memref } + +func @test_sum(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_sum + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref + // CHECK: return [[RES]] : memref +} + +func @test_max(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_max + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref + // CHECK: return [[RES]] : memref +} + +func @test_min(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_min + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref + // CHECK: return [[RES]] : memref +} + +func @test_elu(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_elu + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_leakyrelu + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_selu(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_selu + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32 + // CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32 + // CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32 + // CHECK: store [[SELU_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_hardsigmoid + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[BETA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32 + // CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32 + // CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32 + // CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32 + // CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32 + // CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32 + // CHECK: store [[SELECT2]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} diff --git a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir index 6310516..749829e 100644 --- a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir +++ b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir @@ -571,3 +571,342 @@ func @test_relu_relu(%arg0 : tensor) -> tensor<*xf32> { // CHECK: return [[RET_RES]] : memref } + +func @test_sum_sum(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + %1 = "onnx.Sum"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_sum_sum + /// First Sum + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref + + /// Second Sum + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[ADD]], [[RET_RES]][%arg2, %arg3] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_max_max(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + %1 = "onnx.Max"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_max_max + /// First Max + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref + + /// Second Max + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_min_min(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + %1 = "onnx.Min"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_min_min + /// First Min + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref + + /// Second Min + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_elu_elu(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.Elu"(%0) {Elu.alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_elu_elu + /// First Elu + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref + + /// Second Elu + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32 + // CHECK: store [[SELECT]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_leakyrelu_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.LeakyRelu"(%0) {LeakyRelu.alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_leakyrelu_leakyrelu + /// First LeakyRelu + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref + + /// Second LeakyRelu + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32 + // CHECK: store [[SELECT]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_selu_selu(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.Selu"(%0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_selu_selu + /// First Selu + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32 + // CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32 + // CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32 + // CHECK: store [[SELU_RES]], [[RES]][%arg1, %arg2] : memref + + /// Second Selu + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32 + // CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32 + // CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32 + // CHECK: store [[SELU_RES]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_hardsigmoid_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.HardSigmoid"(%0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_hardsigmoid_hardsigmoid + /// First HardSigmoid + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[BETA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32 + // CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32 + // CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32 + // CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32 + // CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32 + // CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32 + // CHECK: store [[SELECT2]], [[RES]][%arg1, %arg2] : memref + + /// Second HardSigmoid + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32 + // CHECK: [[BETA:%.+]] = constant {{2.+}} : f32 + // CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32 + // CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32 + // CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32 + // CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32 + // CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32 + // CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32 + // CHECK: store [[SELECT2]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +}