Change the name and signature of mapToLowerScalarOp (#67)
* Revise mapToLowerScalarOp() * Update TanhOp Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
f4fefcf713
commit
4e66488ad3
|
@ -88,17 +88,15 @@ struct ScalarOp<ONNXSqrtOp> {
|
||||||
// Scalar unary ops for lowering ONNXSinhOp
|
// Scalar unary ops for lowering ONNXSinhOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXSinhOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
auto two = emitConstantOp(rewriter, loc, elementType, 2);
|
auto two = emitConstantOp(rewriter, loc, elementType, 2);
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
@ -112,17 +110,15 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXCoshOp
|
// Scalar unary ops for lowering ONNXCoshOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXCoshOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
auto two = emitConstantOp(rewriter, loc, elementType, 2);
|
auto two = emitConstantOp(rewriter, loc, elementType, 2);
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
@ -136,14 +132,12 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXTanhOp
|
// Scalar unary ops for lowering ONNXTanhOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXTanhOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
@ -160,15 +154,12 @@ Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXSigmoidOp
|
// Scalar unary ops for lowering ONNXSigmoidOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
Value emitScalarOpFor<ONNXSigmoidOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Type> result_types,
|
Location loc, Operation *op, Type elementType,
|
||||||
ArrayRef<Value> operands,
|
ArrayRef<Value> scalarOperands) {
|
||||||
ConversionPatternRewriter &rewriter) {
|
|
||||||
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
||||||
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||||
|
@ -184,9 +175,9 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
Value emitScalarOpFor<ONNXHardSigmoidOp>(ConversionPatternRewriter &rewriter,
|
||||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
||||||
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
||||||
// %Y,
|
// %Y,
|
||||||
|
@ -194,13 +185,11 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
// ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1),
|
// ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1),
|
||||||
// %Z,
|
// %Z,
|
||||||
// Constant 1)
|
// Constant 1)
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
|
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
|
||||||
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
|
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||||
|
@ -223,15 +212,13 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
// Scalar unary ops for lowering ONNXEluOp
|
// Scalar unary ops for lowering ONNXEluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXEluOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
|
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
|
||||||
|
@ -241,10 +228,9 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
auto result = rewriter.create<SelectOp>(
|
auto result = rewriter.create<SelectOp>(loc, lessThanZero,
|
||||||
loc, lessThanZero,
|
rewriter.create<MulFOp>(
|
||||||
rewriter.create<MulFOp>(loc, alpha,
|
loc, alpha, rewriter.create<SubFOp>(loc, exp, one)),
|
||||||
rewriter.create<SubFOp>(loc, exp, one)),
|
|
||||||
operand);
|
operand);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -254,15 +240,13 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXReluOp
|
// Scalar unary ops for lowering ONNXReluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXReluOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// ConstantOp 0,
|
// ConstantOp 0,
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
|
@ -276,16 +260,13 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXLeakyReluOp
|
// Scalar unary ops for lowering ONNXLeakyReluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
Value emitScalarOpFor<ONNXLeakyReluOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Type> result_types,
|
Location loc, Operation *op, Type elementType,
|
||||||
ArrayRef<Value> operands,
|
ArrayRef<Value> scalarOperands) {
|
||||||
ConversionPatternRewriter &rewriter) {
|
|
||||||
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// MulFOp(alpha, %X),
|
// MulFOp(alpha, %X),
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
|
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
|
||||||
|
@ -303,21 +284,19 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXSeluOp
|
// Scalar unary ops for lowering ONNXSeluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXSeluOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
||||||
// MulFOp(gamma, %X),
|
// MulFOp(gamma, %X),
|
||||||
// MulFOp(gamma,
|
// MulFOp(gamma,
|
||||||
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
||||||
// alpha)))
|
// alpha)))
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
|
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
|
||||||
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||||
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
|
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||||
|
@ -325,10 +304,9 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto greaterThanZero =
|
auto greaterThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||||
auto select = rewriter.create<SelectOp>(
|
auto select = rewriter.create<SelectOp>(loc, greaterThanZero, operand,
|
||||||
loc, greaterThanZero, operand,
|
rewriter.create<SubFOp>(
|
||||||
rewriter.create<SubFOp>(loc, rewriter.create<MulFOp>(loc, alpha, exp),
|
loc, rewriter.create<MulFOp>(loc, alpha, exp), alpha));
|
||||||
alpha));
|
|
||||||
auto result = rewriter.create<MulFOp>(loc, gamma, select);
|
auto result = rewriter.create<MulFOp>(loc, gamma, select);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -338,14 +316,11 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXReciprocalOp
|
// Scalar unary ops for lowering ONNXReciprocalOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
Value emitScalarOpFor<ONNXReciprocalOp>(ConversionPatternRewriter &rewriter,
|
||||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||||
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
||||||
|
|
||||||
|
@ -356,13 +331,11 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
||||||
// Scalar unary ops for lowering ONNXSoftplusOp
|
// Scalar unary ops for lowering ONNXSoftplusOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXSoftplusOp>(
|
Value emitScalarOpFor<ONNXSoftplusOp>(ConversionPatternRewriter &rewriter,
|
||||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
|
// ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||||
|
@ -376,13 +349,11 @@ Value mapToLowerScalarOp<ONNXSoftplusOp>(
|
||||||
// Scalar unary ops for lowering ONNXSoftsignOp
|
// Scalar unary ops for lowering ONNXSoftsignOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXSoftsignOp>(
|
Value emitScalarOpFor<ONNXSoftsignOp>(ConversionPatternRewriter &rewriter,
|
||||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
|
// ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||||
auto loc = op->getLoc();
|
Value operand = scalarOperands[0];
|
||||||
Value operand = operands[0];
|
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
auto abs = rewriter.create<AbsFOp>(loc, operand);
|
auto abs = rewriter.create<AbsFOp>(loc, operand);
|
||||||
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
auto one = emitConstantOp(rewriter, loc, elementType, 1);
|
||||||
|
@ -396,13 +367,10 @@ Value mapToLowerScalarOp<ONNXSoftsignOp>(
|
||||||
// Scalar unary ops for lowering ONNXSignOp
|
// Scalar unary ops for lowering ONNXSignOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXSignOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
|
Value operand = scalarOperands[0];
|
||||||
auto loc = op->getLoc();
|
|
||||||
Value operand = operands[0];
|
|
||||||
Type elementType = operands.front().getType();
|
|
||||||
// TODO: unsigned int should be supported separately?
|
// TODO: unsigned int should be supported separately?
|
||||||
if (elementType.isa<IntegerType>()) {
|
if (elementType.isa<IntegerType>()) {
|
||||||
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
|
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
|
||||||
|
@ -451,15 +419,14 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXMaxOp
|
// Scalar unary ops for lowering ONNXMaxOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXMaxOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
||||||
// %X,
|
// %X,
|
||||||
// %Y)
|
// %Y)
|
||||||
auto loc = op->getLoc();
|
Value lhs = scalarOperands[0];
|
||||||
Value lhs = operands[0];
|
Value rhs = scalarOperands[1];
|
||||||
Value rhs = operands[1];
|
|
||||||
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
|
@ -469,15 +436,14 @@ Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXMinOp
|
// Scalar unary ops for lowering ONNXMinOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXMinOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands,
|
Location loc, Operation *op, Type elementType,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ArrayRef<Value> scalarOperands) {
|
||||||
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
||||||
// %X,
|
// %X,
|
||||||
// %Y)
|
// %Y)
|
||||||
auto loc = op->getLoc();
|
Value lhs = scalarOperands[0];
|
||||||
Value lhs = operands[0];
|
Value rhs = scalarOperands[1];
|
||||||
Value rhs = operands[1];
|
|
||||||
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
|
@ -487,11 +453,10 @@ Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXAbsOp
|
// Scalar unary ops for lowering ONNXAbsOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXAbsOp>(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor<ONNXAbsOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) {
|
Location loc, Operation *op, Type elementType,
|
||||||
auto loc = op->getLoc();
|
ArrayRef<Value> scalarOperands) {
|
||||||
Value operand = operands[0];
|
Value operand = scalarOperands[0];
|
||||||
auto elementType = result_types[0];
|
|
||||||
|
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
return rewriter.create<AbsFOp>(loc, operand);
|
return rewriter.create<AbsFOp>(loc, operand);
|
||||||
|
@ -536,15 +501,14 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
alloc = insertAllocAndDealloc(
|
||||||
{operands[0]});
|
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
||||||
|
|
||||||
std::vector<Value> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
KrnlOptimizeLoopsOp optimizedLoopsOp;
|
KrnlOptimizeLoopsOp optimizedLoopsOp;
|
||||||
KrnlIterateOp iterateOp;
|
KrnlIterateOp iterateOp;
|
||||||
emitKrnlLoopsAndIterationForOperand(
|
emitKrnlLoopsAndIterationForOperand(
|
||||||
rewriter, loc, operands[0], originalLoops,
|
rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp);
|
||||||
optimizedLoopsOp, iterateOp);
|
|
||||||
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
Block &iterationBlock = iterateOp.bodyRegion().front();
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
@ -564,8 +528,8 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||||
auto loweredOpResult = mapToLowerScalarOp<ElementwiseUnaryOp>(
|
auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>(
|
||||||
op, memRefType.getElementType(), {loadedVal}, rewriter);
|
rewriter, loc, op, memRefType.getElementType(), {loadedVal});
|
||||||
// Store result in the resulting array.
|
// Store result in the resulting array.
|
||||||
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
|
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
|
||||||
|
|
||||||
|
@ -603,8 +567,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
alloc = insertAllocAndDealloc(
|
||||||
operands);
|
memRefType, loc, rewriter, insertDealloc, operands);
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
|
@ -615,8 +579,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
KrnlOptimizeLoopsOp optimizedLoopsOp;
|
KrnlOptimizeLoopsOp optimizedLoopsOp;
|
||||||
KrnlIterateOp iterateOp;
|
KrnlIterateOp iterateOp;
|
||||||
emitKrnlLoopsAndIterationForOperand(
|
emitKrnlLoopsAndIterationForOperand(
|
||||||
rewriter, loc, alloc, originalLoops,
|
rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp);
|
||||||
optimizedLoopsOp, iterateOp);
|
|
||||||
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
Block &iterationBlock = iterateOp.bodyRegion().front();
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
@ -643,8 +606,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
auto nextLoopIVs = getLoopIVsForBroadcasting(
|
auto nextLoopIVs = getLoopIVsForBroadcasting(
|
||||||
loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]);
|
loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]);
|
||||||
next = rewriter.create<LoadOp>(loc, operands[i], nextLoopIVs);
|
next = rewriter.create<LoadOp>(loc, operands[i], nextLoopIVs);
|
||||||
accumulated = mapToLowerScalarOp<ElementwiseVariadicOp>(
|
accumulated = emitScalarOpFor<ElementwiseVariadicOp>(
|
||||||
op, memRefType.getElementType(), {accumulated, next}, rewriter);
|
rewriter, loc, op, memRefType.getElementType(), {accumulated, next});
|
||||||
}
|
}
|
||||||
// Store result in the resulting array.
|
// Store result in the resulting array.
|
||||||
rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs);
|
rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs);
|
||||||
|
@ -658,31 +621,31 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
void populateLoweringONNXElementwiseOpPattern(
|
void populateLoweringONNXElementwiseOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXAbsOp>,
|
patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXAbsOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSignOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSignOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx);
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,13 +54,11 @@ struct ScalarOp<ONNXReduceSumOp> {
|
||||||
// Scalar unary ops for lowering ONNXReduceMaxOp
|
// Scalar unary ops for lowering ONNXReduceMaxOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op,
|
Value emitScalarOpFor<ONNXReduceMaxOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Type> result_types,
|
Location loc, Operation *op, Type elementType,
|
||||||
ArrayRef<Value> operands,
|
ArrayRef<Value> scalarOperands) {
|
||||||
ConversionPatternRewriter &rewriter) {
|
Value lhs = scalarOperands[0];
|
||||||
auto loc = op->getLoc();
|
Value rhs = scalarOperands[1];
|
||||||
Value lhs = operands[0];
|
|
||||||
Value rhs = operands[1];
|
|
||||||
Type element_type = lhs.getType();
|
Type element_type = lhs.getType();
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
|
auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
|
||||||
|
@ -80,19 +78,16 @@ Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op,
|
||||||
// Scalar unary ops for lowering ONNXReduceMinOp
|
// Scalar unary ops for lowering ONNXReduceMinOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXReduceMinOp>(Operation *op,
|
Value emitScalarOpFor<ONNXReduceMinOp>(ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Type> result_types,
|
Location loc, Operation *op, Type elementType,
|
||||||
ArrayRef<Value> operands,
|
ArrayRef<Value> scalarOperands) {
|
||||||
ConversionPatternRewriter &rewriter) {
|
Value lhs = scalarOperands[0];
|
||||||
auto loc = op->getLoc();
|
Value rhs = scalarOperands[1];
|
||||||
Value lhs = operands[0];
|
if (elementType.isa<IntegerType>()) {
|
||||||
Value rhs = operands[1];
|
|
||||||
Type element_type = lhs.getType();
|
|
||||||
if (element_type.isa<IntegerType>()) {
|
|
||||||
auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
|
auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
} else if (element_type.isa<FloatType>()) {
|
} else if (elementType.isa<FloatType>()) {
|
||||||
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
|
@ -129,7 +124,7 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
* Y(i1) += X(i0, i1, i2)
|
* Y(i1) += X(i0, i1, i2)
|
||||||
* }
|
* }
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto memRefInType = operands[0].getType().cast<MemRefType>();
|
auto memRefInType = operands[0].getType().cast<MemRefType>();
|
||||||
auto memRefInShape = memRefInType.getShape();
|
auto memRefInShape = memRefInType.getShape();
|
||||||
|
@ -154,8 +149,7 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// KeepDims
|
// KeepDims
|
||||||
auto keepdims =
|
auto keepdims = llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
|
||||||
llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
|
|
||||||
bool isKeepdims = (keepdims == 1) ? true : false;
|
bool isKeepdims = (keepdims == 1) ? true : false;
|
||||||
|
|
||||||
// Get type information
|
// Get type information
|
||||||
|
@ -168,7 +162,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
Value alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
if (hasAllConstantDimensions(memRefOutType)) {
|
if (hasAllConstantDimensions(memRefOutType)) {
|
||||||
alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc);
|
alloc =
|
||||||
|
insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc);
|
||||||
} else {
|
} else {
|
||||||
SmallVector<Value, 2> allocOperands;
|
SmallVector<Value, 2> allocOperands;
|
||||||
for (decltype(outRank) i = 0; i < outRank; ++i) {
|
for (decltype(outRank) i = 0; i < outRank; ++i) {
|
||||||
|
@ -192,12 +187,12 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
// Define loops to initialize the result.
|
// Define loops to initialize the result.
|
||||||
std::vector<Value> originalLoopsInit;
|
std::vector<Value> originalLoopsInit;
|
||||||
std::vector<Value> optimizedLoopsInit;
|
std::vector<Value> optimizedLoopsInit;
|
||||||
Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit,
|
Block *optimizationBlockInit = defineLoops(
|
||||||
optimizedLoopsInit, outRank);
|
rewriter, loc, originalLoopsInit, optimizedLoopsInit, outRank);
|
||||||
|
|
||||||
// Iteration information
|
// Iteration information
|
||||||
KrnlIterateOperandPack packInit(rewriter, originalLoopsInit,
|
KrnlIterateOperandPack packInit(
|
||||||
optimizedLoopsInit);
|
rewriter, originalLoopsInit, optimizedLoopsInit);
|
||||||
for (decltype(outRank) i = 0; i < outRank; ++i) {
|
for (decltype(outRank) i = 0; i < outRank; ++i) {
|
||||||
addDimensionToPack(rewriter, loc, packInit, alloc, i);
|
addDimensionToPack(rewriter, loc, packInit, alloc, i);
|
||||||
}
|
}
|
||||||
|
@ -225,8 +220,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
// Define an Krnl loop to do reduction.
|
// Define an Krnl loop to do reduction.
|
||||||
rewriter.setInsertionPointAfter(iterateOpInit);
|
rewriter.setInsertionPointAfter(iterateOpInit);
|
||||||
std::vector<Value> originalLoops, optimizedLoops;
|
std::vector<Value> originalLoops, optimizedLoops;
|
||||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
Block *optimizationBlock =
|
||||||
optimizedLoops, inRank);
|
defineLoops(rewriter, loc, originalLoops, optimizedLoops, inRank);
|
||||||
// Iteration information
|
// Iteration information
|
||||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
for (decltype(inRank) i = 0; i < inRank; ++i) {
|
for (decltype(inRank) i = 0; i < inRank; ++i) {
|
||||||
|
@ -266,8 +261,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
Value next, accumulated;
|
Value next, accumulated;
|
||||||
next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
||||||
accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs);
|
accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs);
|
||||||
accumulated = mapToLowerScalarOp<ONNXReductionOp>(
|
accumulated = emitScalarOpFor<ONNXReductionOp>(
|
||||||
op, memRefOutType.getElementType(), {accumulated, next}, rewriter);
|
rewriter, loc, op, memRefOutType.getElementType(), {accumulated, next});
|
||||||
rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs);
|
rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs);
|
||||||
|
|
||||||
rewriter.replaceOp(op, alloc);
|
rewriter.replaceOp(op, alloc);
|
||||||
|
@ -278,7 +273,7 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
void populateLoweringONNXReductionOpPattern(
|
void populateLoweringONNXReductionOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
patterns.insert<ONNXReductionOpLowering<mlir::ONNXReduceMaxOp>,
|
patterns.insert<ONNXReductionOpLowering<mlir::ONNXReduceMaxOp>,
|
||||||
ONNXReductionOpLowering<mlir::ONNXReduceMinOp>,
|
ONNXReductionOpLowering<mlir::ONNXReduceMinOp>,
|
||||||
ONNXReductionOpLowering<mlir::ONNXReduceProdOp>,
|
ONNXReductionOpLowering<mlir::ONNXReduceProdOp>,
|
||||||
ONNXReductionOpLowering<mlir::ONNXReduceSumOp>>(ctx);
|
ONNXReductionOpLowering<mlir::ONNXReduceSumOp>>(ctx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,12 +20,11 @@ Value getIdentityValue<ONNXMaxPoolSingleOutOp>(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Value mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(Operation *op,
|
Value emitScalarOpFor<ONNXMaxPoolSingleOutOp>(
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
ConversionPatternRewriter &rewriter, Location loc, Operation *op,
|
||||||
ConversionPatternRewriter &rewriter) {
|
Type elementType, ArrayRef<Value> scalarOperands) {
|
||||||
auto loc = op->getLoc();
|
Value lhs = scalarOperands[0];
|
||||||
Value lhs = operands[0];
|
Value rhs = scalarOperands[1];
|
||||||
Value rhs = operands[1];
|
|
||||||
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
|
@ -308,8 +307,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
||||||
auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
|
auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
|
||||||
auto loadPartialResult =
|
auto loadPartialResult =
|
||||||
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||||
Value result = mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(
|
Value result = emitScalarOpFor<ONNXMaxPoolSingleOutOp>(rewriter, loc,
|
||||||
op, resultElementType, {loadPartialResult, loadData}, rewriter);
|
op, resultElementType, {loadPartialResult, loadData});
|
||||||
rewriter.create<StoreOp>(loc, result, alloc, resultIndices);
|
rewriter.create<StoreOp>(loc, result, alloc, resultIndices);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,17 +148,14 @@ Value getIdentityValue(
|
||||||
// Use template specialization for each of different ONNX operations.
|
// Use template specialization for each of different ONNX operations.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
ArrayRef<Value> operands,
|
Operation *op, Type elementType, ArrayRef<Value> scalarOperands) {
|
||||||
ConversionPatternRewriter &rewriter) {
|
if (elementType.isa<IntegerType>()) {
|
||||||
auto loc = op->getLoc();
|
return rewriter.create<ScalarIOp<Op>>(
|
||||||
Type element_type = operands.front().getType();
|
loc, elementType, scalarOperands, mlir::None);
|
||||||
if (element_type.isa<IntegerType>()) {
|
} else if (elementType.isa<FloatType>()) {
|
||||||
return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
|
return rewriter.create<ScalarFOp<Op>>(
|
||||||
mlir::None);
|
loc, elementType, scalarOperands, mlir::None);
|
||||||
} else if (element_type.isa<FloatType>()) {
|
|
||||||
return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
|
|
||||||
mlir::None);
|
|
||||||
} else {
|
} else {
|
||||||
emitError(loc, "unsupported element type");
|
emitError(loc, "unsupported element type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -247,4 +244,3 @@ void populateLoweringONNXIdentityOpPattern(
|
||||||
|
|
||||||
void populateLoweringONNXConstantOpPattern(
|
void populateLoweringONNXConstantOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue