[MLIR] Add support for Relu (#392)

* Add support for Relu

* Add comments
This commit is contained in:
TUNG LEDUC 2019-12-06 14:31:17 +09:00 committed by Tian Jin
parent 82f5bfec9f
commit 45608282e0
7 changed files with 110 additions and 19 deletions

View File

@ -263,7 +263,7 @@ def collect_types(schema, input) :
return allowedTypeStr return allowedTypeStr
def gen_schema(schema) : def gen_schema(schema) :
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'MatMul', 'Gemm'] 'MatMul', 'Gemm']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']

View File

@ -78,6 +78,14 @@ void ONNXSigmoidOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult()->setType(getOperand()->getType());
} }
//===----------------------------------------------------------------------===//
// Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
void ONNXReluOp::inferShapes() {
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the

View File

@ -2185,7 +2185,7 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
} }
def ONNXReluOp:ONNX_Op<"Relu", def ONNXReluOp:ONNX_Op<"Relu",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Relu operation"; let summary = "ONNX Relu operation";
let description = [{ let description = [{
"Relu takes one input data (Tensor<T>) and produces one output data" "Relu takes one input data (Tensor<T>) and produces one output data"

View File

@ -251,6 +251,23 @@ Value* mapToLowerScalarOp<ONNXSigmoidOp>(Location loc,
return result; return result;
} }
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXReluOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXReluOp>(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// ConstantOp 0,
// %X)
Value* operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
return result;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Element-wise n-ary ops lowering to Krnl dialect. // Element-wise n-ary ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -452,6 +469,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXAddOp>, ONNXElementwiseBinaryOpLowering<mlir::ONNXAddOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXMulOp>, ONNXElementwiseBinaryOpLowering<mlir::ONNXMulOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXDivOp>, ONNXElementwiseBinaryOpLowering<mlir::ONNXDivOp>,

View File

@ -93,6 +93,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
op->getName().getStringRef() != "onnx.Sinh" && op->getName().getStringRef() != "onnx.Sinh" &&
op->getName().getStringRef() != "onnx.Cosh" && op->getName().getStringRef() != "onnx.Cosh" &&
op->getName().getStringRef() != "onnx.Sigmoid" && op->getName().getStringRef() != "onnx.Sigmoid" &&
op->getName().getStringRef() != "onnx.Relu" &&
op->getName().getStringRef() != "onnx.Mul" && op->getName().getStringRef() != "onnx.Mul" &&
op->getName().getStringRef() != "onnx.Add" && op->getName().getStringRef() != "onnx.Add" &&
op->getName().getStringRef() != "onnx.Div" && op->getName().getStringRef() != "onnx.Div" &&

View File

@ -257,3 +257,24 @@ func @test_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: store [[SIGMOID_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32> // CHECK: store [[SIGMOID_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }
func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_relu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}

View File

@ -528,3 +528,46 @@ func @test_sigmoid_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: return [[RET_RES]] : memref<?x10xf32> // CHECK: return [[RET_RES]] : memref<?x10xf32>
} }
func @test_relu_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Relu"(%0) : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_relu_relu
/// First Relu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
/// Second Relu
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
// CHECK: store [[RELU_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}