parent
c74f814f64
commit
95cf939c5c
|
@ -420,14 +420,16 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
// Constant 1)
|
// Constant 1)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
auto betaAttr = op->getAttrOfType<FloatAttr>("beta");
|
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
|
||||||
|
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
|
||||||
auto elementType = result_types[0];
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);
|
||||||
|
|
||||||
auto add = rewriter.create<AddFOp>(
|
auto add = rewriter.create<AddFOp>(
|
||||||
loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
|
loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
|
||||||
|
@ -455,10 +457,11 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto elementType = result_types[0];
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
|
@ -508,9 +511,10 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto elementType = result_types[0];
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
auto result = rewriter.create<SelectOp>(
|
auto result = rewriter.create<SelectOp>(
|
||||||
|
@ -533,13 +537,15 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// alpha)))
|
// alpha)))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
auto gammaAttr = op->getAttrOfType<FloatAttr>("gamma");
|
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
|
||||||
|
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
|
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
|
||||||
auto elementType = result_types[0];
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
|
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto greaterThanZero =
|
auto greaterThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||||
|
@ -876,7 +882,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
// exp_x / sum
|
// exp_x / sum
|
||||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||||
int64_t rank = tensorType.getRank();
|
int64_t rank = tensorType.getRank();
|
||||||
int64_t axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
|
||||||
axis = axis >= 0 ? axis : rank + axis;
|
axis = axis >= 0 ? axis : rank + axis;
|
||||||
assert(axis >= -rank && axis <= rank - 1);
|
assert(axis >= -rank && axis <= rank - 1);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue