Change the name and signature of mapToLowerScalarOp (#67)

* Revise mapToLowerScalarOp()

* Update TanhOp

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-04-09 17:06:56 +09:00 committed by GitHub
parent f4fefcf713
commit 4e66488ad3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 153 additions and 200 deletions

View File

@ -88,14 +88,12 @@ struct ScalarOp<ONNXSqrtOp> {
// Scalar unary ops for lowering ONNXSinhOp // Scalar unary ops for lowering ONNXSinhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXSinhOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto two = emitConstantOp(rewriter, loc, elementType, 2); auto two = emitConstantOp(rewriter, loc, elementType, 2);
@ -112,14 +110,12 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXCoshOp // Scalar unary ops for lowering ONNXCoshOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXCoshOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto two = emitConstantOp(rewriter, loc, elementType, 2); auto two = emitConstantOp(rewriter, loc, elementType, 2);
@ -136,14 +132,12 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXTanhOp // Scalar unary ops for lowering ONNXTanhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXTanhOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
@ -160,15 +154,12 @@ Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXSigmoidOp // Scalar unary ops for lowering ONNXSigmoidOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, Value emitScalarOpFor<ONNXSigmoidOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Type> result_types, Location loc, Operation *op, Type elementType,
ArrayRef<Value> operands, ArrayRef<Value> scalarOperands) {
ConversionPatternRewriter &rewriter) {
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1); auto one = emitConstantOp(rewriter, loc, elementType, 1);
@ -184,9 +175,9 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
// Scalar unary ops for lowering ONNXHardSigmoidOp // Scalar unary ops for lowering ONNXHardSigmoidOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXHardSigmoidOp>( Value emitScalarOpFor<ONNXHardSigmoidOp>(ConversionPatternRewriter &rewriter,
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// %Y = AddFOp(MulFOp(alpha, %X), beta) // %Y = AddFOp(MulFOp(alpha, %X), beta)
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
// %Y, // %Y,
@ -194,13 +185,11 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
// ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1),
// %Z, // %Z,
// Constant 1) // Constant 1)
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat()); llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat()); llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
auto elementType = result_types[0];
auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1); auto one = emitConstantOp(rewriter, loc, elementType, 1);
@ -223,15 +212,13 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
// Scalar unary ops for lowering ONNXEluOp // Scalar unary ops for lowering ONNXEluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXEluOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
// %X) // %X)
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat()); llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
@ -241,10 +228,9 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero = auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>( auto result = rewriter.create<SelectOp>(loc, lessThanZero,
loc, lessThanZero, rewriter.create<MulFOp>(
rewriter.create<MulFOp>(loc, alpha, loc, alpha, rewriter.create<SubFOp>(loc, exp, one)),
rewriter.create<SubFOp>(loc, exp, one)),
operand); operand);
return result; return result;
@ -254,15 +240,13 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXReluOp // Scalar unary ops for lowering ONNXReluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXReluOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// ConstantOp 0, // ConstantOp 0,
// %X) // %X)
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto lessThanZero = auto lessThanZero =
@ -276,16 +260,13 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXLeakyReluOp // Scalar unary ops for lowering ONNXLeakyReluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, Value emitScalarOpFor<ONNXLeakyReluOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Type> result_types, Location loc, Operation *op, Type elementType,
ArrayRef<Value> operands, ArrayRef<Value> scalarOperands) {
ConversionPatternRewriter &rewriter) {
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, %X), // MulFOp(alpha, %X),
// %X) // %X)
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat()); llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
@ -303,21 +284,19 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
// Scalar unary ops for lowering ONNXSeluOp // Scalar unary ops for lowering ONNXSeluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXSeluOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
// MulFOp(gamma, %X), // MulFOp(gamma, %X),
// MulFOp(gamma, // MulFOp(gamma,
// SubFOp(MulFOp(alpha, ExpOp(%X)), // SubFOp(MulFOp(alpha, ExpOp(%X)),
// alpha))) // alpha)))
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat()); llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat()); llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
auto elementType = result_types[0];
auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
@ -325,10 +304,9 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto greaterThanZero = auto greaterThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
auto select = rewriter.create<SelectOp>( auto select = rewriter.create<SelectOp>(loc, greaterThanZero, operand,
loc, greaterThanZero, operand, rewriter.create<SubFOp>(
rewriter.create<SubFOp>(loc, rewriter.create<MulFOp>(loc, alpha, exp), loc, rewriter.create<MulFOp>(loc, alpha, exp), alpha));
alpha));
auto result = rewriter.create<MulFOp>(loc, gamma, select); auto result = rewriter.create<MulFOp>(loc, gamma, select);
return result; return result;
@ -338,14 +316,11 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXReciprocalOp // Scalar unary ops for lowering ONNXReciprocalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXReciprocalOp>( Value emitScalarOpFor<ONNXReciprocalOp>(ConversionPatternRewriter &rewriter,
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto one = emitConstantOp(rewriter, loc, elementType, 1); auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto result = rewriter.create<DivFOp>(loc, one, operand); auto result = rewriter.create<DivFOp>(loc, one, operand);
@ -356,13 +331,11 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
// Scalar unary ops for lowering ONNXSoftplusOp // Scalar unary ops for lowering ONNXSoftplusOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXSoftplusOp>( Value emitScalarOpFor<ONNXSoftplusOp>(ConversionPatternRewriter &rewriter,
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto one = emitConstantOp(rewriter, loc, elementType, 1); auto one = emitConstantOp(rewriter, loc, elementType, 1);
@ -376,13 +349,11 @@ Value mapToLowerScalarOp<ONNXSoftplusOp>(
// Scalar unary ops for lowering ONNXSoftsignOp // Scalar unary ops for lowering ONNXSoftsignOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXSoftsignOp>( Value emitScalarOpFor<ONNXSoftsignOp>(ConversionPatternRewriter &rewriter,
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc(); Value operand = scalarOperands[0];
Value operand = operands[0];
auto elementType = result_types[0];
auto abs = rewriter.create<AbsFOp>(loc, operand); auto abs = rewriter.create<AbsFOp>(loc, operand);
auto one = emitConstantOp(rewriter, loc, elementType, 1); auto one = emitConstantOp(rewriter, loc, elementType, 1);
@ -396,13 +367,10 @@ Value mapToLowerScalarOp<ONNXSoftsignOp>(
// Scalar unary ops for lowering ONNXSignOp // Scalar unary ops for lowering ONNXSignOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXSignOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
Value operand = scalarOperands[0];
auto loc = op->getLoc();
Value operand = operands[0];
Type elementType = operands.front().getType();
// TODO: unsigned int should be supported separately? // TODO: unsigned int should be supported separately?
if (elementType.isa<IntegerType>()) { if (elementType.isa<IntegerType>()) {
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
@ -451,15 +419,14 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXMaxOp // Scalar unary ops for lowering ONNXMaxOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXMaxOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
// %X, // %X,
// %Y) // %Y)
auto loc = op->getLoc(); Value lhs = scalarOperands[0];
Value lhs = operands[0]; Value rhs = scalarOperands[1];
Value rhs = operands[1];
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result; return result;
@ -469,15 +436,14 @@ Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXMinOp // Scalar unary ops for lowering ONNXMinOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXMinOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, Location loc, Operation *op, Type elementType,
ConversionPatternRewriter &rewriter) { ArrayRef<Value> scalarOperands) {
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
// %X, // %X,
// %Y) // %Y)
auto loc = op->getLoc(); Value lhs = scalarOperands[0];
Value lhs = operands[0]; Value rhs = scalarOperands[1];
Value rhs = operands[1];
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result; return result;
@ -487,11 +453,10 @@ Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXAbsOp // Scalar unary ops for lowering ONNXAbsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXAbsOp>(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor<ONNXAbsOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) { Location loc, Operation *op, Type elementType,
auto loc = op->getLoc(); ArrayRef<Value> scalarOperands) {
Value operand = operands[0]; Value operand = scalarOperands[0];
auto elementType = result_types[0];
if (elementType.isa<FloatType>()) { if (elementType.isa<FloatType>()) {
return rewriter.create<AbsFOp>(loc, operand); return rewriter.create<AbsFOp>(loc, operand);
@ -536,15 +501,14 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, alloc = insertAllocAndDealloc(
{operands[0]}); memRefType, loc, rewriter, insertDealloc, {operands[0]});
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlOptimizeLoopsOp optimizedLoopsOp;
KrnlIterateOp iterateOp; KrnlIterateOp iterateOp;
emitKrnlLoopsAndIterationForOperand( emitKrnlLoopsAndIterationForOperand(
rewriter, loc, operands[0], originalLoops, rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp);
optimizedLoopsOp, iterateOp);
Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &optimizationBlock = optimizedLoopsOp.region().front();
Block &iterationBlock = iterateOp.bodyRegion().front(); Block &iterationBlock = iterateOp.bodyRegion().front();
@ -564,8 +528,8 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
loopIVs.push_back(arg); loopIVs.push_back(arg);
auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs); auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
auto loweredOpResult = mapToLowerScalarOp<ElementwiseUnaryOp>( auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>(
op, memRefType.getElementType(), {loadedVal}, rewriter); rewriter, loc, op, memRefType.getElementType(), {loadedVal});
// Store result in the resulting array. // Store result in the resulting array.
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs); rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
@ -603,8 +567,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, alloc = insertAllocAndDealloc(
operands); memRefType, loc, rewriter, insertDealloc, operands);
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
@ -615,8 +579,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlOptimizeLoopsOp optimizedLoopsOp;
KrnlIterateOp iterateOp; KrnlIterateOp iterateOp;
emitKrnlLoopsAndIterationForOperand( emitKrnlLoopsAndIterationForOperand(
rewriter, loc, alloc, originalLoops, rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp);
optimizedLoopsOp, iterateOp);
Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &optimizationBlock = optimizedLoopsOp.region().front();
Block &iterationBlock = iterateOp.bodyRegion().front(); Block &iterationBlock = iterateOp.bodyRegion().front();
@ -643,8 +606,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
auto nextLoopIVs = getLoopIVsForBroadcasting( auto nextLoopIVs = getLoopIVsForBroadcasting(
loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]);
next = rewriter.create<LoadOp>(loc, operands[i], nextLoopIVs); next = rewriter.create<LoadOp>(loc, operands[i], nextLoopIVs);
accumulated = mapToLowerScalarOp<ElementwiseVariadicOp>( accumulated = emitScalarOpFor<ElementwiseVariadicOp>(
op, memRefType.getElementType(), {accumulated, next}, rewriter); rewriter, loc, op, memRefType.getElementType(), {accumulated, next});
} }
// Store result in the resulting array. // Store result in the resulting array.
rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs); rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs);

View File

@ -54,13 +54,11 @@ struct ScalarOp<ONNXReduceSumOp> {
// Scalar unary ops for lowering ONNXReduceMaxOp // Scalar unary ops for lowering ONNXReduceMaxOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op, Value emitScalarOpFor<ONNXReduceMaxOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Type> result_types, Location loc, Operation *op, Type elementType,
ArrayRef<Value> operands, ArrayRef<Value> scalarOperands) {
ConversionPatternRewriter &rewriter) { Value lhs = scalarOperands[0];
auto loc = op->getLoc(); Value rhs = scalarOperands[1];
Value lhs = operands[0];
Value rhs = operands[1];
Type element_type = lhs.getType(); Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs); auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
@ -80,19 +78,16 @@ Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op,
// Scalar unary ops for lowering ONNXReduceMinOp // Scalar unary ops for lowering ONNXReduceMinOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value mapToLowerScalarOp<ONNXReduceMinOp>(Operation *op, Value emitScalarOpFor<ONNXReduceMinOp>(ConversionPatternRewriter &rewriter,
ArrayRef<Type> result_types, Location loc, Operation *op, Type elementType,
ArrayRef<Value> operands, ArrayRef<Value> scalarOperands) {
ConversionPatternRewriter &rewriter) { Value lhs = scalarOperands[0];
auto loc = op->getLoc(); Value rhs = scalarOperands[1];
Value lhs = operands[0]; if (elementType.isa<IntegerType>()) {
Value rhs = operands[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs); auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result; return result;
} else if (element_type.isa<FloatType>()) { } else if (elementType.isa<FloatType>()) {
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result; return result;
@ -154,8 +149,7 @@ struct ONNXReductionOpLowering : public ConversionPattern {
} }
} }
// KeepDims // KeepDims
auto keepdims = auto keepdims = llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
bool isKeepdims = (keepdims == 1) ? true : false; bool isKeepdims = (keepdims == 1) ? true : false;
// Get type information // Get type information
@ -168,7 +162,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
Value alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefOutType)) { if (hasAllConstantDimensions(memRefOutType)) {
alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); alloc =
insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc);
} else { } else {
SmallVector<Value, 2> allocOperands; SmallVector<Value, 2> allocOperands;
for (decltype(outRank) i = 0; i < outRank; ++i) { for (decltype(outRank) i = 0; i < outRank; ++i) {
@ -192,12 +187,12 @@ struct ONNXReductionOpLowering : public ConversionPattern {
// Define loops to initialize the result. // Define loops to initialize the result.
std::vector<Value> originalLoopsInit; std::vector<Value> originalLoopsInit;
std::vector<Value> optimizedLoopsInit; std::vector<Value> optimizedLoopsInit;
Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, Block *optimizationBlockInit = defineLoops(
optimizedLoopsInit, outRank); rewriter, loc, originalLoopsInit, optimizedLoopsInit, outRank);
// Iteration information // Iteration information
KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, KrnlIterateOperandPack packInit(
optimizedLoopsInit); rewriter, originalLoopsInit, optimizedLoopsInit);
for (decltype(outRank) i = 0; i < outRank; ++i) { for (decltype(outRank) i = 0; i < outRank; ++i) {
addDimensionToPack(rewriter, loc, packInit, alloc, i); addDimensionToPack(rewriter, loc, packInit, alloc, i);
} }
@ -225,8 +220,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
// Define an Krnl loop to do reduction. // Define an Krnl loop to do reduction.
rewriter.setInsertionPointAfter(iterateOpInit); rewriter.setInsertionPointAfter(iterateOpInit);
std::vector<Value> originalLoops, optimizedLoops; std::vector<Value> originalLoops, optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, Block *optimizationBlock =
optimizedLoops, inRank); defineLoops(rewriter, loc, originalLoops, optimizedLoops, inRank);
// Iteration information // Iteration information
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
for (decltype(inRank) i = 0; i < inRank; ++i) { for (decltype(inRank) i = 0; i < inRank; ++i) {
@ -266,8 +261,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
Value next, accumulated; Value next, accumulated;
next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs); next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs); accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs);
accumulated = mapToLowerScalarOp<ONNXReductionOp>( accumulated = emitScalarOpFor<ONNXReductionOp>(
op, memRefOutType.getElementType(), {accumulated, next}, rewriter); rewriter, loc, op, memRefOutType.getElementType(), {accumulated, next});
rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs); rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs);
rewriter.replaceOp(op, alloc); rewriter.replaceOp(op, alloc);

