parent
c74f814f64
commit
95cf939c5c
|
@ -420,14 +420,16 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
|||
// Constant 1)
|
||||
auto loc = op->getLoc();
|
||||
Value operand = operands[0];
|
||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
||||
auto betaAttr = op->getAttrOfType<FloatAttr>("beta");
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
|
||||
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);
|
||||
|
||||
auto add = rewriter.create<AddFOp>(
|
||||
loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
|
||||
|
@ -455,10 +457,11 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||
auto lessThanZero =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||
|
@ -508,9 +511,10 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
auto lessThanZero =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||
auto result = rewriter.create<SelectOp>(
|
||||
|
@ -533,13 +537,15 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
// alpha)))
|
||||
auto loc = op->getLoc();
|
||||
Value operand = operands[0];
|
||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
||||
auto gammaAttr = op->getAttrOfType<FloatAttr>("gamma");
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
|
||||
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
|
||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||
auto greaterThanZero =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||
|
@ -876,7 +882,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
// exp_x / sum
|
||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||
int64_t rank = tensorType.getRank();
|
||||
int64_t axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
|
||||
axis = axis >= 0 ? axis : rank + axis;
|
||||
assert(axis >= -rank && axis <= rank - 1);
|
||||
|
||||
|
|
Loading…
Reference in New Issue