[MLIR] Add support for Relu (#392)

* Add support for Relu

* Add comments
This commit is contained in:
TUNG LEDUC 2019-12-06 14:31:17 +09:00 committed by Tian Jin
parent 82f5bfec9f
commit 45608282e0
7 changed files with 110 additions and 19 deletions

View File

@ -263,7 +263,7 @@ def collect_types(schema, input) :
return allowedTypeStr
def gen_schema(schema) :
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid',
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'MatMul', 'Gemm']
CanonicalList=['Add', 'Identity']

View File

@ -78,6 +78,14 @@ void ONNXSigmoidOp::inferShapes() {
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===//
// Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
void ONNXReluOp::inferShapes() {
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===//
// Add
/// Infer the output shape of the ONNXAddOp. This method is required by the

View File

@ -2185,7 +2185,7 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
}
def ONNXReluOp:ONNX_Op<"Relu",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Relu operation";
let description = [{
"Relu takes one input data (Tensor<T>) and produces one output data"

View File

@ -251,6 +251,23 @@ Value* mapToLowerScalarOp<ONNXSigmoidOp>(Location loc,
return result;
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXReluOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXReluOp>(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// ConstantOp 0,
// %X)
Value* operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
return result;
}
//===----------------------------------------------------------------------===//
// Element-wise n-ary ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
@ -452,6 +469,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXAddOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXMulOp>,
ONNXElementwiseBinaryOpLowering<mlir::ONNXDivOp>,

View File

@ -89,20 +89,21 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
// shaped outputs. All those operation need to implement the inferShape()
// method.
if (op->getName().getStringRef() != "onnx.Exp" &&
op->getName().getStringRef() != "onnx.Tanh" &&
op->getName().getStringRef() != "onnx.Sinh" &&
op->getName().getStringRef() != "onnx.Cosh" &&
op->getName().getStringRef() != "onnx.Sigmoid" &&
op->getName().getStringRef() != "onnx.Mul" &&
op->getName().getStringRef() != "onnx.Add" &&
op->getName().getStringRef() != "onnx.Div" &&
op->getName().getStringRef() != "onnx.Sub" &&
op->getName().getStringRef() != "onnx.And" &&
op->getName().getStringRef() != "onnx.Or" &&
op->getName().getStringRef() != "onnx.Xor" &&
op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm")
op->getName().getStringRef() != "onnx.Tanh" &&
op->getName().getStringRef() != "onnx.Sinh" &&
op->getName().getStringRef() != "onnx.Cosh" &&
op->getName().getStringRef() != "onnx.Sigmoid" &&
op->getName().getStringRef() != "onnx.Relu" &&
op->getName().getStringRef() != "onnx.Mul" &&
op->getName().getStringRef() != "onnx.Add" &&
op->getName().getStringRef() != "onnx.Div" &&
op->getName().getStringRef() != "onnx.Sub" &&
op->getName().getStringRef() != "onnx.And" &&
op->getName().getStringRef() != "onnx.Or" &&
op->getName().getStringRef() != "onnx.Xor" &&
op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm")
return false;
return llvm::any_of(op->getResultTypes(),
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
@ -118,4 +119,4 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
}
static PassRegistration<ShapeInferencePass> pass(
"shape-inference", "Shape inference for frontend dialects.");
"shape-inference", "Shape inference for frontend dialects.");

View File

@ -256,4 +256,25 @@ func @test_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: [[SIGMOID_RES:%.+]] = divf [[ONE]], [[DIVISOR]] : f32
// CHECK: store [[SIGMOID_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
}
func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_relu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}

View File

@ -527,4 +527,47 @@ func @test_sigmoid_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}
}
func @test_relu_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Relu"(%0) : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_relu_relu
/// First Relu
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
/// Second Relu
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
// CHECK: store [[RELU_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>
}