Change the name and signature of mapToLowerScalarOp (#67)
* Revise mapToLowerScalarOp() * Update TanhOp Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									f4fefcf713
								
							
						
					
					
						commit
						4e66488ad3
					
				|  | @ -88,17 +88,15 @@ struct ScalarOp<ONNXSqrtOp> { | |||
| // Scalar unary ops for lowering ONNXSinhOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXSinhOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||
|   //                         ConstantOp 2)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto zero = emitConstantOp(rewriter, loc, elementType, 0); | ||||
|   auto two =  emitConstantOp(rewriter, loc, elementType, 2); | ||||
|   auto two = emitConstantOp(rewriter, loc, elementType, 2); | ||||
|   auto neg = rewriter.create<SubFOp>(loc, zero, operand); | ||||
|   auto exp = rewriter.create<ExpOp>(loc, operand); | ||||
|   auto negExp = rewriter.create<ExpOp>(loc, neg); | ||||
|  | @ -112,17 +110,15 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXCoshOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXCoshOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||
|   //                         ConstantOp 2)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto zero = emitConstantOp(rewriter, loc, elementType, 0); | ||||
|   auto two =  emitConstantOp(rewriter, loc, elementType, 2); | ||||
|   auto two = emitConstantOp(rewriter, loc, elementType, 2); | ||||
|   auto neg = rewriter.create<SubFOp>(loc, zero, operand); | ||||
|   auto exp = rewriter.create<ExpOp>(loc, operand); | ||||
|   auto negExp = rewriter.create<ExpOp>(loc, neg); | ||||
|  | @ -136,14 +132,12 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXTanhOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXTanhOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||
|   //                         AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto zero = emitConstantOp(rewriter, loc, elementType, 0); | ||||
|   auto neg = rewriter.create<SubFOp>(loc, zero, operand); | ||||
|  | @ -160,15 +154,12 @@ Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXSigmoidOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, | ||||
|                                         ArrayRef<Type> result_types, | ||||
|                                         ArrayRef<Value> operands, | ||||
|                                         ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXSigmoidOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
 | ||||
|   //                            AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto zero = emitConstantOp(rewriter, loc, elementType, 0); | ||||
|   auto one = emitConstantOp(rewriter, loc, elementType, 1); | ||||
|  | @ -184,9 +175,9 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXHardSigmoidOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXHardSigmoidOp>( | ||||
|     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||
|     ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXHardSigmoidOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // %Y = AddFOp(MulFOp(alpha, %X), beta)
 | ||||
|   // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
 | ||||
|   //               %Y,
 | ||||
|  | @ -194,13 +185,11 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>( | |||
|   // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1),
 | ||||
|   //                                  %Z,
 | ||||
|   //                                  Constant 1)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
|   auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), | ||||
|       llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat()); | ||||
|   auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), | ||||
|       llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat()); | ||||
|   auto elementType = result_types[0]; | ||||
| 
 | ||||
|   auto zero = emitConstantOp(rewriter, loc, elementType, 0); | ||||
|   auto one = emitConstantOp(rewriter, loc, elementType, 1); | ||||
|  | @ -223,15 +212,13 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>( | |||
| // Scalar unary ops for lowering ONNXEluOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                     ArrayRef<Value> operands, | ||||
|                                     ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXEluOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||
|   //                          MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
 | ||||
|   //                          %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), | ||||
|       llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat()); | ||||
|  | @ -241,10 +228,9 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | |||
|   auto exp = rewriter.create<ExpOp>(loc, operand); | ||||
|   auto lessThanZero = | ||||
|       rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero); | ||||
|   auto result = rewriter.create<SelectOp>( | ||||
|       loc, lessThanZero, | ||||
|       rewriter.create<MulFOp>(loc, alpha, | ||||
|                               rewriter.create<SubFOp>(loc, exp, one)), | ||||
|   auto result = rewriter.create<SelectOp>(loc, lessThanZero, | ||||
|       rewriter.create<MulFOp>( | ||||
|           loc, alpha, rewriter.create<SubFOp>(loc, exp, one)), | ||||
|       operand); | ||||
| 
 | ||||
|   return result; | ||||
|  | @ -254,15 +240,13 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXReluOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXReluOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||
|   //                           ConstantOp 0,
 | ||||
|   //                           %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto zero = emitConstantOp(rewriter, loc, elementType, 0); | ||||
|   auto lessThanZero = | ||||
|  | @ -276,16 +260,13 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXLeakyReluOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, | ||||
|                                           ArrayRef<Type> result_types, | ||||
|                                           ArrayRef<Value> operands, | ||||
|                                           ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXLeakyReluOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||
|   //                                MulFOp(alpha, %X),
 | ||||
|   //                                %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), | ||||
|       llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat()); | ||||
|  | @ -303,21 +284,19 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXSeluOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXSeluOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
 | ||||
|   //                           MulFOp(gamma, %X),
 | ||||
|   //                           MulFOp(gamma,
 | ||||
|   //                                  SubFOp(MulFOp(alpha, ExpOp(%X)),
 | ||||
|   //                                         alpha)))
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
|   auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), | ||||
|       llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat()); | ||||
|   auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), | ||||
|       llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat()); | ||||
|   auto elementType = result_types[0]; | ||||
| 
 | ||||
|   auto zero = emitConstantOp(rewriter, loc, elementType, 0); | ||||
|   auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute); | ||||
|  | @ -325,10 +304,9 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types, | |||
|   auto exp = rewriter.create<ExpOp>(loc, operand); | ||||
|   auto greaterThanZero = | ||||
|       rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero); | ||||
|   auto select = rewriter.create<SelectOp>( | ||||
|       loc, greaterThanZero, operand, | ||||
|       rewriter.create<SubFOp>(loc, rewriter.create<MulFOp>(loc, alpha, exp), | ||||
|                               alpha)); | ||||
|   auto select = rewriter.create<SelectOp>(loc, greaterThanZero, operand, | ||||
|       rewriter.create<SubFOp>( | ||||
|           loc, rewriter.create<MulFOp>(loc, alpha, exp), alpha)); | ||||
|   auto result = rewriter.create<MulFOp>(loc, gamma, select); | ||||
| 
 | ||||
|   return result; | ||||
|  | @ -338,14 +316,11 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXReciprocalOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXReciprocalOp>( | ||||
|     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||
|     ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXReciprocalOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
| 
 | ||||
|   Value operand = scalarOperands[0]; | ||||
|   auto one = emitConstantOp(rewriter, loc, elementType, 1); | ||||
|   auto result = rewriter.create<DivFOp>(loc, one, operand); | ||||
| 
 | ||||
|  | @ -356,13 +331,11 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>( | |||
| // Scalar unary ops for lowering ONNXSoftplusOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXSoftplusOp>( | ||||
|     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||
|     ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXSoftplusOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto exp = rewriter.create<ExpOp>(loc, operand); | ||||
|   auto one = emitConstantOp(rewriter, loc, elementType, 1); | ||||
|  | @ -376,13 +349,11 @@ Value mapToLowerScalarOp<ONNXSoftplusOp>( | |||
| // Scalar unary ops for lowering ONNXSoftsignOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXSoftsignOp>( | ||||
|     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||
|     ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXSoftsignOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   auto abs = rewriter.create<AbsFOp>(loc, operand); | ||||
|   auto one = emitConstantOp(rewriter, loc, elementType, 1); | ||||
|  | @ -396,13 +367,10 @@ Value mapToLowerScalarOp<ONNXSoftsignOp>( | |||
| // Scalar unary ops for lowering ONNXSignOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| 
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   Type elementType = operands.front().getType(); | ||||
| Value emitScalarOpFor<ONNXSignOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   Value operand = scalarOperands[0]; | ||||
|   // TODO: unsigned int should be supported separately?
 | ||||
|   if (elementType.isa<IntegerType>()) { | ||||
|     // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
 | ||||
|  | @ -451,15 +419,14 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXMaxOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                     ArrayRef<Value> operands, | ||||
|                                     ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXMaxOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
 | ||||
|   //                              %X,
 | ||||
|   //                              %Y)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value lhs = operands[0]; | ||||
|   Value rhs = operands[1]; | ||||
|   Value lhs = scalarOperands[0]; | ||||
|   Value rhs = scalarOperands[1]; | ||||
|   auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); | ||||
|   auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); | ||||
|   return result; | ||||
|  | @ -469,15 +436,14 @@ Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXMinOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                     ArrayRef<Value> operands, | ||||
|                                     ConversionPatternRewriter &rewriter) { | ||||
| Value emitScalarOpFor<ONNXMinOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
 | ||||
|   //                              %X,
 | ||||
|   //                              %Y)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value lhs = operands[0]; | ||||
|   Value rhs = operands[1]; | ||||
|   Value lhs = scalarOperands[0]; | ||||
|   Value rhs = scalarOperands[1]; | ||||
|   auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); | ||||
|   auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); | ||||
|   return result; | ||||
|  | @ -487,11 +453,10 @@ Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXAbsOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXAbsOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|     ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) { | ||||
|   auto loc = op->getLoc(); | ||||
|   Value operand = operands[0]; | ||||
|   auto elementType = result_types[0]; | ||||
| Value emitScalarOpFor<ONNXAbsOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   Value operand = scalarOperands[0]; | ||||
| 
 | ||||
|   if (elementType.isa<FloatType>()) { | ||||
|     return rewriter.create<AbsFOp>(loc, operand); | ||||
|  | @ -536,15 +501,14 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | |||
|     if (hasAllConstantDimensions(memRefType)) | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||
|     else | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, | ||||
|                                     {operands[0]}); | ||||
|       alloc = insertAllocAndDealloc( | ||||
|           memRefType, loc, rewriter, insertDealloc, {operands[0]}); | ||||
| 
 | ||||
|     std::vector<Value> originalLoops; | ||||
|     KrnlOptimizeLoopsOp optimizedLoopsOp; | ||||
|     KrnlIterateOp iterateOp; | ||||
|     emitKrnlLoopsAndIterationForOperand( | ||||
|         rewriter, loc, operands[0], originalLoops, | ||||
|         optimizedLoopsOp, iterateOp); | ||||
|         rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp); | ||||
|     Block &optimizationBlock = optimizedLoopsOp.region().front(); | ||||
|     Block &iterationBlock = iterateOp.bodyRegion().front(); | ||||
| 
 | ||||
