diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index d38954b..19c0b9c 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -181,7 +181,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, auto dim = rewriter.create(loc, operands[i], j).getResult(); auto one = rewriter.create(loc, 1); auto isBroadcasted = - rewriter.create(loc, CmpIPredicate::eq, dim, one); + rewriter.create(loc, CmpIPredicate::eq, dim, one); broadcastedDims.insert(std::make_pair(j, isBroadcasted)); } } @@ -325,8 +325,9 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); Value operand = operands[0]; + auto elementType = result_types[0]; - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); @@ -348,9 +349,10 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; + auto elementType = result_types[0]; - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); @@ -371,9 +373,10 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; + auto elementType = result_types[0]; - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); @@ -395,9 +398,10 @@ Value mapToLowerScalarOp(Operation *op, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); Value operand = operands[0]; + auto elementType = result_types[0]; - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( @@ -424,9 +428,10 @@ Value mapToLowerScalarOp( Value operand = operands[0]; auto alphaAttr = op->getAttrOfType("HardSigmoid.alpha"); auto betaAttr = op->getAttrOfType("HardSigmoid.beta"); + auto elementType = result_types[0]; - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); @@ -454,10 +459,11 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // %X) auto loc = op->getLoc(); Value operand = operands[0]; + auto elementType = result_types[0]; auto alphaAttr = op->getAttrOfType("Elu.alpha"); - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttr); auto exp = rewriter.create(loc, operand); auto lessThanZero = @@ -483,8 +489,9 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // %X) auto loc = op->getLoc(); Value operand = operands[0]; + auto elementType = result_types[0]; - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); @@ -505,9 +512,10 @@ Value mapToLowerScalarOp(Operation *op, // %X) auto loc = op->getLoc(); Value operand = operands[0]; + auto elementType = result_types[0]; auto alphaAttr = op->getAttrOfType("LeakyRelu.alpha"); - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttr); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); @@ -533,8 +541,9 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, Value operand = operands[0]; auto alphaAttr = op->getAttrOfType("Selu.alpha"); auto gammaAttr = op->getAttrOfType("Selu.gamma"); + auto elementType = result_types[0]; - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttr); auto gamma = rewriter.create(loc, gammaAttr); auto exp = rewriter.create(loc, operand); @@ -559,8 +568,9 @@ Value mapToLowerScalarOp( // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value operand = operands[0]; + auto elementType = result_types[0]; - auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto result = rewriter.create(loc, one, operand); return result;