Fix end-to-end tests. (#52)

* Fix end-to-end tests.

* Use dyn_cast.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-27 11:35:45 -05:00 committed by GitHub
parent c74f814f64
commit 95cf939c5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 13 deletions

View File

@ -420,14 +420,16 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
// Constant 1) // Constant 1)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha"); auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
auto betaAttr = op->getAttrOfType<FloatAttr>("beta"); llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
auto elementType = result_types[0]; auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1)); auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr); auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);
auto add = rewriter.create<AddFOp>( auto add = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, alpha, operand), beta); loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
@ -455,10 +457,11 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0]; auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha"); auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1)); auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero = auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
@ -508,9 +511,10 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0]; auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha"); auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto lessThanZero = auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>( auto result = rewriter.create<SelectOp>(
@ -533,13 +537,15 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
// alpha))) // alpha)))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha"); auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
auto gammaAttr = op->getAttrOfType<FloatAttr>("gamma"); llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
auto elementType = result_types[0]; auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr); auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto greaterThanZero = auto greaterThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
@ -876,7 +882,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// exp_x / sum // exp_x / sum
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>(); auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
int64_t rank = tensorType.getRank(); int64_t rank = tensorType.getRank();
int64_t axis = op->getAttrOfType<IntegerAttr>("axis").getInt(); int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
axis = axis >= 0 ? axis : rank + axis; axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1); assert(axis >= -rank && axis <= rank - 1);