Merge branch 'master' into tanh_cos_log

This commit is contained in:
Tung D. Le 2020-01-08 13:39:24 +09:00 committed by GitHub
commit edcd506dde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 24 additions and 16 deletions

View File

@ -181,7 +181,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult(); auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1); auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted = auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one); rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDims.insert(std::make_pair(j, isBroadcasted)); broadcastedDims.insert(std::make_pair(j, isBroadcasted));
} }
} }
@ -332,7 +332,6 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
} }
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXSinhOp // Scalar unary ops for lowering ONNXSinhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -344,9 +343,10 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg); auto negExp = rewriter.create<ExpOp>(loc, neg);
@ -367,9 +367,10 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg); auto negExp = rewriter.create<ExpOp>(loc, neg);
@ -391,9 +392,10 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg); auto negExp = rewriter.create<ExpOp>(loc, neg);
auto result = rewriter.create<DivFOp>( auto result = rewriter.create<DivFOp>(
@ -420,9 +422,10 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
Value operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta"); auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr); auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -450,10 +453,11 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero = auto lessThanZero =
@ -479,8 +483,9 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto lessThanZero = auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand); auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
@ -501,9 +506,10 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto lessThanZero = auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
@ -529,8 +535,9 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma"); auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr); auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
@ -555,8 +562,9 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0];
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto result = rewriter.create<DivFOp>(loc, one, operand); auto result = rewriter.create<DivFOp>(loc, one, operand);
return result; return result;