View File

@ -20,12 +20,11 @@ Value getIdentityValue<ONNXMaxPoolSingleOutOp>(
} }
template <> template <>
Value mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(Operation *op, Value emitScalarOpFor<ONNXMaxPoolSingleOutOp>(
ArrayRef<Type> result_types, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter, Location loc, Operation *op,
ConversionPatternRewriter &rewriter) { Type elementType, ArrayRef<Value> scalarOperands) {
auto loc = op->getLoc(); Value lhs = scalarOperands[0];
Value lhs = operands[0]; Value rhs = scalarOperands[1];
Value rhs = operands[1];
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result; return result;
@ -308,8 +307,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices); auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
auto loadPartialResult = auto loadPartialResult =
rewriter.create<LoadOp>(loc, alloc, resultIndices); rewriter.create<LoadOp>(loc, alloc, resultIndices);
Value result = mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>( Value result = emitScalarOpFor<ONNXMaxPoolSingleOutOp>(rewriter, loc,
op, resultElementType, {loadPartialResult, loadData}, rewriter); op, resultElementType, {loadPartialResult, loadData});
rewriter.create<StoreOp>(loc, result, alloc, resultIndices); rewriter.create<StoreOp>(loc, result, alloc, resultIndices);
} }
} }

View File

@ -148,17 +148,14 @@ Value getIdentityValue(
// Use template specialization for each of different ONNX operations. // Use template specialization for each of different ONNX operations.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <typename Op> template <typename Op>
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> operands, Operation *op, Type elementType, ArrayRef<Value> scalarOperands) {
ConversionPatternRewriter &rewriter) { if (elementType.isa<IntegerType>()) {
auto loc = op->getLoc(); return rewriter.create<ScalarIOp<Op>>(
Type element_type = operands.front().getType(); loc, elementType, scalarOperands, mlir::None);
if (element_type.isa<IntegerType>()) { } else if (elementType.isa<FloatType>()) {
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands, return rewriter.create<ScalarFOp<Op>>(
mlir::None); loc, elementType, scalarOperands, mlir::None);
} else if (element_type.isa<FloatType>()) {
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
mlir::None);
} else { } else {
emitError(loc, "unsupported element type"); emitError(loc, "unsupported element type");
return nullptr; return nullptr;
@ -247,4 +244,3 @@ void populateLoweringONNXIdentityOpPattern(
void populateLoweringONNXConstantOpPattern( void populateLoweringONNXConstantOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);