Lowering ReductionMax, ReductionMin, ReductionProd and ReductionSum (#31)

* Shape inference for reduction

* Lower ReduceSum

* Support list-like attributes

* Add ReduceMax, ReduceMin, ReduceProd

* Add tests

* Emit errors for unsupported types

* Typos

* Add backend test

* Fix axis computation

* Update the use of attributes

* Use SmallVector

* Address stylistic comments

* Change type from int to int64_t for indices

* Change type from int to int64_t for indices
This commit is contained in:
Tung D. Le 2020-02-10 22:38:19 +09:00 committed by GitHub
parent 0272451521
commit 2c7046ff5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 594 additions and 4 deletions

View File

@ -46,6 +46,7 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum',
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign']
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',

View File

@ -24,6 +24,54 @@
using namespace mlir;
using namespace mlir::OpTrait::util;
//===----------------------------------------------------------------------===//
// Get reduction type
//===----------------------------------------------------------------------===//
RankedTensorType getReductionOutputType(RankedTensorType operandTy,
Optional<ArrayAttr> axesAttrs,
APInt keepdims) {
int64_t rank = operandTy.getRank();
SmallVector<int64_t, 4> axes;
if (axesAttrs != llvm::None) {
for (auto axisAttr : axesAttrs.getValue()) {
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
axis = axis >= 0 ? axis : (rank + axis);
assert(axis >= -rank && axis <= rank - 1);
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis);
}
} else {
for (decltype(rank) i = 0; i < rank; ++i) {
axes.emplace_back(i);
}
}
// Mark reduction axes.
SmallVector<bool, 4> isReductionAxis;
for (decltype(rank) i = 0; i < rank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end())
isReductionAxis.emplace_back(true);
else
isReductionAxis.emplace_back(false);
}
// KeepDims
bool isKeepdims = (keepdims == 1) ? true : false;
SmallVector<int64_t, 4> dims;
for (decltype(rank) i = 0; i < rank; ++i) {
if (isReductionAxis[i]) {
if (isKeepdims)
dims.emplace_back(1); // reduction dimension
} else {
dims.emplace_back(operandTy.getShape()[i]);
}
}
return RankedTensorType::get(dims, operandTy.getElementType());
}
//===----------------------------------------------------------------------===//
// ONNXOpsDialect
//===----------------------------------------------------------------------===//
@ -608,6 +656,60 @@ void ONNXTransposeOp::inferShapes() {
//===----------------------------------------------------------------------===//
// ReduceMax
void ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
return;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
//===----------------------------------------------------------------------===//
// ReduceMin
void ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
return;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
//===----------------------------------------------------------------------===//
// ReduceProd
void ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
return;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
//===----------------------------------------------------------------------===//
// ReduceSum
void ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
return;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
// Conv
// For this operation, we define the attributes once in the original Conv

View File

@ -2349,7 +2349,7 @@ def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
}
def ONNXReduceMaxOp:ONNX_Op<"ReduceMax",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX ReduceMax operation";
let description = [{
"Computes the max of the input tensor's element along the provided axes. The resulted"
@ -2383,7 +2383,7 @@ def ONNXReduceMeanOp:ONNX_Op<"ReduceMean",
}
def ONNXReduceMinOp:ONNX_Op<"ReduceMin",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX ReduceMin operation";
let description = [{
"Computes the min of the input tensor's element along the provided axes. The resulted"
@ -2400,7 +2400,7 @@ def ONNXReduceMinOp:ONNX_Op<"ReduceMin",
}
def ONNXReduceProdOp:ONNX_Op<"ReduceProd",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX ReduceProd operation";
let description = [{
"Computes the product of the input tensor's element along the provided axes. The resulted"
@ -2417,7 +2417,7 @@ def ONNXReduceProdOp:ONNX_Op<"ReduceProd",
}
def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX ReduceSum operation";
let description = [{
"Computes the sum of the input tensor's element along the provided axes. The resulted"

View File

@ -130,6 +130,37 @@ static bool checkInsertDealloc(Operation *currentOp) {
return insertDealloc;
}
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
// type.
std::map<int64_t, int64_t>
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
std::map<int64_t, int64_t> OutInDimMap;
int64_t rank = inputTy.getRank();
// Mark reduction axes.
std::vector<bool> isReductionAxis;
for (decltype(rank) i = 0; i < rank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end())
isReductionAxis.push_back(true);
else
isReductionAxis.push_back(false);
}
for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
// If it is a reduction axis, there is no relationship among dimensions.
if (isReductionAxis[inIndex]) {
if (keepdims)
outIndex++;
} else {
OutInDimMap.insert(std::make_pair(outIndex, inIndex));
outIndex++;
}
}
return OutInDimMap;
}
// Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported.
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
@ -376,6 +407,18 @@ struct ScalarOp<ONNXLogOp> {
using IOp = LogOp; // not use
};
template <>
struct ScalarOp<ONNXReduceProdOp> {
using FOp = MulFOp;
using IOp = MulIOp;
};
template <>
struct ScalarOp<ONNXReduceSumOp> {
using FOp = AddFOp;
using IOp = AddIOp;
};
template <>
struct ScalarOp<ONNXSqrtOp> {
using FOp = KrnlSqrtOp;
@ -387,6 +430,53 @@ using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
template <typename ElementwiseNaryOp>
using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
// Get the identity element of a operation.
// Return NULL if the function does not have identity.
template <typename DataType, typename Op>
DataType getIdentityValue() {
return NULL;
}
template <>
float getIdentityValue<float, ONNXReduceMaxOp>(){
return (float)-std::numeric_limits<float>::infinity();
}
template <>
int getIdentityValue<int, ONNXReduceMaxOp>(){
return std::numeric_limits<int>::min();
}
template <>
float getIdentityValue<float, ONNXReduceMinOp>(){
return (float)std::numeric_limits<float>::infinity();
}
template <>
int getIdentityValue<int, ONNXReduceMinOp>(){
return std::numeric_limits<int>::max();
}
template <>
float getIdentityValue<float, ONNXReduceProdOp>(){
return (float)1.0;
}
template <>
int getIdentityValue<int, ONNXReduceProdOp>(){
return 1;
}
template <>
float getIdentityValue<float, ONNXReduceSumOp>(){
return (float)0;
}
template <>
int getIdentityValue<int, ONNXReduceSumOp>(){
return 0;
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
@ -788,6 +878,58 @@ Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
return result;
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXReduceMaxOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Value lhs = operands[0];
Value rhs = operands[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result;
} else if (element_type.isa<FloatType>()) {
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result;
} else {
emitError(loc, "unsupported element type");
return nullptr;
}
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXReduceMinOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXReduceMinOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Value lhs = operands[0];
Value rhs = operands[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result;
} else if (element_type.isa<FloatType>()) {
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result;
} else {
emitError(loc, "unsupported element type");
return nullptr;
}
}
// Element-wise unary ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOp>
@ -1823,6 +1965,193 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
}
};
//===----------------------------------------------------------------------===//
// Reduction ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
template <typename ONNXReductionOp>
struct ONNXReductionOpLowering : public ConversionPattern {
ONNXReductionOpLowering(MLIRContext *ctx)
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
/*
* Condition: reduction function must be associative and commutative.
*
* Example 1 (here, reduction function is `+`):
* Induction variables: (i0, i1, i2)
* axes = [0, 2]
* keepdims = true
* krnl.iterate() with (i0, i1, i2) {
* Y(0, i1, 0) += X(i0, i1, i2)
* }
*
* Example 2 (here, reduction function is `+`):
* Induction variables: (i0, i1, i2)
* axes = [0, 2]
* keepdims = false
* krnl.iterate() with (i0, i1, i2) {
* Y(i1) += X(i0, i1, i2)
* }
*
*/
auto loc = op->getLoc();
auto memRefInType = operands[0].getType().cast<MemRefType>();
auto memRefInShape = memRefInType.getShape();
auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
int64_t inRank = memRefInType.getRank();
int64_t outRank = tensorOutType.getRank();
// Get attributes
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
std::vector<int64_t> axes;
if (axisAttrs) {
for (auto axisAttr : axisAttrs.getValue()) {
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
axis = axis >= 0 ? axis : (inRank + axis);
assert(axis >= -inRank && axis <= inRank - 1);
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.push_back(axis);
}
} else {
for (decltype(inRank) i = 0; i < inRank; ++i) {
axes.push_back(i);
}
}
// KeepDims
auto keepdims =
llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
bool isKeepdims = (keepdims == 1) ? true : false;
// Get type information
auto memRefOutType = convertTensorToMemRef(tensorOutType);
auto memRefOutShape = memRefOutType.getShape();
auto elementOutType = memRefOutType.getElementType();
std::map<int64_t, int64_t> outInDimMap =
getReductionMapping(memRefInType, axes, isKeepdims);
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefOutType)) {
alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc);
} else {
SmallVector<Value, 2> allocOperands;
for (decltype(outRank) i = 0; i < outRank; ++i) {
if (memRefOutShape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, operands[0], outInDimMap[i]);
allocOperands.push_back(dim);
}
}
alloc = rewriter.create<AllocOp>(loc, memRefOutType, allocOperands);
if (insertDealloc) {
auto *parentBlock = alloc.getDefiningOp()->getBlock();
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
}
// There are two Krnl loops:
// - One to initialize the result memref, and
// - One to do reduction
// Define loops to initialize the result.
std::vector<Value> originalLoopsInit;
std::vector<Value> optimizedLoopsInit;
Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit,
optimizedLoopsInit, outRank);
// Iteration information
KrnlIterateOperandPack packInit(rewriter, originalLoopsInit,
optimizedLoopsInit);
for (decltype(outRank) i = 0; i < outRank; ++i) {
addDimensionToPack(rewriter, loc, packInit, alloc, i);
}
auto iterateOpInit = rewriter.create<KrnlIterateOp>(loc, packInit);
Block &iterationBlockInit = iterateOpInit.bodyRegion().front();
// Perform the insertions into the body of the initialization loop.
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlockInit);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoopsInit);
// Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlockInit);
// Handle the operation:
SmallVector<Value, 4> loopIVs;
for (auto arg : iterationBlockInit.getArguments()) {
loopIVs.push_back(arg);
}
Value identity;
if (elementOutType.isa<FloatType>()) {
identity = rewriter.create<ConstantOp>(
loc, FloatAttr::get(elementOutType,
getIdentityValue<float, ONNXReductionOp>()));
} else if (elementOutType.isa<IntegerType>()) {
identity = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(elementOutType,
getIdentityValue<int, ONNXReductionOp>()));
} else {
emitError(loc, "unsupported element type");
}
rewriter.create<StoreOp>(loc, identity, alloc, loopIVs);
// Define an Krnl loop to do reduction.
rewriter.setInsertionPointAfter(iterateOpInit);
std::vector<Value> originalLoops, optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, inRank);
// Iteration information
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
for (decltype(inRank) i = 0; i < inRank; ++i) {
addDimensionToPack(rewriter, loc, pack, operands[0], i);
}
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &iterationBlock = iterateOp.bodyRegion().front();
// Perform the insertions into the body of the reduction loop.
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
// Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation:
SmallVector<Value, 4> inLoopIVs, outLoopIVs;
auto args = iterationBlock.getArguments();
for (int i = 0; i < args.size(); ++i) {
inLoopIVs.push_back(args[i]);
}
Value zeroIndex = nullptr;
for (decltype(inRank) i = 0; i < outRank; ++i) {
if (outInDimMap.find(i) != outInDimMap.end()) {
outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]);
} else {
if (zeroIndex) {
outLoopIVs.push_back(zeroIndex);
} else {
zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
outLoopIVs.push_back(zeroIndex);
}
}
}
Value next, accumulated;
next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs);
accumulated = mapToLowerScalarOp<ONNXReductionOp>(
op, memRefOutType.getElementType(), {accumulated, next}, rewriter);
rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===//
@ -1952,6 +2281,10 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering, ONNXEntryPointLowering,
ONNXReductionOpLowering<mlir::ONNXReduceMaxOp>,
ONNXReductionOpLowering<mlir::ONNXReduceMinOp>,
ONNXReductionOpLowering<mlir::ONNXReduceProdOp>,
ONNXReductionOpLowering<mlir::ONNXReduceSumOp>,
ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering,
ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering

View File

@ -121,6 +121,10 @@ public:
op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.ReduceMax" &&
op->getName().getStringRef() != "onnx.ReduceMin" &&
op->getName().getStringRef() != "onnx.ReduceProd" &&
op->getName().getStringRef() != "onnx.ReduceSum" &&
op->getName().getStringRef() != "onnx.Softmax" &&
op->getName().getStringRef() != "onnx.Sqrt" &&
op->getName().getStringRef() != "onnx.ConvNoBias" &&

View File

@ -134,6 +134,46 @@ test_to_enable = [
# Relu Op:
"test_relu_cpu",
# ReduceMax Op:
"test_reduce_max_default_axes_keepdim_example_cpu",
"test_reduce_max_default_axes_keepdims_random_cpu",
"test_reduce_max_do_not_keepdims_example_cpu",
"test_reduce_max_do_not_keepdims_random_cpu",
"test_reduce_max_keepdims_example_cpu",
"test_reduce_max_keepdims_random_cpu",
"test_reduce_max_negative_axes_keepdims_example_cpu",
"test_reduce_max_negative_axes_keepdims_random_cpu",
# ReduceMin Op:
"test_reduce_min_default_axes_keepdims_example_cpu",
"test_reduce_min_default_axes_keepdims_random_cpu",
"test_reduce_min_do_not_keepdims_example_cpu",
"test_reduce_min_do_not_keepdims_random_cpu",
"test_reduce_min_keepdims_example_cpu",
"test_reduce_min_keepdims_random_cpu",
"test_reduce_min_negative_axes_keepdims_example_cpu",
"test_reduce_min_negative_axes_keepdims_random_cpu",
# ReduceProd Op:
"test_reduce_prod_default_axes_keepdims_example_cpu",
"test_reduce_prod_default_axes_keepdims_random_cpu",
"test_reduce_prod_do_not_keepdims_example_cpu",
"test_reduce_prod_do_not_keepdims_random_cpu",
"test_reduce_prod_keepdims_example_cpu",
"test_reduce_prod_keepdims_random_cpu",
"test_reduce_prod_negative_axes_keepdims_example_cpu",
"test_reduce_prod_negative_axes_keepdims_random_cpu",
# ReduceSum Op:
"test_reduce_sum_default_axes_keepdims_example_cpu",
"test_reduce_sum_default_axes_keepdims_random_cpu",
"test_reduce_sum_do_not_keepdims_example_cpu",
"test_reduce_sum_do_not_keepdims_random_cpu",
"test_reduce_sum_keepdims_example_cpu",
"test_reduce_sum_keepdims_random_cpu",
"test_reduce_sum_negative_axes_keepdims_example_cpu",
"test_reduce_sum_negative_axes_keepdims_random_cpu",
# Selu Op:
"test_selu_cpu",
"test_selu_default_cpu",

View File

@ -587,6 +587,116 @@ func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_reducemax(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceMax"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reducemax
// CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
// CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) {
// CHECK: [[IDENTITY:%.+]] = constant 0xFF800000 : f32
// CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32>
// CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32>
// CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32>
// CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD2]], [[LOAD1]] : f32
// CHECK: [[SELECT:%.+]] = select %7, %6, %5 : f32
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg3] : memref<3x2xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<3x2xf32>
}
func @test_reducemin(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceMin"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reducemin
// CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
// CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) {
// CHECK: [[IDENTITY:%.+]] = constant 0x7F800000 : f32
// CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32>
// CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32>
// CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32>
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD2]], [[LOAD1]] : f32
// CHECK: [[SELECT:%.+]] = select %7, %6, %5 : f32
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg3] : memref<3x2xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<3x2xf32>
}
func @test_reduceprod(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceProd"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reduceprod
// CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
// CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) {
// CHECK: [[IDENTITY:%.+]] = constant 1.000000e+00 : f32
// CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32>
// CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32>
// CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32>
// CHECK: [[REDUCE:%.+]] = mulf %6, %5 : f32
// CHECK: store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<3x2xf32>
}
func @test_reducesum(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reducesum
// CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
// CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 2) {
// CHECK: [[IDENTITY:%.+]] = constant 0.000000e+00 : f32
// CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2] : memref<3x2xf32>
// CHECK: [[DEF_LOOPS2:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS2:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 3, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 2, [[DEF_LOOPS2]]#2 -> %arg3 = 0 to 2) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<3x2x2xf32>
// CHECK: [[LOAD2:%.+]] = load %0[%arg1, %arg3] : memref<3x2xf32>
// CHECK: [[REDUCE:%.+]] = addf %6, %5 : f32
// CHECK: store [[REDUCE]], [[RES]][%arg1, %arg3] : memref<3x2xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<3x2xf32>
}
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Softmax"(%arg0) {axis=1:i64} : (tensor<10x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()