diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 2b150d6..6bb0d7e 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -46,6 +46,7 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', + 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign'] CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 0946a65..fafc834 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -24,6 +24,54 @@ using namespace mlir; using namespace mlir::OpTrait::util; +//===----------------------------------------------------------------------===// +// Get reduction type +//===----------------------------------------------------------------------===// +RankedTensorType getReductionOutputType(RankedTensorType operandTy, + Optional axesAttrs, + APInt keepdims) { + int64_t rank = operandTy.getRank(); + + SmallVector axes; + if (axesAttrs != llvm::None) { + for (auto axisAttr : axesAttrs.getValue()) { + int64_t axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (rank + axis); + assert(axis >= -rank && axis <= rank - 1); + if (std::find(axes.begin(), axes.end(), axis) == axes.end()) + axes.emplace_back(axis); + } + } else { + for (decltype(rank) i = 0; i < rank; ++i) { + axes.emplace_back(i); + } + } + + // Mark reduction axes. + SmallVector isReductionAxis; + for (decltype(rank) i = 0; i < rank; ++i) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) + isReductionAxis.emplace_back(true); + else + isReductionAxis.emplace_back(false); + } + + // KeepDims + bool isKeepdims = (keepdims == 1) ? true : false; + + SmallVector dims; + for (decltype(rank) i = 0; i < rank; ++i) { + if (isReductionAxis[i]) { + if (isKeepdims) + dims.emplace_back(1); // reduction dimension + } else { + dims.emplace_back(operandTy.getShape()[i]); + } + } + + return RankedTensorType::get(dims, operandTy.getElementType()); +} + //===----------------------------------------------------------------------===// // ONNXOpsDialect //===----------------------------------------------------------------------===// @@ -608,6 +656,60 @@ void ONNXTransposeOp::inferShapes() { //===----------------------------------------------------------------------===// +// ReduceMax + +void ONNXReduceMaxOp::inferShapes() { + if (!getOperand().getType().isa()) { + emitError("Shape tensor not ranked."); + return; + } + + auto operandTy = getOperand().getType().cast(); + getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); +} + +//===----------------------------------------------------------------------===// + +// ReduceMin + +void ONNXReduceMinOp::inferShapes() { + if (!getOperand().getType().isa()) { + emitError("Shape tensor not ranked."); + return; + } + + auto operandTy = getOperand().getType().cast(); + getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); +} + +//===----------------------------------------------------------------------===// + +// ReduceProd + +void ONNXReduceProdOp::inferShapes() { + if (!getOperand().getType().isa()) { + emitError("Shape tensor not ranked."); + return; + } + + auto operandTy = getOperand().getType().cast(); + getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); +} + +//===----------------------------------------------------------------------===// + +// ReduceSum + +void ONNXReduceSumOp::inferShapes() { + if (!getOperand().getType().isa()) { + emitError("Shape tensor not ranked."); + return; + } + + auto operandTy = getOperand().getType().cast(); + getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); +} + // Conv // For this operation, we define the attributes once in the original Conv diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 2956238..b293000 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -2349,7 +2349,7 @@ def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", } def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMax operation"; let description = [{ "Computes the max of the input tensor's element along the provided axes. The resulted" @@ -2383,7 +2383,7 @@ def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", } def ONNXReduceMinOp:ONNX_Op<"ReduceMin", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMin operation"; let description = [{ "Computes the min of the input tensor's element along the provided axes. The resulted" @@ -2400,7 +2400,7 @@ def ONNXReduceMinOp:ONNX_Op<"ReduceMin", } def ONNXReduceProdOp:ONNX_Op<"ReduceProd", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceProd operation"; let description = [{ "Computes the product of the input tensor's element along the provided axes. The resulted" @@ -2417,7 +2417,7 @@ def ONNXReduceProdOp:ONNX_Op<"ReduceProd", } def ONNXReduceSumOp:ONNX_Op<"ReduceSum", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSum operation"; let description = [{ "Computes the sum of the input tensor's element along the provided axes. The resulted" diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 040b1be..e2354c9 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -130,6 +130,37 @@ static bool checkInsertDealloc(Operation *currentOp) { return insertDealloc; } +// Create a mapping from result type's dimensions to input type's dimensions, +// given that the result type is the result of a reduction op over the input +// type. +std::map +getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { + std::map OutInDimMap; + int64_t rank = inputTy.getRank(); + + // Mark reduction axes. + std::vector isReductionAxis; + for (decltype(rank) i = 0; i < rank; ++i) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) + isReductionAxis.push_back(true); + else + isReductionAxis.push_back(false); + } + + for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { + // If it is a reduction axis, there is no relationship among dimensions. + if (isReductionAxis[inIndex]) { + if (keepdims) + outIndex++; + } else { + OutInDimMap.insert(std::make_pair(outIndex, inIndex)); + outIndex++; + } + } + + return OutInDimMap; +} + // Add bounds associated with the op operand to the KRNL iteration pack. // Dynamic dimenions are supported. static void addDimensionToPack(ConversionPatternRewriter &rewriter, @@ -376,6 +407,18 @@ struct ScalarOp { using IOp = LogOp; // not use }; +template <> +struct ScalarOp { + using FOp = MulFOp; + using IOp = MulIOp; +}; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + template <> struct ScalarOp { using FOp = KrnlSqrtOp; @@ -387,6 +430,53 @@ using ScalarFOp = typename ScalarOp::FOp; template using ScalarIOp = typename ScalarOp::IOp; +// Get the identity element of a operation. +// Return NULL if the function does not have identity. +template +DataType getIdentityValue() { + return NULL; +} + +template <> +float getIdentityValue(){ + return (float)-std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::min(); +} + +template <> +float getIdentityValue(){ + return (float)std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::max(); +} + +template <> +float getIdentityValue(){ + return (float)1.0; +} + +template <> +int getIdentityValue(){ + return 1; +} + +template <> +float getIdentityValue(){ + return (float)0; +} + +template <> +int getIdentityValue(){ + return 0; +} + //===----------------------------------------------------------------------===// // Scalar unary ops for lowering to Krnl dialect. //===----------------------------------------------------------------------===// @@ -788,6 +878,58 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, return result; } +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMaxOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; + } else if (element_type.isa()) { + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMinOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; + } else if (element_type.isa()) { + auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + // Element-wise unary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template @@ -1823,6 +1965,193 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { } }; +//===----------------------------------------------------------------------===// +// Reduction ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXReductionOpLowering : public ConversionPattern { + ONNXReductionOpLowering(MLIRContext *ctx) + : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + /* + * Condition: reduction function must be associative and commutative. + * + * Example 1 (here, reduction function is `+`): + * Induction variables: (i0, i1, i2) + * axes = [0, 2] + * keepdims = true + * krnl.iterate() with (i0, i1, i2) { + * Y(0, i1, 0) += X(i0, i1, i2) + * } + * + * Example 2 (here, reduction function is `+`): + * Induction variables: (i0, i1, i2) + * axes = [0, 2] + * keepdims = false + * krnl.iterate() with (i0, i1, i2) { + * Y(i1) += X(i0, i1, i2) + * } + * + */ + auto loc = op->getLoc(); + auto memRefInType = operands[0].getType().cast(); + auto memRefInShape = memRefInType.getShape(); + auto tensorOutType = (*op->result_type_begin()).cast(); + int64_t inRank = memRefInType.getRank(); + int64_t outRank = tensorOutType.getRank(); + + // Get attributes + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + std::vector axes; + if (axisAttrs) { + for (auto axisAttr : axisAttrs.getValue()) { + int64_t axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (inRank + axis); + assert(axis >= -inRank && axis <= inRank - 1); + if (std::find(axes.begin(), axes.end(), axis) == axes.end()) + axes.push_back(axis); + } + } else { + for (decltype(inRank) i = 0; i < inRank; ++i) { + axes.push_back(i); + } + } + // KeepDims + auto keepdims = + llvm::dyn_cast(op).keepdims(); + bool isKeepdims = (keepdims == 1) ? true : false; + + // Get type information + auto memRefOutType = convertTensorToMemRef(tensorOutType); + auto memRefOutShape = memRefOutType.getShape(); + auto elementOutType = memRefOutType.getElementType(); + std::map outInDimMap = + getReductionMapping(memRefInType, axes, isKeepdims); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefOutType)) { + alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); + } else { + SmallVector allocOperands; + for (decltype(outRank) i = 0; i < outRank; ++i) { + if (memRefOutShape[i] < 0) { + auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); + allocOperands.push_back(dim); + } + } + alloc = rewriter.create(loc, memRefOutType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + // There are two Krnl loops: + // - One to initialize the result memref, and + // - One to do reduction + + // Define loops to initialize the result. + std::vector originalLoopsInit; + std::vector optimizedLoopsInit; + Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, + optimizedLoopsInit, outRank); + + // Iteration information + KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, + optimizedLoopsInit); + for (decltype(outRank) i = 0; i < outRank; ++i) { + addDimensionToPack(rewriter, loc, packInit, alloc, i); + } + auto iterateOpInit = rewriter.create(loc, packInit); + Block &iterationBlockInit = iterateOpInit.bodyRegion().front(); + + // Perform the insertions into the body of the initialization loop. + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlockInit); + rewriter.create(loc, originalLoopsInit); + + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlockInit); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlockInit.getArguments()) { + loopIVs.push_back(arg); + } + + Value identity; + if (elementOutType.isa()) { + identity = rewriter.create( + loc, FloatAttr::get(elementOutType, + getIdentityValue())); + } else if (elementOutType.isa()) { + identity = rewriter.create( + loc, IntegerAttr::get(elementOutType, + getIdentityValue())); + } else { + emitError(loc, "unsupported element type"); + } + rewriter.create(loc, identity, alloc, loopIVs); + + // Define an Krnl loop to do reduction. + rewriter.setInsertionPointAfter(iterateOpInit); + std::vector originalLoops, optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, inRank); + // Iteration information + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + for (decltype(inRank) i = 0; i < inRank; ++i) { + addDimensionToPack(rewriter, loc, pack, operands[0], i); + } + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Perform the insertions into the body of the reduction loop. + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector inLoopIVs, outLoopIVs; + auto args = iterationBlock.getArguments(); + for (int i = 0; i < args.size(); ++i) { + inLoopIVs.push_back(args[i]); + } + Value zeroIndex = nullptr; + for (decltype(inRank) i = 0; i < outRank; ++i) { + if (outInDimMap.find(i) != outInDimMap.end()) { + outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]); + } else { + if (zeroIndex) { + outLoopIVs.push_back(zeroIndex); + } else { + zeroIndex = rewriter.create(loc, 0); + outLoopIVs.push_back(zeroIndex); + } + } + } + + Value next, accumulated; + next = rewriter.create(loc, operands[0], inLoopIVs); + accumulated = rewriter.create(loc, alloc, outLoopIVs); + accumulated = mapToLowerScalarOp( + op, memRefOutType.getElementType(), {accumulated, next}, rewriter); + rewriter.create(loc, accumulated, alloc, outLoopIVs); + + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// @@ -1952,6 +2281,10 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, ONNXSoftmaxOpLowering, ONNXGemmOpLowering, ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering, ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index daf0224..d62069a 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -121,6 +121,10 @@ public: op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Transpose" && + op->getName().getStringRef() != "onnx.ReduceMax" && + op->getName().getStringRef() != "onnx.ReduceMin" && + op->getName().getStringRef() != "onnx.ReduceProd" && + op->getName().getStringRef() != "onnx.ReduceSum" && op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Sqrt" && op->getName().getStringRef() != "onnx.ConvNoBias" && diff --git a/test/backend/test.py b/test/backend/test.py index d7ae639..1c520aa 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -134,6 +134,46 @@ test_to_enable = [ # Relu Op: "test_relu_cpu", + # ReduceMax Op: + "test_reduce_max_default_axes_keepdim_example_cpu", + "test_reduce_max_default_axes_keepdims_random_cpu", + "test_reduce_max_do_not_keepdims_example_cpu", + "test_reduce_max_do_not_keepdims_random_cpu", + "test_reduce_max_keepdims_example_cpu", + "test_reduce_max_keepdims_random_cpu", + "test_reduce_max_negative_axes_keepdims_example_cpu", + "test_reduce_max_negative_axes_keepdims_random_cpu", + + # ReduceMin Op: + "test_reduce_min_default_axes_keepdims_example_cpu", + "test_reduce_min_default_axes_keepdims_random_cpu", + "test_reduce_min_do_not_keepdims_example_cpu", + "test_reduce_min_do_not_keepdims_random_cpu", + "test_reduce_min_keepdims_example_cpu", + "test_reduce_min_keepdims_random_cpu", + "test_reduce_min_negative_axes_keepdims_example_cpu", + "test_reduce_min_negative_axes_keepdims_random_cpu", + + # ReduceProd Op: + "test_reduce_prod_default_axes_keepdims_example_cpu", + "test_reduce_prod_default_axes_keepdims_random_cpu", + "test_reduce_prod_do_not_keepdims_example_cpu", + "test_reduce_prod_do_not_keepdims_random_cpu", + "test_reduce_prod_keepdims_example_cpu", + "test_reduce_prod_keepdims_random_cpu", + "test_reduce_prod_negative_axes_keepdims_example_cpu", + "test_reduce_prod_negative_axes_keepdims_random_cpu", + + # ReduceSum Op: + "test_reduce_sum_default_axes_keepdims_example_cpu", + "test_reduce_sum_default_axes_keepdims_random_cpu", + "test_reduce_sum_do_not_keepdims_example_cpu", + "test_reduce_sum_do_not_keepdims_random_cpu", + "test_reduce_sum_keepdims_example_cpu", + "test_reduce_sum_keepdims_random_cpu", + "test_reduce_sum_negative_axes_keepdims_example_cpu", + "test_reduce_sum_negative_axes_keepdims_random_cpu", + # Selu Op: "test_selu_cpu", "test_selu_default_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index b5a7dd4..e1724b6 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -587,6 +587,116 @@ func @test_add_with_broadcasting(%arg0 : tensor, %arg1 : tensor // CHECK: return [[RES]] : memref } +func @test_reducemax(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %0 ="onnx.ReduceMax"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reducemax + // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32> + // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) { + // CHECK: [[IDENTITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32> + + // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> + // CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD2]], [[LOAD1]] : f32 + // CHECK: [[SELECT:%.+]] = select %7, %6, %5 : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg3] : memref<3x2xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<3x2xf32> +} + +func @test_reducemin(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %0 ="onnx.ReduceMin"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reducemin + // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32> + // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) { + // CHECK: [[IDENTITY:%.+]] = constant 0x7F800000 : f32 + // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32> + + // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> + // CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD2]], [[LOAD1]] : f32 + // CHECK: [[SELECT:%.+]] = select %7, %6, %5 : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg3] : memref<3x2xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<3x2xf32> +} + +func @test_reduceprod(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %0 ="onnx.ReduceProd"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reduceprod + // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32> + // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) { + // CHECK: [[IDENTITY:%.+]] = constant 1.000000e+00 : f32 + // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32> + + // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> + // CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[REDUCE:%.+]] = mulf %6, %5 : f32 + // CHECK: store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<3x2xf32> +} + +func @test_reducesum(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %0 ="onnx.ReduceSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reducesum + // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32> + // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) { + // CHECK: [[IDENTITY:%.+]] = constant 0.000000e+00 : f32 + // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32> + + // CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32> + // CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32> + // CHECK: [[REDUCE:%.+]] = addf %6, %5 : f32 + // CHECK: store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<3x2xf32> +} + func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Softmax"(%arg0) {axis=1:i64} : (tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> ()