Merge pull request #13 from tungld/get_float_for_constants
Do not get float attributes using a fixed precision
This commit is contained in:
commit
5bc0967175
|
@ -325,8 +325,9 @@ Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
@ -348,9 +349,10 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
@ -371,9 +373,10 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
@ -395,9 +398,10 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
||||||
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
auto result = rewriter.create<DivFOp>(
|
auto result = rewriter.create<DivFOp>(
|
||||||
|
@ -424,9 +428,10 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
||||||
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||||
|
|
||||||
|
@ -454,10 +459,11 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
|
@ -483,8 +489,9 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
|
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
|
||||||
|
@ -505,9 +512,10 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
|
@ -533,8 +541,9 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
||||||
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
|
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
@ -559,8 +568,9 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
||||||
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
Loading…
Reference in New Issue