Fix end-to-end tests. (#52)

* Fix end-to-end tests.

* Use dyn_cast.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-27 11:35:45 -05:00 committed by GitHub
parent c74f814f64
commit 95cf939c5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 13 deletions

View File

@ -420,14 +420,16 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
// Constant 1)
auto loc = op->getLoc();
Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("beta");
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);
auto add = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
@ -455,10 +457,11 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
@ -508,9 +511,10 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(
@ -533,13 +537,15 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
// alpha)))
auto loc = op->getLoc();
Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("gamma");
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto greaterThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
@ -876,7 +882,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// exp_x / sum
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
int64_t rank = tensorType.getRank();
int64_t axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1);