From 95cf939c5cb5679a77b302cb215224a941a2c5ef Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Mon, 27 Jan 2020 11:35:45 -0500 Subject: [PATCH] Fix end-to-end tests. (#52) * Fix end-to-end tests. * Use dyn_cast. --- src/pass/lower_frontend_to_krnl.cpp | 32 +++++++++++++++++------------ 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 020204a..1402748 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -420,14 +420,16 @@ Value mapToLowerScalarOp( // Constant 1) auto loc = op->getLoc(); Value operand = operands[0]; - auto alphaAttr = op->getAttrOfType("alpha"); - auto betaAttr = op->getAttrOfType("beta"); + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).beta().convertToFloat()); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto alpha = rewriter.create(loc, alphaAttr); - auto beta = rewriter.create(loc, betaAttr); + auto alpha = rewriter.create(loc, alphaAttribute); + auto beta = rewriter.create(loc, betaAttribute); auto add = rewriter.create( loc, rewriter.create(loc, alpha, operand), beta); @@ -455,10 +457,11 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, Value operand = operands[0]; auto elementType = result_types[0]; - auto alphaAttr = op->getAttrOfType("alpha"); + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto alpha = rewriter.create(loc, alphaAttr); + auto alpha = rewriter.create(loc, alphaAttribute); auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); @@ -508,9 +511,10 @@ Value mapToLowerScalarOp(Operation *op, Value operand = operands[0]; auto elementType = result_types[0]; - auto alphaAttr = op->getAttrOfType("alpha"); + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto alpha = rewriter.create(loc, alphaAttr); + auto alpha = rewriter.create(loc, alphaAttribute); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( @@ -533,13 +537,15 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // alpha))) auto loc = op->getLoc(); Value operand = operands[0]; - auto alphaAttr = op->getAttrOfType("alpha"); - auto gammaAttr = op->getAttrOfType("gamma"); + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).gamma().convertToFloat()); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto alpha = rewriter.create(loc, alphaAttr); - auto gamma = rewriter.create(loc, gammaAttr); + auto alpha = rewriter.create(loc, alphaAttribute); + auto gamma = rewriter.create(loc, gammaAttribute); auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); @@ -876,7 +882,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { // exp_x / sum auto tensorType = (*op->result_type_begin()).cast(); int64_t rank = tensorType.getRank(); - int64_t axis = op->getAttrOfType("axis").getInt(); + int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); axis = axis >= 0 ? axis : rank + axis; assert(axis >= -rank && axis <= rank - 1);