Lowering ReductionMax, ReductionMin, ReductionProd and ReductionSum (#31)
* Shape inference for reduction * Lower ReduceSum * Support list-like attributes * Add ReduceMax, ReduceMin, ReduceProd * Add tests * Emit errors for unsupported types * Typos * Add backend test * Fix axis computation * Update the use of attributes * Use SmallVector * Address stylistic comments * Change type from int to int64_t for indices * Change type from int to int64_t for indices
This commit is contained in:
		
							parent
							
								
									0272451521
								
							
						
					
					
						commit
						2c7046ff5f
					
				| 
						 | 
				
			
			@ -46,6 +46,7 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
 | 
			
		|||
                   'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
 | 
			
		||||
                   'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
 | 
			
		||||
                   'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
 | 
			
		||||
                   'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum',
 | 
			
		||||
                   'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign']
 | 
			
		||||
 | 
			
		||||
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,6 +24,54 @@
 | 
			
		|||
using namespace mlir;
 | 
			
		||||
using namespace mlir::OpTrait::util;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Get reduction type
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
RankedTensorType getReductionOutputType(RankedTensorType operandTy,
 | 
			
		||||
                                        Optional<ArrayAttr> axesAttrs,
 | 
			
		||||
                                        APInt keepdims) {
 | 
			
		||||
  int64_t rank = operandTy.getRank();
 | 
			
		||||
 | 
			
		||||
  SmallVector<int64_t, 4> axes;
 | 
			
		||||
  if (axesAttrs != llvm::None) {
 | 
			
		||||
    for (auto axisAttr : axesAttrs.getValue()) {
 | 
			
		||||
      int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
 | 
			
		||||
      axis = axis >= 0 ? axis : (rank + axis);
 | 
			
		||||
      assert(axis >= -rank && axis <= rank - 1);
 | 
			
		||||
      if (std::find(axes.begin(), axes.end(), axis) == axes.end())
 | 
			
		||||
        axes.emplace_back(axis);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    for (decltype(rank) i = 0; i < rank; ++i) {
 | 
			
		||||
      axes.emplace_back(i);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Mark reduction axes.
 | 
			
		||||
  SmallVector<bool, 4> isReductionAxis;
 | 
			
		||||
  for (decltype(rank) i = 0; i < rank; ++i) {
 | 
			
		||||
    if (std::find(axes.begin(), axes.end(), i) != axes.end())
 | 
			
		||||
      isReductionAxis.emplace_back(true);
 | 
			
		||||
    else
 | 
			
		||||
      isReductionAxis.emplace_back(false);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // KeepDims
 | 
			
		||||
  bool isKeepdims = (keepdims == 1) ? true : false;
 | 
			
		||||
 | 
			
		||||
  SmallVector<int64_t, 4> dims;
 | 
			
		||||
  for (decltype(rank) i = 0; i < rank; ++i) {
 | 
			
		||||
    if (isReductionAxis[i]) {
 | 
			
		||||
      if (isKeepdims)
 | 
			
		||||
        dims.emplace_back(1); // reduction dimension
 | 
			
		||||
    } else {
 | 
			
		||||
      dims.emplace_back(operandTy.getShape()[i]);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return RankedTensorType::get(dims, operandTy.getElementType());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ONNXOpsDialect
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			@ -608,6 +656,60 @@ void ONNXTransposeOp::inferShapes() {
 | 
			
		|||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// ReduceMax
 | 
			
		||||
 | 
			
		||||
void ONNXReduceMaxOp::inferShapes() {
 | 
			
		||||
  if (!getOperand().getType().isa<RankedTensorType>()) {
 | 
			
		||||
    emitError("Shape tensor not ranked.");
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto operandTy = getOperand().getType().cast<RankedTensorType>();
 | 
			
		||||
  getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// ReduceMin
 | 
			
		||||
 | 
			
		||||
void ONNXReduceMinOp::inferShapes() {
 | 
			
		||||
  if (!getOperand().getType().isa<RankedTensorType>()) {
 | 
			
		||||
    emitError("Shape tensor not ranked.");
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto operandTy = getOperand().getType().cast<RankedTensorType>();
 | 
			
		||||
  getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// ReduceProd
 | 
			
		||||
 | 
			
		||||
void ONNXReduceProdOp::inferShapes() {
 | 
			
		||||
  if (!getOperand().getType().isa<RankedTensorType>()) {
 | 
			
		||||
    emitError("Shape tensor not ranked.");
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto operandTy = getOperand().getType().cast<RankedTensorType>();
 | 
			
		||||
  getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// ReduceSum
 | 
			
		||||
 | 
			
		||||
void ONNXReduceSumOp::inferShapes() {
 | 
			
		||||
  if (!getOperand().getType().isa<RankedTensorType>()) {
 | 
			
		||||
    emitError("Shape tensor not ranked.");
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto operandTy = getOperand().getType().cast<RankedTensorType>();
 | 
			
		||||
  getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Conv
 | 
			
		||||
 | 
			
		||||
// For this operation, we define the attributes once in the original Conv
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2349,7 +2349,7 @@ def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", 
 | 
			
		||||
    [NoSideEffect]> {
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
			
		||||
  let summary = "ONNX ReduceMax operation";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    "Computes the max of the input tensor's element along the provided axes. The resulted"
 | 
			
		||||
| 
						 | 
				
			
			@ -2383,7 +2383,7 @@ def ONNXReduceMeanOp:ONNX_Op<"ReduceMean",
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
def ONNXReduceMinOp:ONNX_Op<"ReduceMin", 
 | 
			
		||||
    [NoSideEffect]> {
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
			
		||||
  let summary = "ONNX ReduceMin operation";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    "Computes the min of the input tensor's element along the provided axes. The resulted"
 | 
			
		||||
| 
						 | 
				
			
			@ -2400,7 +2400,7 @@ def ONNXReduceMinOp:ONNX_Op<"ReduceMin",
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
def ONNXReduceProdOp:ONNX_Op<"ReduceProd", 
 | 
			
		||||
    [NoSideEffect]> {
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
			
		||||
  let summary = "ONNX ReduceProd operation";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    "Computes the product of the input tensor's element along the provided axes. The resulted"
 | 
			
		||||
| 
						 | 
				
			
			@ -2417,7 +2417,7 @@ def ONNXReduceProdOp:ONNX_Op<"ReduceProd",
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
def ONNXReduceSumOp:ONNX_Op<"ReduceSum", 
 | 
			
		||||
    [NoSideEffect]> {
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
			
		||||
  let summary = "ONNX ReduceSum operation";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    "Computes the sum of the input tensor's element along the provided axes. The resulted"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -130,6 +130,37 @@ static bool checkInsertDealloc(Operation *currentOp) {
 | 
			
		|||
  return insertDealloc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Create a mapping from result type's dimensions to input type's dimensions,
 | 
			
		||||
// given that the result type is the result of a reduction op over the input
 | 
			
		||||
// type.
 | 
			
		||||
std::map<int64_t, int64_t>
 | 
			
		||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
 | 
			
		||||
  std::map<int64_t, int64_t> OutInDimMap;
 | 
			
		||||
  int64_t rank = inputTy.getRank();
 | 
			
		||||
 | 
			
		||||
  // Mark reduction axes.
 | 
			
		||||
  std::vector<bool> isReductionAxis;
 | 
			
		||||
  for (decltype(rank) i = 0; i < rank; ++i) {
 | 
			
		||||
    if (std::find(axes.begin(), axes.end(), i) != axes.end())
 | 
			
		||||
      isReductionAxis.push_back(true);
 | 
			
		||||
    else
 | 
			
		||||
      isReductionAxis.push_back(false);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
 | 
			
		||||
    // If it is a reduction axis, there is no relationship among dimensions.
 | 
			
		||||
    if (isReductionAxis[inIndex]) {
 | 
			
		||||
      if (keepdims)
 | 
			
		||||
        outIndex++;
 | 
			
		||||
    } else {
 | 
			
		||||
      OutInDimMap.insert(std::make_pair(outIndex, inIndex));
 | 
			
		||||
      outIndex++;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return OutInDimMap;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Add bounds associated with the op operand to the KRNL iteration pack.
 | 
			
		||||
// Dynamic dimenions are supported.
 | 
			
		||||
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
 | 
			
		||||
| 
						 | 
				
			
			@ -376,6 +407,18 @@ struct ScalarOp<ONNXLogOp> {
 | 
			
		|||
  using IOp = LogOp; // not use
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct ScalarOp<ONNXReduceProdOp> {
 | 
			
		||||
  using FOp = MulFOp;
 | 
			
		||||
  using IOp = MulIOp;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct ScalarOp<ONNXReduceSumOp> {
 | 
			
		||||
  using FOp = AddFOp;
 | 
			
		||||
  using IOp = AddIOp;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct ScalarOp<ONNXSqrtOp> {
 | 
			
		||||
  using FOp = KrnlSqrtOp;
 | 
			
		||||
| 
						 | 
				
			
			@ -387,6 +430,53 @@ using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
 | 
			
		|||
template <typename ElementwiseNaryOp>
 | 
			
		||||
using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
 | 
			
		||||
 | 
			
		||||
// Get the identity element of a operation.
 | 
			
		||||
// Return NULL if the function does not have identity.
 | 
			
		||||
template <typename DataType, typename Op>
 | 
			
		||||
DataType getIdentityValue() {
 | 
			
		||||
  return NULL;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
float getIdentityValue<float, ONNXReduceMaxOp>(){
 | 
			
		||||
  return (float)-std::numeric_limits<float>::infinity();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
int getIdentityValue<int, ONNXReduceMaxOp>(){
 | 
			
		||||
  return std::numeric_limits<int>::min();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
float getIdentityValue<float, ONNXReduceMinOp>(){
 | 
			
		||||
  return (float)std::numeric_limits<float>::infinity();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
int getIdentityValue<int, ONNXReduceMinOp>(){
 | 
			
		||||
  return std::numeric_limits<int>::max();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
float getIdentityValue<float, ONNXReduceProdOp>(){
 | 
			
		||||
  return (float)1.0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
int getIdentityValue<int, ONNXReduceProdOp>(){
 | 
			
		||||
  return 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
float getIdentityValue<float, ONNXReduceSumOp>(){
 | 
			
		||||
  return (float)0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
int getIdentityValue<int, ONNXReduceSumOp>(){
 | 
			
		||||
  return 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Scalar unary ops for lowering to Krnl dialect.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			@ -788,6 +878,58 @@ Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
 | 
			
		|||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Scalar unary ops for lowering ONNXReduceMaxOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
template <>
 | 
			
		||||
Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op,
 | 
			
		||||
                                          ArrayRef<Type> result_types,
 | 
			
		||||
                                          ArrayRef<Value> operands,
 | 
			
		||||
                                          ConversionPatternRewriter &rewriter) {
 | 
			
		||||
  auto loc = op->getLoc();
 | 
			
		||||
  Value lhs = operands[0];
 | 
			
		||||
  Value rhs = operands[1];
 | 
			
		||||
  Type element_type = lhs.getType();
 | 
			
		||||
  if (element_type.isa<IntegerType>()) {
 | 
			
		||||
    auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
 | 
			
		||||
    auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
 | 
			
		||||
    return result;
 | 
			
		||||
  } else if (element_type.isa<FloatType>()) {
 | 
			
		||||
    auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
 | 
			
		||||
    auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
 | 
			
		||||
    return result;
 | 
			
		||||
  } else {
 | 
			
		||||
    emitError(loc, "unsupported element type");
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Scalar unary ops for lowering ONNXReduceMinOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
template <>
 | 
			
		||||
Value mapToLowerScalarOp<ONNXReduceMinOp>(Operation *op,
 | 
			
		||||
                                          ArrayRef<Type> result_types,
 | 
			
		||||
                                          ArrayRef<Value> operands,
 | 
			
		||||
                                          ConversionPatternRewriter &rewriter) {
 | 
			
		||||
  auto loc = op->getLoc();
 | 
			
		||||
  Value lhs = operands[0];
 | 
			
		||||
  Value rhs = operands[1];
 | 
			
		||||
  Type element_type = lhs.getType();
 | 
			
		||||
  if (element_type.isa<IntegerType>()) {
 | 
			
		||||
    auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
 | 
			
		||||
    auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
 | 
			
		||||
    return result;
 | 
			
		||||
  } else if (element_type.isa<FloatType>()) {
 | 
			
		||||
    auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
 | 
			
		||||
    auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
 | 
			
		||||
    return result;
 | 
			
		||||
  } else {
 | 
			
		||||
    emitError(loc, "unsupported element type");
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Element-wise unary ops lowering to Krnl dialect.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
template <typename ElementwiseUnaryOp>
 | 
			
		||||
| 
						 | 
				
			
			@ -1823,6 +1965,193 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
			
		|||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Reduction ops lowering to Krnl dialect.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
template <typename ONNXReductionOp>
 | 
			
		||||
struct ONNXReductionOpLowering : public ConversionPattern {
 | 
			
		||||
  ONNXReductionOpLowering(MLIRContext *ctx)
 | 
			
		||||
      : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
 | 
			
		||||
 | 
			
		||||
  PatternMatchResult
 | 
			
		||||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
                  ConversionPatternRewriter &rewriter) const final {
 | 
			
		||||
    /*
 | 
			
		||||
     * Condition: reduction function must be associative and commutative.
 | 
			
		||||
     *
 | 
			
		||||
     * Example 1 (here, reduction function is `+`):
 | 
			
		||||
     * Induction variables: (i0, i1, i2)
 | 
			
		||||
     * axes = [0, 2]
 | 
			
		||||
     * keepdims = true
 | 
			
		||||
     * krnl.iterate() with (i0, i1, i2) {
 | 
			
		||||
     *   Y(0, i1, 0) += X(i0, i1, i2)
 | 
			
		||||
     * }
 | 
			
		||||
     *
 | 
			
		||||
     * Example 2 (here, reduction function is `+`):
 | 
			
		||||
     * Induction variables: (i0, i1, i2)
 | 
			
		||||
     * axes = [0, 2]
 | 
			
		||||
     * keepdims = false
 | 
			
		||||
     * krnl.iterate() with (i0, i1, i2) {
 | 
			
		||||
     *   Y(i1) += X(i0, i1, i2)
 | 
			
		||||
     * }
 | 
			
		||||
     *
 | 
			
		||||
    */
 | 
			
		||||
    auto loc = op->getLoc();
 | 
			
		||||
    auto memRefInType = operands[0].getType().cast<MemRefType>();
 | 
			
		||||
    auto memRefInShape = memRefInType.getShape();
 | 
			
		||||
    auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
 | 
			
		||||
    int64_t inRank = memRefInType.getRank();
 | 
			
		||||
    int64_t outRank = tensorOutType.getRank();
 | 
			
		||||
 | 
			
		||||
    // Get attributes
 | 
			
		||||
    ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
 | 
			
		||||
    std::vector<int64_t> axes;
 | 
			
		||||
    if (axisAttrs) {
 | 
			
		||||
      for (auto axisAttr : axisAttrs.getValue()) {
 | 
			
		||||
        int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
 | 
			
		||||
        axis = axis >= 0 ? axis : (inRank + axis);
 | 
			
		||||
        assert(axis >= -inRank && axis <= inRank - 1);
 | 
			
		||||
        if (std::find(axes.begin(), axes.end(), axis) == axes.end())
 | 
			
		||||
          axes.push_back(axis);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      for (decltype(inRank) i = 0; i < inRank; ++i) {
 | 
			
		||||
        axes.push_back(i);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    // KeepDims
 | 
			
		||||
    auto keepdims =
 | 
			
		||||
        llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
 | 
			
		||||
    bool isKeepdims = (keepdims == 1) ? true : false;
 | 
			
		||||
 | 
			
		||||
    // Get type information
 | 
			
		||||
    auto memRefOutType = convertTensorToMemRef(tensorOutType);
 | 
			
		||||
    auto memRefOutShape = memRefOutType.getShape();
 | 
			
		||||
    auto elementOutType = memRefOutType.getElementType();
 | 
			
		||||
    std::map<int64_t, int64_t> outInDimMap =
 | 
			
		||||
        getReductionMapping(memRefInType, axes, isKeepdims);
 | 
			
		||||
 | 
			
		||||
    // Insert an allocation and deallocation for the result of this operation.
 | 
			
		||||
    Value alloc;
 | 
			
		||||
    bool insertDealloc = checkInsertDealloc(op);
 | 
			
		||||
    if (hasAllConstantDimensions(memRefOutType)) {
 | 
			
		||||
      alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc);
 | 
			
		||||
    } else {
 | 
			
		||||
      SmallVector<Value, 2> allocOperands;
 | 
			
		||||
      for (decltype(outRank) i = 0; i < outRank; ++i) {
 | 
			
		||||
        if (memRefOutShape[i] < 0) {
 | 
			
		||||
          auto dim = rewriter.create<DimOp>(loc, operands[0], outInDimMap[i]);
 | 
			
		||||
          allocOperands.push_back(dim);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      alloc = rewriter.create<AllocOp>(loc, memRefOutType, allocOperands);
 | 
			
		||||
      if (insertDealloc) {
 | 
			
		||||
        auto *parentBlock = alloc.getDefiningOp()->getBlock();
 | 
			
		||||
        auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
 | 
			
		||||
        dealloc.getOperation()->moveBefore(&parentBlock->back());
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // There are two Krnl loops:
 | 
			
		||||
    // - One to initialize the result memref, and
 | 
			
		||||
    // - One to do reduction
 | 
			
		||||
 | 
			
		||||
    // Define loops to initialize the result.
 | 
			
		||||
    std::vector<Value> originalLoopsInit;
 | 
			
		||||
    std::vector<Value> optimizedLoopsInit;
 | 
			
		||||
    Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit,
 | 
			
		||||
            optimizedLoopsInit, outRank);
 | 
			
		||||
 | 
			
		||||
    // Iteration information
 | 
			
		||||
    KrnlIterateOperandPack packInit(rewriter, originalLoopsInit,
 | 
			
		||||
        optimizedLoopsInit);
 | 
			
		||||
    for (decltype(outRank) i = 0; i < outRank; ++i) {
 | 
			
		||||
      addDimensionToPack(rewriter, loc, packInit, alloc, i);
 | 
			
		||||
    }
 | 
			
		||||
    auto iterateOpInit = rewriter.create<KrnlIterateOp>(loc, packInit);
 | 
			
		||||
    Block &iterationBlockInit = iterateOpInit.bodyRegion().front();
 | 
			
		||||
 | 
			
		||||
    // Perform the insertions into the body of the initialization loop.
 | 
			
		||||
    // No optimization
 | 
			
		||||
    rewriter.setInsertionPointToEnd(optimizationBlockInit);
 | 
			
		||||
    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoopsInit);
 | 
			
		||||
 | 
			
		||||
    // Insert instructions inside the KernelIterateOp body.
 | 
			
		||||
    rewriter.setInsertionPointToStart(&iterationBlockInit);
 | 
			
		||||
 | 
			
		||||
    // Handle the operation:
 | 
			
		||||
    SmallVector<Value, 4> loopIVs;
 | 
			
		||||
    for (auto arg : iterationBlockInit.getArguments()) {
 | 
			
		||||
      loopIVs.push_back(arg);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Value identity;
 | 
			
		||||
    if (elementOutType.isa<FloatType>()) {
 | 
			
		||||
      identity = rewriter.create<ConstantOp>(
 | 
			
		||||
          loc, FloatAttr::get(elementOutType,
 | 
			
		||||
                              getIdentityValue<float, ONNXReductionOp>()));
 | 
			
		||||
    } else if (elementOutType.isa<IntegerType>()) {
 | 
			
		||||
      identity = rewriter.create<ConstantOp>(
 | 
			
		||||
          loc, IntegerAttr::get(elementOutType,
 | 
			
		||||
                                getIdentityValue<int, ONNXReductionOp>()));
 | 
			
		||||
    } else {
 | 
			
		||||
      emitError(loc, "unsupported element type");
 | 
			
		||||
    }
 | 
			
		||||
    rewriter.create<StoreOp>(loc, identity, alloc, loopIVs);
 | 
			
		||||
 | 
			
		||||
    // Define an Krnl loop to do reduction.
 | 
			
		||||
    rewriter.setInsertionPointAfter(iterateOpInit);
 | 
			
		||||
    std::vector<Value> originalLoops, optimizedLoops;
 | 
			
		||||
    Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
 | 
			
		||||
            optimizedLoops, inRank);
 | 
			
		||||
    // Iteration information
 | 
			
		||||
    KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
 | 
			
		||||
    for (decltype(inRank) i = 0; i < inRank; ++i) {
 | 
			
		||||
      addDimensionToPack(rewriter, loc, pack, operands[0], i);
 | 
			
		||||
    }
 | 
			
		||||
    auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
			
		||||
    Block &iterationBlock = iterateOp.bodyRegion().front();
 | 
			
		||||
 | 
			
		||||
    // Perform the insertions into the body of the reduction loop.
 | 
			
		||||
    // No optimization
 | 
			
		||||
    rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
			
		||||
    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
			
		||||
 | 
			
		||||
    // Insert instructions inside the KernelIterateOp body.
 | 
			
		||||
    rewriter.setInsertionPointToStart(&iterationBlock);
 | 
			
		||||
 | 
			
		||||
    // Handle the operation:
 | 
			
		||||
    SmallVector<Value, 4> inLoopIVs, outLoopIVs;
 | 
			
		||||
    auto args = iterationBlock.getArguments();
 | 
			
		||||
    for (int i = 0; i < args.size(); ++i) {
 | 
			
		||||
      inLoopIVs.push_back(args[i]);
 | 
			
		||||
    }
 | 
			
		||||
    Value zeroIndex = nullptr;
 | 
			
		||||
    for (decltype(inRank) i = 0; i < outRank; ++i) {
 | 
			
		||||
      if (outInDimMap.find(i) != outInDimMap.end()) {
 | 
			
		||||
        outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]);
 | 
			
		||||
      } else {
 | 
			
		||||
        if (zeroIndex) {
 | 
			
		||||
          outLoopIVs.push_back(zeroIndex);
 | 
			
		||||
        } else {
 | 
			
		||||
          zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
          outLoopIVs.push_back(zeroIndex);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Value next, accumulated;
 | 
			
		||||
    next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
 | 
			
		||||
    accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs);
 | 
			
		||||
    accumulated = mapToLowerScalarOp<ONNXReductionOp>(
 | 
			
		||||
        op, memRefOutType.getElementType(), {accumulated, next}, rewriter);
 | 
			
		||||
    rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs);
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, alloc);
 | 
			
		||||
    return matchSuccess();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// EntryPoint Op lowering to Krnl Entry Point.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			@ -1952,6 +2281,10 @@ void FrontendToKrnlLoweringPass::runOnModule() {
 | 
			
		|||
                  ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
 | 
			
		||||
                  ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
 | 
			
		||||
                  ONNXReshapeOpLowering, ONNXEntryPointLowering,
 | 
			
		||||
                  ONNXReductionOpLowering<mlir::ONNXReduceMaxOp>,
 | 
			
		||||
                  ONNXReductionOpLowering<mlir::ONNXReduceMinOp>,
 | 
			
		||||
                  ONNXReductionOpLowering<mlir::ONNXReduceProdOp>,
 | 
			
		||||
                  ONNXReductionOpLowering<mlir::ONNXReduceSumOp>,
 | 
			
		||||
                  ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
 | 
			
		||||
                  ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering,
 | 
			
		||||
                  ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -121,6 +121,10 @@ public:
 | 
			
		|||
        op->getName().getStringRef() != "onnx.GemmNoBias" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Reshape" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Transpose" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.ReduceMax" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.ReduceMin" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.ReduceProd" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.ReduceSum" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Softmax" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Sqrt" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.ConvNoBias" &&
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -134,6 +134,46 @@ test_to_enable = [
 | 
			
		|||
    # Relu Op:
 | 
			
		||||
    "test_relu_cpu",
 | 
			
		||||
 | 
			
		||||
    # ReduceMax Op:
 | 
			
		||||
    "test_reduce_max_default_axes_keepdim_example_cpu",
 | 
			
		||||
    "test_reduce_max_default_axes_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_max_do_not_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_max_do_not_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_max_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_max_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_max_negative_axes_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_max_negative_axes_keepdims_random_cpu",
 | 
			
		||||
 | 
			
		||||
    # ReduceMin Op:
 | 
			
		||||
    "test_reduce_min_default_axes_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_min_default_axes_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_min_do_not_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_min_do_not_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_min_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_min_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_min_negative_axes_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_min_negative_axes_keepdims_random_cpu",
 | 
			
		||||
 | 
			
		||||
    # ReduceProd Op:
 | 
			
		||||
    "test_reduce_prod_default_axes_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_prod_default_axes_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_prod_do_not_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_prod_do_not_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_prod_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_prod_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_prod_negative_axes_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_prod_negative_axes_keepdims_random_cpu",
 | 
			
		||||
 | 
			
		||||
    # ReduceSum Op:
 | 
			
		||||
    "test_reduce_sum_default_axes_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_sum_default_axes_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_sum_do_not_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_sum_do_not_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_sum_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_sum_keepdims_random_cpu",
 | 
			
		||||
    "test_reduce_sum_negative_axes_keepdims_example_cpu",
 | 
			
		||||
    "test_reduce_sum_negative_axes_keepdims_random_cpu",
 | 
			
		||||
 | 
			
		||||
    # Selu Op:
 | 
			
		||||
    "test_selu_cpu",
 | 
			
		||||
    "test_selu_default_cpu",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -587,6 +587,116 @@ func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>
 | 
			
		|||
  // CHECK: return [[RES]] : memref<?x10xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func @test_reducemax(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 ="onnx.ReduceMax"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
 | 
			
		||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
			
		||||
 | 
			
		||||
  // CHECK-LABEL: test_reducemax
 | 
			
		||||
  // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
 | 
			
		||||
  // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
 | 
			
		||||
  // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) {
 | 
			
		||||
  // CHECK: [[IDENTITY:%.+]] = constant 0xFF800000 : f32
 | 
			
		||||
  // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32>
 | 
			
		||||
 | 
			
		||||
  // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3
 | 
			
		||||
  // CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) {
 | 
			
		||||
  // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32>
 | 
			
		||||
  // CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32>
 | 
			
		||||
  // CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD2]], [[LOAD1]] : f32
 | 
			
		||||
  // CHECK: [[SELECT:%.+]] = select %7, %6, %5 : f32
 | 
			
		||||
  // CHECK: store [[SELECT]], [[RES]][%arg1, %arg3] : memref<3x2xf32>
 | 
			
		||||
  // CHECK: }
 | 
			
		||||
  // CHECK: return [[RES]] : memref<3x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func @test_reducemin(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 ="onnx.ReduceMin"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
 | 
			
		||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
			
		||||
 | 
			
		||||
  // CHECK-LABEL: test_reducemin
 | 
			
		||||
  // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
 | 
			
		||||
  // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
 | 
			
		||||
  // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) {
 | 
			
		||||
  // CHECK: [[IDENTITY:%.+]] = constant 0x7F800000 : f32
 | 
			
		||||
  // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32>
 | 
			
		||||
 | 
			
		||||
  // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3
 | 
			
		||||
  // CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) {
 | 
			
		||||
  // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32>
 | 
			
		||||
  // CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32>
 | 
			
		||||
  // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD2]], [[LOAD1]] : f32
 | 
			
		||||
  // CHECK: [[SELECT:%.+]] = select %7, %6, %5 : f32
 | 
			
		||||
  // CHECK: store [[SELECT]], [[RES]][%arg1, %arg3] : memref<3x2xf32>
 | 
			
		||||
  // CHECK: }
 | 
			
		||||
  // CHECK: return [[RES]] : memref<3x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func @test_reduceprod(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 ="onnx.ReduceProd"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
 | 
			
		||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
			
		||||
 | 
			
		||||
  // CHECK-LABEL: test_reduceprod
 | 
			
		||||
  // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
 | 
			
		||||
  // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
 | 
			
		||||
  // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) {
 | 
			
		||||
  // CHECK: [[IDENTITY:%.+]] = constant 1.000000e+00 : f32
 | 
			
		||||
  // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32>
 | 
			
		||||
 | 
			
		||||
  // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3
 | 
			
		||||
  // CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) {
 | 
			
		||||
  // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32>
 | 
			
		||||
  // CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32>
 | 
			
		||||
  // CHECK: [[REDUCE:%.+]] = mulf %6, %5 : f32
 | 
			
		||||
  // CHECK: store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32>
 | 
			
		||||
  // CHECK: }
 | 
			
		||||
  // CHECK: return [[RES]] : memref<3x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func @test_reducesum(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 ="onnx.ReduceSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
 | 
			
		||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
			
		||||
 | 
			
		||||
  // CHECK-LABEL: test_reducesum
 | 
			
		||||
  // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
 | 
			
		||||
  // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
 | 
			
		||||
  // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) {
 | 
			
		||||
  // CHECK: [[IDENTITY:%.+]] = constant 0.000000e+00 : f32
 | 
			
		||||
  // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32>
 | 
			
		||||
 | 
			
		||||
  // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3
 | 
			
		||||
  // CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) {
 | 
			
		||||
  // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32>
 | 
			
		||||
  // CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32>
 | 
			
		||||
  // CHECK: [[REDUCE:%.+]] = addf %6, %5 : f32
 | 
			
		||||
  // CHECK: store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32>
 | 
			
		||||
  // CHECK: }
 | 
			
		||||
  // CHECK: return [[RES]] : memref<3x2xf32>
 | 
			
		||||
}
 | 
			
		||||
  
 | 
			
		||||
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = "onnx.Softmax"(%arg0) {axis=1:i64} : (tensor<10x10xf32>) -> tensor<*xf32>
 | 
			
		||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue