[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)

* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor

* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>

* Miss a space

* Add tests

* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering

* Move lowering patterns into runOnModule()

* Redundant space
This commit is contained in:
TUNG LEDUC 2019-12-04 01:17:21 +09:00 committed by Tian Jin
parent 05e16dafae
commit c3ef1d93ae
7 changed files with 507 additions and 83 deletions

View File

@ -263,7 +263,7 @@ def collect_types(schema, input) :
return allowedTypeStr
def gen_schema(schema) :
ShapeInferenceList=['Add', 'MatMul', 'Gemm']
ShapeInferenceList=['Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'MatMul', 'Gemm']
CanonicalList=['Add', 'Identity']
line_indent = ' '
@ -314,7 +314,7 @@ def gen_schema(schema) :
#TODO handle (variadic, heterogeneous)"
print('variadic, heterogeneous', input.name)
if etypes == '':
s+= 'AnyTensor'
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
else:
s+= 'TensorOf<['+etypes+']>'
@ -339,7 +339,7 @@ def gen_schema(schema) :
#need to interpret output.typeStr
etypes=collect_types(schema, output)
if etypes == '':
s+= 'AnyTensor'
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
else:
s+= 'TensorOf<['+etypes+']>'
s+= ');'

View File

@ -46,6 +46,54 @@ void ONNXAddOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// Mul
/// Infer the output shape of the ONNXMulOp. This method is required by the
/// shape inference interface.
void ONNXMulOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// Div
/// Infer the output shape of the ONNXDivOp. This method is required by the
/// shape inference interface.
void ONNXDivOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// Sub
/// Infer the output shape of the ONNXSubOp. This method is required by the
/// shape inference interface.
void ONNXSubOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// And
/// Infer the output shape of the ONNXAndOp. This method is required by the
/// shape inference interface.
void ONNXAndOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// Or
/// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface.
void ONNXOrOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// Xor
/// Infer the output shape of the ONNXXorOp. This method is required by the
/// shape inference interface.
void ONNXXorOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// MatMul

View File

@ -44,7 +44,7 @@ def ONNXAddOp:ONNX_Op<"Add",
}
def ONNXAndOp:ONNX_Op<"And",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX And operation";
let description = [{
"Returns the tensor resulted from performing the `and` logical operation"
@ -61,8 +61,11 @@ def ONNXArgMaxOp:ONNX_Op<"ArgMax",
let summary = "ONNX ArgMax operation";
let description = [{
"Computes the indices of the max elements of the input tensor's element along the "
"provided axis. The resulted tensor has the same rank as the input if keepdims equal 1."
"If keepdims equal 0, then the resulted tensor have the reduced dimension pruned. "
"provided axis. The resulting tensor has the same rank as the input if keepdims equal 1. "
"If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. "
"If select_last_index is True (default False), the index of the last occurence of the max "
"is selected if the max appears more than once in the input. Otherwise the index of the "
"first occurence is selected."
"The type of the output tensor is integer."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
@ -74,8 +77,11 @@ def ONNXArgMinOp:ONNX_Op<"ArgMin",
let summary = "ONNX ArgMin operation";
let description = [{
"Computes the indices of the min elements of the input tensor's element along the "
"provided axis. The resulted tensor has the same rank as the input if keepdims equal 1."
"If keepdims equal 0, then the resulted tensor have the reduced dimension pruned. "
"provided axis. The resulting tensor has the same rank as the input if keepdims equal 1. "
"If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. "
"If select_last_index is True (default False), the index of the last occurence of the min "
"is selected if the min appears more than once in the input. Otherwise the index of the "
"first occurence is selected."
"The type of the output tensor is integer."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
@ -467,7 +473,7 @@ def ONNXDetOp:ONNX_Op<"Det",
}
def ONNXDivOp:ONNX_Op<"Div",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Div operation";
let description = [{
"Performs element-wise binary division (with Numpy-style broadcasting support)."
@ -1576,7 +1582,7 @@ def ONNXModOp:ONNX_Op<"Mod",
}
def ONNXMulOp:ONNX_Op<"Mul",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Mul operation";
let description = [{
"Performs element-wise binary multiplication (with Numpy-style broadcasting support)."
@ -1678,7 +1684,7 @@ def ONNXOneHotOp:ONNX_Op<"OneHot",
}
def ONNXOrOp:ONNX_Op<"Or",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Or operation";
let description = [{
"Returns the tensor resulted from performing the `or` logical operation"
@ -2954,7 +2960,7 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer",
}
def ONNXSubOp:ONNX_Op<"Sub",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Sub operation";
let description = [{
"Performs element-wise binary subtraction (with Numpy-style broadcasting support)."
@ -3223,7 +3229,7 @@ def ONNXWhereOp:ONNX_Op<"Where",
}
def ONNXXorOp:ONNX_Op<"Xor",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Xor operation";
let description = [{
"Returns the tensor resulted from performing the `xor` logical operation"

View File

@ -99,17 +99,17 @@ static bool checkInsertDealloc(Operation *currentOp) {
namespace {
//===----------------------------------------------------------------------===//
// Binary ops lowering to Krnl dialect.
// Element-wise binary ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
struct ONNXBinaryOpLowering : public ConversionPattern {
ONNXBinaryOpLowering(MLIRContext* ctx)
struct ONNXEWBinaryOpLowering : public ConversionPattern {
ONNXEWBinaryOpLowering(MLIRContext* ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
// TODO: Check that the types are valid.
// Add is an operation that must have all operands and the result of
// An element-wise binary operation must have all operands and the result of
// the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
@ -118,8 +118,8 @@ struct ONNXBinaryOpLowering : public ConversionPattern {
auto memRefType = convertTensorToMemRef(tensorType);
// If the output has a dynamic dimension, pass the operands required for
// each dynamic dimension to the AllocOp. The first operand of the Add
// operation is used. The operands of the Add need to match in terms of
// each dynamic dimension to the AllocOp. The first operand of the binary
// operation is used. The operands of the op need to match in terms of
// dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match.
// TODO: can the dimension of the result differ after optimizations?
@ -186,14 +186,13 @@ struct ONNXBinaryOpLowering : public ConversionPattern {
// 2. Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle AddOp:
// Handle the operation:
SmallVector<Value*, 4> loopIVs;
for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg);
auto loadedFirstVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
// TODO: Choose type of the Add for now use the Float Add.
auto loweredOpResult =
rewriter.create<LoweredBinaryOp>(loc, loadedFirstVal, loadedSecondVal);
@ -206,11 +205,6 @@ struct ONNXBinaryOpLowering : public ConversionPattern {
}
};
//===----------------------------------------------------------------------===//
// AddOp lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
using ONNXAddOpLowering = ONNXBinaryOpLowering<mlir::ONNXAddOp, AddFOp>;
//===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===//
@ -291,7 +285,15 @@ void FrontendToKrnlLoweringPass::runOnModule() {
patterns, &getContext(), tensor_to_memref_converter);
// Frontent operation lowering.
patterns.insert<ONNXAddOpLowering>(&getContext());
// TODO: Support 1-N mapping (e.g. different types of the lowered op)
patterns.insert<ONNXEWBinaryOpLowering<mlir::ONNXAddOp, AddFOp>,
ONNXEWBinaryOpLowering<mlir::ONNXMulOp, MulFOp>,
ONNXEWBinaryOpLowering<mlir::ONNXDivOp, DivFOp>,
ONNXEWBinaryOpLowering<mlir::ONNXSubOp, SubFOp>,
ONNXEWBinaryOpLowering<mlir::ONNXAndOp, AndOp>,
ONNXEWBinaryOpLowering<mlir::ONNXOrOp, OrOp>,
ONNXEWBinaryOpLowering<mlir::ONNXXorOp, XOrOp>>
(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`

View File

@ -89,6 +89,12 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
// shaped outputs. All those operation need to implement the inferShape()
// method.
if (op->getName().getStringRef() != "onnx.Add" &&
op->getName().getStringRef() != "onnx.Mul" &&
op->getName().getStringRef() != "onnx.Div" &&
op->getName().getStringRef() != "onnx.Sub" &&
op->getName().getStringRef() != "onnx.And" &&
op->getName().getStringRef() != "onnx.Or" &&
op->getName().getStringRef() != "onnx.Xor" &&
op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm")

View File

@ -1,23 +1,141 @@
// RUN: onnf-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
module {
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
}
func @test_add(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_add
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
func @test_mul(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_mul
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[MULF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_div(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Div"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_div
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[DIVF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_sub(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sub"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_sub
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[SUBF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_and(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_and
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
// CHECK: return [[RES]] : memref<?x10xi32>
}
func @test_or(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_or
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
// CHECK: return [[RES]] : memref<?x10xi32>
}
func @test_xor(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_xor
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
// CHECK: return [[RES]] : memref<?x10xi32>
}

View File

@ -1,45 +1,289 @@
// RUN: dlc-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
module {
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %a2) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
}
func @test_add_add(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_add_add
/// First Add
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
/// Second Add
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
/// First Add
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
func @test_mul_mul(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Mul"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
/// Second Add
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK-LABEL: test_mul_mul
/// First Mul
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[MULF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
/// Second Mul
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[MULF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_div_div(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Div"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Div"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_div_div
/// First Div
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[DIVF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
/// Second Div
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[DIVF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_sub_sub(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sub"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Sub"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_sub_sub
/// First Sub
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[SUBF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
/// Second Sub
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[SUBF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
func @test_and_and(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
%1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor<?x10xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_and_and
/// First And
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
/// Second And
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<?x10xi32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xi32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xi32>
// CHECK: return [[RET_RES]] : memref<?x10xi32>
}
func @test_or_or(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
%1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor<?x10xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_or_or
/// First Or
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
/// Second Or
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<?x10xi32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xi32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xi32>
// CHECK: return [[RET_RES]] : memref<?x10xi32>
}
func @test_xor_xor(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
%1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor<?x10xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_xor_xor
/// First Xor
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
/// Second Xor
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<?x10xi32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xi32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xi32>
// CHECK: return [[RET_RES]] : memref<?x10xi32>
}