|  | @ -564,8 +528,8 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | |||
|       loopIVs.push_back(arg); | ||||
| 
 | ||||
|     auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs); | ||||
|     auto loweredOpResult = mapToLowerScalarOp<ElementwiseUnaryOp>( | ||||
|         op, memRefType.getElementType(), {loadedVal}, rewriter); | ||||
|     auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>( | ||||
|         rewriter, loc, op, memRefType.getElementType(), {loadedVal}); | ||||
|     // Store result in the resulting array.
 | ||||
|     rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs); | ||||
| 
 | ||||
|  | @ -603,8 +567,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
|     if (hasAllConstantDimensions(memRefType)) | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||
|     else | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, | ||||
|                                     operands); | ||||
|       alloc = insertAllocAndDealloc( | ||||
|           memRefType, loc, rewriter, insertDealloc, operands); | ||||
| 
 | ||||
|     // Get run-time dimension information for unknown dimensions used for
 | ||||
|     // broadcasting.
 | ||||
|  | @ -615,8 +579,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
|     KrnlOptimizeLoopsOp optimizedLoopsOp; | ||||
|     KrnlIterateOp iterateOp; | ||||
|     emitKrnlLoopsAndIterationForOperand( | ||||
|         rewriter, loc, alloc, originalLoops, | ||||
|         optimizedLoopsOp, iterateOp); | ||||
|         rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp); | ||||
|     Block &optimizationBlock = optimizedLoopsOp.region().front(); | ||||
|     Block &iterationBlock = iterateOp.bodyRegion().front(); | ||||
| 
 | ||||
|  | @ -643,8 +606,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
|       auto nextLoopIVs = getLoopIVsForBroadcasting( | ||||
|           loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); | ||||
|       next = rewriter.create<LoadOp>(loc, operands[i], nextLoopIVs); | ||||
|       accumulated = mapToLowerScalarOp<ElementwiseVariadicOp>( | ||||
|           op, memRefType.getElementType(), {accumulated, next}, rewriter); | ||||
|       accumulated = emitScalarOpFor<ElementwiseVariadicOp>( | ||||
|           rewriter, loc, op, memRefType.getElementType(), {accumulated, next}); | ||||
|     } | ||||
|     // Store result in the resulting array.
 | ||||
