[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor * Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]> * Miss a space * Add tests * Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering * Move lowering patterns into runOnModule() * Redundant space
This commit is contained in:
parent
05e16dafae
commit
c3ef1d93ae
|
@ -263,7 +263,7 @@ def collect_types(schema, input) :
|
|||
return allowedTypeStr
|
||||
|
||||
def gen_schema(schema) :
|
||||
ShapeInferenceList=['Add', 'MatMul', 'Gemm']
|
||||
ShapeInferenceList=['Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'MatMul', 'Gemm']
|
||||
CanonicalList=['Add', 'Identity']
|
||||
line_indent = ' '
|
||||
|
||||
|
@ -314,7 +314,7 @@ def gen_schema(schema) :
|
|||
#TODO handle (variadic, heterogeneous)"
|
||||
print('variadic, heterogeneous', input.name)
|
||||
if etypes == '':
|
||||
s+= 'AnyTensor'
|
||||
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
||||
else:
|
||||
s+= 'TensorOf<['+etypes+']>'
|
||||
|
||||
|
@ -339,7 +339,7 @@ def gen_schema(schema) :
|
|||
#need to interpret output.typeStr
|
||||
etypes=collect_types(schema, output)
|
||||
if etypes == '':
|
||||
s+= 'AnyTensor'
|
||||
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
||||
else:
|
||||
s+= 'TensorOf<['+etypes+']>'
|
||||
s+= ');'
|
||||
|
|
|
@ -46,6 +46,54 @@ void ONNXAddOp::inferShapes() {
|
|||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Mul
|
||||
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXMulOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Div
|
||||
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXDivOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sub
|
||||
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXSubOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// And
|
||||
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAndOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Or
|
||||
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXOrOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Xor
|
||||
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXXorOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// MatMul
|
||||
|
|
|
@ -44,7 +44,7 @@ def ONNXAddOp:ONNX_Op<"Add",
|
|||
}
|
||||
|
||||
def ONNXAndOp:ONNX_Op<"And",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX And operation";
|
||||
let description = [{
|
||||
"Returns the tensor resulted from performing the `and` logical operation"
|
||||
|
@ -61,8 +61,11 @@ def ONNXArgMaxOp:ONNX_Op<"ArgMax",
|
|||
let summary = "ONNX ArgMax operation";
|
||||
let description = [{
|
||||
"Computes the indices of the max elements of the input tensor's element along the "
|
||||
"provided axis. The resulted tensor has the same rank as the input if keepdims equal 1."
|
||||
"If keepdims equal 0, then the resulted tensor have the reduced dimension pruned. "
|
||||
"provided axis. The resulting tensor has the same rank as the input if keepdims equal 1. "
|
||||
"If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. "
|
||||
"If select_last_index is True (default False), the index of the last occurence of the max "
|
||||
"is selected if the max appears more than once in the input. Otherwise the index of the "
|
||||
"first occurence is selected."
|
||||
"The type of the output tensor is integer."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
||||
|
@ -74,8 +77,11 @@ def ONNXArgMinOp:ONNX_Op<"ArgMin",
|
|||
let summary = "ONNX ArgMin operation";
|
||||
let description = [{
|
||||
"Computes the indices of the min elements of the input tensor's element along the "
|
||||
"provided axis. The resulted tensor has the same rank as the input if keepdims equal 1."
|
||||
"If keepdims equal 0, then the resulted tensor have the reduced dimension pruned. "
|
||||
"provided axis. The resulting tensor has the same rank as the input if keepdims equal 1. "
|
||||
"If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. "
|
||||
"If select_last_index is True (default False), the index of the last occurence of the min "
|
||||
"is selected if the min appears more than once in the input. Otherwise the index of the "
|
||||
"first occurence is selected."
|
||||
"The type of the output tensor is integer."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
||||
|
@ -467,7 +473,7 @@ def ONNXDetOp:ONNX_Op<"Det",
|
|||
}
|
||||
|
||||
def ONNXDivOp:ONNX_Op<"Div",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Div operation";
|
||||
let description = [{
|
||||
"Performs element-wise binary division (with Numpy-style broadcasting support)."
|
||||
|
@ -1576,7 +1582,7 @@ def ONNXModOp:ONNX_Op<"Mod",
|
|||
}
|
||||
|
||||
def ONNXMulOp:ONNX_Op<"Mul",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Mul operation";
|
||||
let description = [{
|
||||
"Performs element-wise binary multiplication (with Numpy-style broadcasting support)."
|
||||
|
@ -1678,7 +1684,7 @@ def ONNXOneHotOp:ONNX_Op<"OneHot",
|
|||
}
|
||||
|
||||
def ONNXOrOp:ONNX_Op<"Or",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Or operation";
|
||||
let description = [{
|
||||
"Returns the tensor resulted from performing the `or` logical operation"
|
||||
|
@ -2954,7 +2960,7 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer",
|
|||
}
|
||||
|
||||
def ONNXSubOp:ONNX_Op<"Sub",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Sub operation";
|
||||
let description = [{
|
||||
"Performs element-wise binary subtraction (with Numpy-style broadcasting support)."
|
||||
|
@ -3223,7 +3229,7 @@ def ONNXWhereOp:ONNX_Op<"Where",
|
|||
}
|
||||
|
||||
def ONNXXorOp:ONNX_Op<"Xor",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Xor operation";
|
||||
let description = [{
|
||||
"Returns the tensor resulted from performing the `xor` logical operation"
|
||||
|
|
|
@ -99,17 +99,17 @@ static bool checkInsertDealloc(Operation *currentOp) {
|
|||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary ops lowering to Krnl dialect.
|
||||
// Element-wise binary ops lowering to Krnl dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename BinaryOp, typename LoweredBinaryOp>
|
||||
struct ONNXBinaryOpLowering : public ConversionPattern {
|
||||
ONNXBinaryOpLowering(MLIRContext* ctx)
|
||||
struct ONNXEWBinaryOpLowering : public ConversionPattern {
|
||||
ONNXEWBinaryOpLowering(MLIRContext* ctx)
|
||||
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// TODO: Check that the types are valid.
|
||||
// Add is an operation that must have all operands and the result of
|
||||
// An element-wise binary operation must have all operands and the result of
|
||||
// the same type. This should have been verified by the verifier.
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
@ -118,8 +118,8 @@ struct ONNXBinaryOpLowering : public ConversionPattern {
|
|||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
|
||||
// If the output has a dynamic dimension, pass the operands required for
|
||||
// each dynamic dimension to the AllocOp. The first operand of the Add
|
||||
// operation is used. The operands of the Add need to match in terms of
|
||||
// each dynamic dimension to the AllocOp. The first operand of the binary
|
||||
// operation is used. The operands of the op need to match in terms of
|
||||
// dimensions with the result at this pre-optimization phase.
|
||||
// TODO: verify that dimensions match.
|
||||
// TODO: can the dimension of the result differ after optimizations?
|
||||
|
@ -186,14 +186,13 @@ struct ONNXBinaryOpLowering : public ConversionPattern {
|
|||
// 2. Insert instructions inside the KernelIterateOp body.
|
||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||
|
||||
// Handle AddOp:
|
||||
// Handle the operation:
|
||||
SmallVector<Value*, 4> loopIVs;
|
||||
for (auto arg : iterationBlock.getArguments())
|
||||
loopIVs.push_back(arg);
|
||||
auto loadedFirstVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
||||
|
||||
// TODO: Choose type of the Add for now use the Float Add.
|
||||
auto loweredOpResult =
|
||||
rewriter.create<LoweredBinaryOp>(loc, loadedFirstVal, loadedSecondVal);
|
||||
|
||||
|
@ -206,11 +205,6 @@ struct ONNXBinaryOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp lowering to Krnl dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
using ONNXAddOpLowering = ONNXBinaryOpLowering<mlir::ONNXAddOp, AddFOp>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -291,7 +285,15 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
patterns, &getContext(), tensor_to_memref_converter);
|
||||
|
||||
// Frontent operation lowering.
|
||||
patterns.insert<ONNXAddOpLowering>(&getContext());
|
||||
// TODO: Support 1-N mapping (e.g. different types of the lowered op)
|
||||
patterns.insert<ONNXEWBinaryOpLowering<mlir::ONNXAddOp, AddFOp>,
|
||||
ONNXEWBinaryOpLowering<mlir::ONNXMulOp, MulFOp>,
|
||||
ONNXEWBinaryOpLowering<mlir::ONNXDivOp, DivFOp>,
|
||||
ONNXEWBinaryOpLowering<mlir::ONNXSubOp, SubFOp>,
|
||||
ONNXEWBinaryOpLowering<mlir::ONNXAndOp, AndOp>,
|
||||
ONNXEWBinaryOpLowering<mlir::ONNXOrOp, OrOp>,
|
||||
ONNXEWBinaryOpLowering<mlir::ONNXXorOp, XOrOp>>
|
||||
(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
|
|
|
@ -89,6 +89,12 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
|||
// shaped outputs. All those operation need to implement the inferShape()
|
||||
// method.
|
||||
if (op->getName().getStringRef() != "onnx.Add" &&
|
||||
op->getName().getStringRef() != "onnx.Mul" &&
|
||||
op->getName().getStringRef() != "onnx.Div" &&
|
||||
op->getName().getStringRef() != "onnx.Sub" &&
|
||||
op->getName().getStringRef() != "onnx.And" &&
|
||||
op->getName().getStringRef() != "onnx.Or" &&
|
||||
op->getName().getStringRef() != "onnx.Xor" &&
|
||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||
op->getName().getStringRef() != "onnx.FullGemm")
|
||||
|
|
|
@ -1,23 +1,141 @@
|
|||
// RUN: onnf-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
|
||||
|
||||
module {
|
||||
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
func @test_add(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_add
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
func @test_mul(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_mul
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[MULF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_div(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Div"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_div
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[DIVF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_sub(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Sub"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_sub
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[SUBF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_and(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
|
||||
%0 = "onnx.And"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_and
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: return [[RES]] : memref<?x10xi32>
|
||||
}
|
||||
|
||||
func @test_or(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
|
||||
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_or
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: return [[RES]] : memref<?x10xi32>
|
||||
}
|
||||
|
||||
func @test_xor(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
|
||||
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_xor
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: return [[RES]] : memref<?x10xi32>
|
||||
}
|
||||
|
|
|
@ -1,45 +1,289 @@
|
|||
// RUN: dlc-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
|
||||
|
||||
module {
|
||||
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Add"(%0, %a2) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
func @test_add_add(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Add"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_add_add
|
||||
/// First Add
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Second Add
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
|
||||
|
||||
/// First Add
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
func @test_mul_mul(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Mul"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
/// Second Add
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK-LABEL: test_mul_mul
|
||||
/// First Mul
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[MULF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||
/// Second Mul
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[MULF:%.+]] = mulf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[MULF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_div_div(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Div"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Div"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_div_div
|
||||
/// First Div
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[DIVF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Second Div
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[DIVF:%.+]] = divf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[DIVF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_sub_sub(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Sub"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Sub"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_sub_sub
|
||||
/// First Sub
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[SUBF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Second Sub
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[SUBF:%.+]] = subf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[SUBF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_and_and(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
|
||||
%0 = "onnx.And"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
%1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_and_and
|
||||
/// First And
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
|
||||
/// Second And
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xi32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xi32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xi32>
|
||||
}
|
||||
|
||||
func @test_or_or(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
|
||||
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
%1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_or_or
|
||||
/// First Or
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
|
||||
/// Second Or
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xi32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xi32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xi32>
|
||||
}
|
||||
|
||||
func @test_xor_xor(%arg0 : tensor<?x10xi32>, %arg1 : tensor<?x10xi32>) -> tensor<*xi32> {
|
||||
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<?x10xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
%1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor<?x10xi32>) -> tensor<*xi32>
|
||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_xor_xor
|
||||
/// First Xor
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
|
||||
/// Second Xor
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xi32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xi32>
|
||||
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
|
||||
// CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<?x10xi32>
|
||||
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xi32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xi32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue