From 06a968d4a1f5f8972485772e67a5ba3612600d5c Mon Sep 17 00:00:00 2001 From: TUNG LEDUC Date: Fri, 20 Dec 2019 01:28:06 +0900 Subject: [PATCH] [MLIR] Add broadcasting support for element wise operations (#398) * Add broadcasting support for elementwise operations * Remove MLIRDialect from MLIRWholeArchiveLibs * Rewrite getLoopIVsForBroadcasting * Compute dimensions for allocating result memory * Compute dimensions for allocating result memory (revised) * Use static dimension for element-wise operation testcases * Add a test for addition with broadcasting * Missed Traits.h when merging * Revise * Update SharedWork.md * Broadcasting for variadic operations * Edit comments * Update SharedWork.md * Reorganize the code * Add CHECK-LABEL for test_add_with_broadcasting --- MLIR.cmake | 2 + SharingWork.md | 22 +- src/compiler/dialect/onnx/onnx_ops.cpp | 86 ++++- src/compiler/pass/lower_frontend_to_krnl.cpp | 166 +++++++- test/mlir/onnx/onnx_lowering.mlir | 206 +++++----- .../mlir/onnx/onnx_lowering_with_dealloc.mlir | 360 ++++++++---------- 6 files changed, 501 insertions(+), 341 deletions(-) diff --git a/MLIR.cmake b/MLIR.cmake index d8d9da7..39cb0c9 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -57,6 +57,7 @@ endfunction(find_mlir_lib) find_mlir_lib(MLIRAffineOps) find_mlir_lib(MLIRAffineToStandard) find_mlir_lib(MLIRAnalysis) +find_mlir_lib(MLIRDialect) find_mlir_lib(MLIRExecutionEngine) find_mlir_lib(MLIRIR) find_mlir_lib(MLIRLLVMIR) @@ -114,6 +115,7 @@ set(MLIRLibsOnce MLIRAffineOps MLIRAffineToStandard MLIRAnalysis + MLIRDialect MLIRExecutionEngine MLIRIR MLIRLLVMIR diff --git a/SharingWork.md b/SharingWork.md index fa12559..6b6c063 100644 --- a/SharingWork.md +++ b/SharingWork.md @@ -8,10 +8,10 @@ ONNX operations for which some work is needed. | ONNX Oper | Person working on it | ONNX 2 KRNL | Basic functionality | Extended functionality (e.g. broadcast) | | ---------- | --------------------- | -------------- | --------------------- | ---------------------------------------- | -| Add | Tung (updated) | v | v | noM | -| And | Tung | v | v | noM | -| Cosh | Tung | v | v | noM | -| Div | Tung | v | v | | +| Add | Tung (updated) | v | v | M | +| And | Tung | v | v | M | +| Cosh | Tung | v | v | | +| Div | Tung | v | v | M | | Elu | Tung | v | v | | | Exp | Tung | v | v | | | FullGemm | | | | noU | @@ -19,18 +19,18 @@ ONNX operations for which some work is needed. | HardSigmoid | Tung | v | v | | | LeakyRelu | Tung | v | v | | | MatMul | | | | noM | -| Max | Tung | v | v | noM | -| Min | Tung | v | v | noM | -| Mul | Tung | v | v | noM | -| Or | Tung | v | v | noM | +| Max | Tung | v | v | M | +| Min | Tung | v | v | M | +| Mul | Tung | v | v | M | +| Or | Tung | v | v | M | | Relu | Tung | v | v | | | Selu | Tung | v | v | | | Sigmoid | Tung | v | v | | | Sinh | Tung | v | v | | -| Sub | Tung | v | v | noM | -| Sum | Tung | v | v | noM | +| Sub | Tung | v | v | M | +| Sum | Tung | v | v | M | | Tanh | Tung | v | v | | -| Xor | Tung | v | v | noM | +| Xor | Tung | v | v | M | ONNX operations for which the work is completed (full functionality) and tested diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index eb7c60e..da50632 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -8,6 +8,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Traits.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -21,6 +22,7 @@ #include "onnx_ops.hpp" using namespace mlir; +using namespace mlir::OpTrait::util; //===----------------------------------------------------------------------===// // ONNXOpsDialect @@ -127,7 +129,12 @@ void ONNXReciprocalOp::inferShapes() { /// Infer the output shape of the ONNXAddOp. This method is required by the /// shape inference interface. void ONNXAddOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -135,7 +142,12 @@ void ONNXAddOp::inferShapes() { /// Infer the output shape of the ONNXMulOp. This method is required by the /// shape inference interface. void ONNXMulOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -143,7 +155,12 @@ void ONNXMulOp::inferShapes() { /// Infer the output shape of the ONNXDivOp. This method is required by the /// shape inference interface. void ONNXDivOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -151,7 +168,12 @@ void ONNXDivOp::inferShapes() { /// Infer the output shape of the ONNXSubOp. This method is required by the /// shape inference interface. void ONNXSubOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -159,21 +181,38 @@ void ONNXSubOp::inferShapes() { /// Infer the output shape of the ONNXAndOp. This method is required by the /// shape inference interface. void ONNXAndOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// // Or /// Infer the output shape of the ONNXOrOp. This method is required by the /// shape inference interface. -void ONNXOrOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } +void ONNXOrOp::inferShapes() { + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); +} //===----------------------------------------------------------------------===// // Xor /// Infer the output shape of the ONNXXorOp. This method is required by the /// shape inference interface. void ONNXXorOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + if (!getOperand(0)->getType().isa() || + !getOperand(1)->getType().isa()) + return; + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// @@ -183,7 +222,16 @@ void ONNXXorOp::inferShapes() { /// Infer the output shape of the ONNXSumOp. This method is required by the /// shape inference interface. void ONNXSumOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + for (int i = 0; i < getNumOperands(); ++i) { + if (!getOperand(i)->getType().cast()) + return; + } + Type resultTy = getOperand(0)->getType().cast(); + for (int i = 1; i < getNumOperands(); ++i) { + Type nextTy = getOperand(i)->getType().cast(); + resultTy = getBroadcastedType(resultTy, nextTy); + } + getResult()->setType(resultTy); } //===----------------------------------------------------------------------===// @@ -191,7 +239,16 @@ void ONNXSumOp::inferShapes() { /// Infer the output shape of the ONNXMaxOp. This method is required by the /// shape inference interface. void ONNXMaxOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + for (int i = 0; i < getNumOperands(); ++i) { + if (!getOperand(i)->getType().cast()) + return; + } + Type resultTy = getOperand(0)->getType().cast(); + for (int i = 1; i < getNumOperands(); ++i) { + Type nextTy = getOperand(i)->getType().cast(); + resultTy = getBroadcastedType(resultTy, nextTy); + } + getResult()->setType(resultTy); } //===----------------------------------------------------------------------===// @@ -199,7 +256,16 @@ void ONNXMaxOp::inferShapes() { /// Infer the output shape of the ONNXMinOp. This method is required by the /// shape inference interface. void ONNXMinOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); + for (int i = 0; i < getNumOperands(); ++i) { + if (!getOperand(i)->getType().cast()) + return; + } + Type resultTy = getOperand(0)->getType().cast(); + for (int i = 1; i < getNumOperands(); ++i) { + Type nextTy = getOperand(i)->getType().cast(); + resultTy = getBroadcastedType(resultTy, nextTy); + } + getResult()->setType(resultTy); } //===----------------------------------------------------------------------===// diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index 593bcbb..9a17df0 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -9,6 +9,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Sequence.h" #include "mlir/Dialect/AffineOps/AffineOps.h" @@ -44,16 +46,51 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static Value* insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter& rewriter, bool insertDealloc, Value* oldMemRef = nullptr) { +static Value *insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter, + bool insertDealloc, + ArrayRef operands = {}) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; - if (oldMemRef) { - SmallVector allocOperands; + if (!operands.empty()) { auto memRefShape = type.getShape(); - for (int i = 0; i < memRefShape.size(); ++i) + auto rank = memRefShape.size(); + + std::map fromOperands; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int memRefDimIdx = rank - 1 - reversedIdx; + if (memRefShape[memRefDimIdx] < 0) { // unknown dimension + Value *maxDim = nullptr; + for (int i = 0; i < operands.size(); i++) { + auto operandShape = + operands[i]->getType().cast().getShape(); + int operandDimIdx = operandShape.size() - 1 - reversedIdx; + + if (operandDimIdx < 0) + continue; + + // In case of operations with broadcasting, the dimension of the + // alloc result is the maximum size along each dimension of the + // operands. + auto operandDim = + rewriter.create(loc, operands[i], operandDimIdx); + if (maxDim) { + auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, + operandDim, maxDim); + maxDim = rewriter.create(loc, maxCondition, operandDim, + maxDim); + } else { + maxDim = operandDim; + } + } + fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); + } + } + + SmallVector allocOperands; + for (int i = 0; i < rank; ++i) if (memRefShape[i] < 0) - allocOperands.push_back(rewriter.create(loc, oldMemRef, i)); + allocOperands.push_back(fromOperands[i]); alloc = rewriter.create(loc, type, allocOperands); } else { alloc = rewriter.create(loc, type); @@ -109,6 +146,89 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { return llvm::divideCeil(sizeInBits, 8); } +// Get run-time dimension information for unknown dimensions used for +// broadcasting. +std::map > +getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, + MemRefType memRefType, ArrayRef operands) { + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + // For unknown dimensions, we need to get dimension values at runtime in + // order to do broadcasting. + std::map> DimInfo; + // For each result dimension, compute the number of sharing operands. + // Sharing operands are operands sharing the same index (counting from the + // rightmost to the leftmost) for a given dimension. + std::map sharedDimCount; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int dimIdx = rank - 1 - reversedIdx; + sharedDimCount[dimIdx] = 0; + for (int i = 0; i < operands.size(); ++i) { + auto shape = operands[i]->getType().cast().getShape(); + if (reversedIdx <= shape.size() - 1) + sharedDimCount[dimIdx]++; + } + } + // An unknown dimension can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + // We only care about unknown dimensions whose number of sharing operands is + // more than one, since they are potentially broadcasted dimensions. + for (int i = 0; i < operands.size(); ++i) { + std::map broadcastedDims; + auto shape = operands[i]->getType().cast().getShape(); + int size = shape.size(); + for (int j = 0; j < shape.size(); ++j) { + if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { + auto dim = rewriter.create(loc, operands[i], j).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDims.insert(std::make_pair(j, isBroadcasted)); + } + } + DimInfo.insert(std::make_pair(i, broadcastedDims)); + } + return DimInfo; +} + +// Extract induction variables that are used for broadcasting values of a +// given operand. +std::vector +getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef loopIVs, Value *operand, + std::map broadcastedDims) { + // `operand` must has a ranked type. This should have been checked by the + // shape inference pass. + auto operandShape = operand->getType().cast().getShape(); + auto rank = operandShape.size(); + auto loopCount = loopIVs.size(); + + std::vector newLoopIVs; + for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + auto dimIdx = rank - 1 - reversedIdx; + auto loopIdx = loopCount - 1 - reversedIdx; + if (operandShape[dimIdx] == 1) { + // Broadcasted dimension + auto zero = rewriter.create(loc, 0); + newLoopIVs.insert(newLoopIVs.begin(), zero); + } else if ((operandShape[dimIdx] == -1) && + (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { + // Unknown dimension, it can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + auto zero = rewriter.create(loc, 0); + auto idx = rewriter.create(loc, broadcastedDims[dimIdx], + zero, loopIVs[loopIdx]); + newLoopIVs.insert(newLoopIVs.begin(), idx); + } else { + // Non-broadcasted dimension + newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); + } + } + return newLoopIVs; +} + namespace { template @@ -505,7 +625,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc( - memRefType, loc, rewriter, insertDealloc, operands[0]); + memRefType, loc, rewriter, insertDealloc, {operands[0]}); // Number of loops auto memRefShape = memRefType.getShape(); @@ -595,20 +715,18 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); - // If the output has a dynamic dimension, pass the operands required for - // each dynamic dimension to the AllocOp. The first operand of the - // operation is used. The operands of the op need to match in terms of - // dimensions with the result at this pre-optimization phase. - // TODO: verify that dimensions match. - // TODO: can the dimension of the result differ after optimizations? Value* alloc; bool insertDealloc = checkInsertDealloc(op); - + // If the output has a dynamic dimension, we compute its dimension at + // runtime by using dimensions from the operands. + // In particular, we need to know from which operand a result dimension + // comes from. + // TODO: can the dimension of the result differ after optimizations? if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc( - memRefType, loc, rewriter, insertDealloc, operands[0]); + memRefType, loc, rewriter, insertDealloc, operands); // Number of loops auto memRefShape = memRefType.getShape(); @@ -639,13 +757,18 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { if (memRefShape[i] < 0) { pack.pushConstantBound(0); pack.pushOperandBound( - rewriter.create(loc, operands[0], i).getResult()); + rewriter.create(loc, alloc, i).getResult()); } else { pack.pushConstantBound(0); pack.pushConstantBound(memRefShape[i]); } } + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + std::map> broadcastedDimInfo = + getBroadcastedDimInfo(loc, rewriter, memRefType, operands); + auto iterateOp = rewriter.create(loc, pack); Block& iterationBlock = iterateOp.bodyRegion().front(); @@ -655,8 +778,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops - // unchaged. + // When no optimizations are present we just return the loops unchaged. rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); @@ -670,9 +792,13 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { // Fold over operands for each of their scalar values Value *accumulated, *next; - accumulated = rewriter.create(loc, operands[0], loopIVs); + auto accumulatedLoopIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); + accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); for (unsigned i = 1; i < numArgs; i++) { - next = rewriter.create(loc, operands[i], loopIVs); + auto nextLoopIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); + next = rewriter.create(loc, operands[i], nextLoopIVs); accumulated = mapToLowerScalarOp( op, memRefType.getElementType(), {accumulated, next}, rewriter); } diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 9cff02c..92d4a0f 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1,143 +1,129 @@ // RUN: onnf-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s -func @test_add(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> +func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_add - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: return [[RES]] : memref<10x10xf32> } -func @test_mul(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> +func @test_mul(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_mul - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[MULF]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[MULF]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: return [[RES]] : memref<10x10xf32> } -func @test_div(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Div"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> +func @test_div(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_div - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[DIVF]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[DIVF]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: return [[RES]] : memref<10x10xf32> } -func @test_sub(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> +func @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Sub"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_sub - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[SUBF]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[SUBF]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: return [[RES]] : memref<10x10xf32> } -func @test_and(%arg0 : tensor, %arg1 : tensor) -> tensor<*xi32> { - %0 = "onnx.And"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xi32> +func @test_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { + %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> "std.return"(%0) : (tensor<*xi32>) -> () // CHECK-LABEL: test_and - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: return [[RES]] : memref<10x10xi32> } -func @test_or(%arg0 : tensor, %arg1 : tensor) -> tensor<*xi32> { - %0 = "onnx.Or"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xi32> +func @test_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { + %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> "std.return"(%0) : (tensor<*xi32>) -> () // CHECK-LABEL: test_or - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: return [[RES]] : memref<10x10xi32> } -func @test_xor(%arg0 : tensor, %arg1 : tensor) -> tensor<*xi32> { - %0 = "onnx.Xor"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xi32> +func @test_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { + %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> "std.return"(%0) : (tensor<*xi32>) -> () // CHECK-LABEL: test_xor - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: return [[RES]] : memref<10x10xi32> } func @test_exp(%arg0 : tensor) -> tensor<*xf32> { @@ -310,66 +296,60 @@ func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi32>) -> tensor<*x // CHECK: return [[ALLOC]] : memref } -func @test_sum(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Sum"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> +func @test_sum(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_sum - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: return [[RES]] : memref<10x10xf32> } -func @test_max(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Max"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> +func @test_max(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_max - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32 // CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: return [[RES]] : memref<10x10xf32> } -func @test_min(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Min"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> +func @test_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_min - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32 // CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref - // CHECK: return [[RES]] : memref + // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: return [[RES]] : memref<10x10xf32> } func @test_elu(%arg0 : tensor) -> tensor<*xf32> { @@ -495,3 +475,29 @@ func @test_reciprocal(%arg0 : tensor) -> tensor<*xf32> { // CHECK: store [[RECIPROCAL_RES]], [[RES]][%arg1, %arg2] : memref // CHECK: return [[RES]] : memref } + +func @test_add_with_broadcasting(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_add_with_broadcasting + // CHECK: [[DIM1:%.+]] = dim %arg1, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM1]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM2:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[DIM3:%.+]] = dim %arg0, 0 : memref + // CHECK: [[ONE:%.+]] = constant 1 : index + // CHECK: [[IS_ONE:%.+]] = cmpi "eq", [[DIM3]], [[ONE]] : index + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[ZERO:%.+]] = constant 0 : index + // CHECK: %[[SELECT1:.+]] = select [[IS_ONE]], [[ZERO]], %arg3 : index + // CHECK: [[LOAD1:%.+]] = load %arg0[%[[SELECT1]]] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref + // CHECK: } + // CHECK: return [[RES]] : memref +} diff --git a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir index cbd4d39..4216ffa 100644 --- a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir +++ b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir @@ -1,291 +1,263 @@ // RUN: onnf-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s -func @test_add_add(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> - %1 = "onnx.Add"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +func @test_add_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %arg1) : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_add_add /// First Add - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<10x10xf32> /// Second Add - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<10x10xf32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xf32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xf32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xf32> } -func @test_mul_mul(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> - %1 = "onnx.Mul"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +func @test_mul_mul(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "onnx.Mul"(%0, %arg1) : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_mul_mul /// First Mul - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[MULF]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[MULF]], [[RES]][%arg2, %arg3] : memref<10x10xf32> /// Second Mul - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[MULF]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[MULF]], [[RET_RES]][%arg2, %arg3] : memref<10x10xf32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xf32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xf32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xf32> } -func @test_div_div(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Div"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> - %1 = "onnx.Div"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +func @test_div_div(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "onnx.Div"(%0, %arg1) : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_div_div /// First Div - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[DIVF]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[DIVF]], [[RES]][%arg2, %arg3] : memref<10x10xf32> /// Second Div - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[DIVF]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[DIVF]], [[RET_RES]][%arg2, %arg3] : memref<10x10xf32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xf32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xf32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xf32> } -func @test_sub_sub(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> - %1 = "onnx.Sub"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +func @test_sub_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Sub"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "onnx.Sub"(%0, %arg1) : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_sub_sub /// First Sub - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[SUBF]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[SUBF]], [[RES]][%arg2, %arg3] : memref<10x10xf32> /// Second Sub - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[SUBF]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[SUBF]], [[RET_RES]][%arg2, %arg3] : memref<10x10xf32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xf32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xf32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xf32> } -func @test_and_and(%arg0 : tensor, %arg1 : tensor) -> tensor<*xi32> { - %0 = "onnx.And"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xi32> - %1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor) -> tensor<*xi32> +func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { + %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> + %1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> "std.return"(%1) : (tensor<*xi32>) -> () // CHECK-LABEL: test_and_and /// First And - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32> /// Second And - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xi32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xi32> } -func @test_or_or(%arg0 : tensor, %arg1 : tensor) -> tensor<*xi32> { - %0 = "onnx.Or"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xi32> - %1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor) -> tensor<*xi32> +func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { + %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> + %1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> "std.return"(%1) : (tensor<*xi32>) -> () // CHECK-LABEL: test_or_or /// First Or - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> /// Second Or - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xi32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xi32> } -func @test_xor_xor(%arg0 : tensor, %arg1 : tensor) -> tensor<*xi32> { - %0 = "onnx.Xor"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xi32> - %1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor) -> tensor<*xi32> +func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { + %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> + %1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> "std.return"(%1) : (tensor<*xi32>) -> () // CHECK-LABEL: test_xor_xor /// First Xor - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> /// Second Xor - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xi32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xi32> } func @test_exp_exp(%arg0 : tensor) -> tensor<*xf32> { @@ -572,131 +544,119 @@ func @test_relu_relu(%arg0 : tensor) -> tensor<*xf32> { // CHECK: return [[RET_RES]] : memref } -func @test_sum_sum(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Sum"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> - %1 = "onnx.Sum"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +func @test_sum_sum(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "onnx.Sum"(%0, %arg1) : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_sum_sum /// First Sum - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<10x10xf32> /// Second Sum - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[ADD]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[ADD]], [[RET_RES]][%arg2, %arg3] : memref<10x10xf32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xf32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xf32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xf32> } -func @test_max_max(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Max"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> - %1 = "onnx.Max"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +func @test_max_max(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "onnx.Max"(%0, %arg1) : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_max_max /// First Max - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32 // CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<10x10xf32> /// Second Max - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32 // CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref<10x10xf32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xf32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xf32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xf32> } -func @test_min_min(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { - %0 = "onnx.Min"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> - %1 = "onnx.Min"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +func @test_min_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "onnx.Min"(%0, %arg1) : (tensor<*xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_min_min /// First Min - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32 // CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref + // CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<10x10xf32> /// Second Min - // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xf32> // CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32 // CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32 - // CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref + // CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref<10x10xf32> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref - // CHECK-NOT: dealloc [[RET_RES]] : memref + // CHECK: dealloc [[RES]] : memref<10x10xf32> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xf32> - // CHECK: return [[RET_RES]] : memref + // CHECK: return [[RET_RES]] : memref<10x10xf32> } func @test_elu_elu(%arg0 : tensor) -> tensor<*xf32> {