|     rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs); | ||||
|  | @ -658,31 +621,31 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
| void populateLoweringONNXElementwiseOpPattern( | ||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx) { | ||||
|   patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXAbsOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXSignOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, | ||||
|                   ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, | ||||
|                   ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx); | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXSignOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx); | ||||
| } | ||||
|  |  | |||
|  | @ -54,13 +54,11 @@ struct ScalarOp<ONNXReduceSumOp> { | |||
| // Scalar unary ops for lowering ONNXReduceMaxOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op, | ||||
|                                           ArrayRef<Type> result_types, | ||||
|                                           ArrayRef<Value> operands, | ||||
|                                           ConversionPatternRewriter &rewriter) { | ||||
|   auto loc = op->getLoc(); | ||||
|   Value lhs = operands[0]; | ||||
|   Value rhs = operands[1]; | ||||
| Value emitScalarOpFor<ONNXReduceMaxOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   Value lhs = scalarOperands[0]; | ||||
|   Value rhs = scalarOperands[1]; | ||||
|   Type element_type = lhs.getType(); | ||||
|   if (element_type.isa<IntegerType>()) { | ||||
|     auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs); | ||||
|  | @ -80,19 +78,16 @@ Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXReduceMinOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXReduceMinOp>(Operation *op, | ||||
|                                           ArrayRef<Type> result_types, | ||||
|                                           ArrayRef<Value> operands, | ||||
|                                           ConversionPatternRewriter &rewriter) { | ||||
|   auto loc = op->getLoc(); | ||||
|   Value lhs = operands[0]; | ||||
|   Value rhs = operands[1]; | ||||
|   Type element_type = lhs.getType(); | ||||
|   if (element_type.isa<IntegerType>()) { | ||||
| Value emitScalarOpFor<ONNXReduceMinOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   Value lhs = scalarOperands[0]; | ||||
|   Value rhs = scalarOperands[1]; | ||||
|   if (elementType.isa<IntegerType>()) { | ||||
|     auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs); | ||||
|     auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); | ||||
|     return result; | ||||
|   } else if (element_type.isa<FloatType>()) { | ||||
|   } else if (elementType.isa<FloatType>()) { | ||||
|     auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); | ||||
|     auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); | ||||
|     return result; | ||||
|  | @ -129,7 +124,7 @@ struct ONNXReductionOpLowering : public ConversionPattern { | |||
|      *   Y(i1) += X(i0, i1, i2) | ||||
|      * } | ||||
|      * | ||||
|     */ | ||||
|      */ | ||||
|     auto loc = op->getLoc(); | ||||
|     auto memRefInType = operands[0].getType().cast<MemRefType>(); | ||||
|     auto memRefInShape = memRefInType.getShape(); | ||||
|  | @ -154,8 +149,7 @@ struct ONNXReductionOpLowering : public ConversionPattern { | |||
|       } | ||||
|     } | ||||
|     // KeepDims
 | ||||
|     auto keepdims = | ||||
|         llvm::dyn_cast<ONNXReductionOp>(op).keepdims(); | ||||
|     auto keepdims = llvm::dyn_cast<ONNXReductionOp>(op).keepdims(); | ||||
|     bool isKeepdims = (keepdims == 1) ? true : false; | ||||
| 
 | ||||
|     // Get type information
 | ||||
|  | @ -168,7 +162,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { | |||
|     Value alloc; | ||||
|     bool insertDealloc = checkInsertDealloc(op); | ||||
|     if (hasAllConstantDimensions(memRefOutType)) { | ||||
|       alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); | ||||
|       alloc = | ||||
|           insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); | ||||
|     } else { | ||||
|       SmallVector<Value, 2> allocOperands; | ||||
|       for (decltype(outRank) i = 0; i < outRank; ++i) { | ||||
|  | @ -192,12 +187,12 @@ struct ONNXReductionOpLowering : public ConversionPattern { | |||
|     // Define loops to initialize the result.
 | ||||
|     std::vector<Value> originalLoopsInit; | ||||
|     std::vector<Value> optimizedLoopsInit; | ||||
|     Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, | ||||
|             optimizedLoopsInit, outRank); | ||||
|     Block *optimizationBlockInit = defineLoops( | ||||
|         rewriter, loc, originalLoopsInit, optimizedLoopsInit, outRank); | ||||
| 
 | ||||
|     // Iteration information
 | ||||
|     KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, | ||||
|         optimizedLoopsInit); | ||||
|     KrnlIterateOperandPack packInit( | ||||
|         rewriter, originalLoopsInit, optimizedLoopsInit); | ||||
|     for (decltype(outRank) i = 0; i < outRank; ++i) { | ||||
|       addDimensionToPack(rewriter, loc, packInit, alloc, i); | ||||
|     } | ||||
|  | @ -225,8 +220,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { | |||
|     // Define an Krnl loop to do reduction.
 | ||||
|     rewriter.setInsertionPointAfter(iterateOpInit); | ||||
|     std::vector<Value> originalLoops, optimizedLoops; | ||||
|     Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, | ||||
|             optimizedLoops, inRank); | ||||
|     Block *optimizationBlock = | ||||
|         defineLoops(rewriter, loc, originalLoops, optimizedLoops, inRank); | ||||
|     // Iteration information
 | ||||
|     KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); | ||||
|     for (decltype(inRank) i = 0; i < inRank; ++i) { | ||||
|  | @ -266,8 +261,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { | |||
|     Value next, accumulated; | ||||
|     next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs); | ||||
|     accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs); | ||||
|     accumulated = mapToLowerScalarOp<ONNXReductionOp>( | ||||
|         op, memRefOutType.getElementType(), {accumulated, next}, rewriter); | ||||
|     accumulated = emitScalarOpFor<ONNXReductionOp>( | ||||
|         rewriter, loc, op, memRefOutType.getElementType(), {accumulated, next}); | ||||
|     rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs); | ||||
| 
 | ||||
|     rewriter.replaceOp(op, alloc); | ||||
|  | @ -278,7 +273,7 @@ struct ONNXReductionOpLowering : public ConversionPattern { | |||
| void populateLoweringONNXReductionOpPattern( | ||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx) { | ||||
|   patterns.insert<ONNXReductionOpLowering<mlir::ONNXReduceMaxOp>, | ||||
|                   ONNXReductionOpLowering<mlir::ONNXReduceMinOp>, | ||||
|                   ONNXReductionOpLowering<mlir::ONNXReduceProdOp>, | ||||
|                   ONNXReductionOpLowering<mlir::ONNXReduceSumOp>>(ctx); | ||||
|       ONNXReductionOpLowering<mlir::ONNXReduceMinOp>, | ||||
|       ONNXReductionOpLowering<mlir::ONNXReduceProdOp>, | ||||
|       ONNXReductionOpLowering<mlir::ONNXReduceSumOp>>(ctx); | ||||
| } | ||||
|  |  | |||
|  | @ -20,12 +20,11 @@ Value getIdentityValue<ONNXMaxPoolSingleOutOp>( | |||
| } | ||||
| 
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(Operation *op, | ||||
|     ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||
|     ConversionPatternRewriter &rewriter) { | ||||
|   auto loc = op->getLoc(); | ||||
|   Value lhs = operands[0]; | ||||
|   Value rhs = operands[1]; | ||||
| Value emitScalarOpFor<ONNXMaxPoolSingleOutOp>( | ||||
|     ConversionPatternRewriter &rewriter, Location loc, Operation *op, | ||||
|     Type elementType, ArrayRef<Value> scalarOperands) { | ||||
|   Value lhs = scalarOperands[0]; | ||||
|   Value rhs = scalarOperands[1]; | ||||
|   auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); | ||||
|   auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); | ||||
|   return result; | ||||
|  | @ -308,8 +307,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { | |||
|         auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices); | ||||
|         auto loadPartialResult = | ||||
|             rewriter.create<LoadOp>(loc, alloc, resultIndices); | ||||
|         Value result = mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>( | ||||
|             op, resultElementType, {loadPartialResult, loadData}, rewriter); | ||||
|         Value result = emitScalarOpFor<ONNXMaxPoolSingleOutOp>(rewriter, loc, | ||||
|             op, resultElementType, {loadPartialResult, loadData}); | ||||
|         rewriter.create<StoreOp>(loc, result, alloc, resultIndices); | ||||
|       } | ||||
|     } | ||||
|  |  | |||
|  | @ -148,17 +148,14 @@ Value getIdentityValue( | |||
| // Use template specialization for each of different ONNX operations.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <typename Op> | ||||
| Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, | ||||
|                          ArrayRef<Value> operands, | ||||
|                          ConversionPatternRewriter &rewriter) { | ||||
|   auto loc = op->getLoc(); | ||||
|   Type element_type = operands.front().getType(); | ||||
|   if (element_type.isa<IntegerType>()) { | ||||
|     return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands, | ||||
|                                           mlir::None); | ||||
|   } else if (element_type.isa<FloatType>()) { | ||||
|     return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands, | ||||
|                                           mlir::None); | ||||
| Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, | ||||
|     Operation *op, Type elementType, ArrayRef<Value> scalarOperands) { | ||||
|   if (elementType.isa<IntegerType>()) { | ||||
|     return rewriter.create<ScalarIOp<Op>>( | ||||
|         loc, elementType, scalarOperands, mlir::None); | ||||
|   } else if (elementType.isa<FloatType>()) { | ||||
|     return rewriter.create<ScalarFOp<Op>>( | ||||
|         loc, elementType, scalarOperands, mlir::None); | ||||
|   } else { | ||||
|     emitError(loc, "unsupported element type"); | ||||
|     return nullptr; | ||||
|  | @ -247,4 +244,3 @@ void populateLoweringONNXIdentityOpPattern( | |||
| 
 | ||||
| void populateLoweringONNXConstantOpPattern( | ||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx); | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue