Do not get float attributes with fixed precision

This commit is contained in:
Tung D. Le 2020-01-07 17:39:34 +09:00
parent 322002f509
commit becb2add4a
1 changed files with 26 additions and 16 deletions

View File

@ -181,7 +181,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
}
}
@ -325,8 +325,9 @@ Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
@ -348,9 +349,10 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
// ConstantOp 2)
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
@ -371,9 +373,10 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
// ConstantOp 2)
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
@ -395,9 +398,10 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
auto result = rewriter.create<DivFOp>(
@ -424,9 +428,10 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -454,10 +459,11 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
// %X)
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero =
@ -483,8 +489,9 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
// %X)
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
@ -505,9 +512,10 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
// %X)
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
@ -533,8 +541,9 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
auto exp = rewriter.create<ExpOp>(loc, operand);
@ -559,8 +568,9 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto result = rewriter.create<DivFOp>(loc, one, operand);
return result;