[MLIR] Add support for Relu (#392)
* Add support for Relu * Add comments
This commit is contained in:
parent
82f5bfec9f
commit
45608282e0
|
@ -263,7 +263,7 @@ def collect_types(schema, input) :
|
||||||
return allowedTypeStr
|
return allowedTypeStr
|
||||||
|
|
||||||
def gen_schema(schema) :
|
def gen_schema(schema) :
|
||||||
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid',
|
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
'MatMul', 'Gemm']
|
'MatMul', 'Gemm']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
|
|
|
@ -78,6 +78,14 @@ void ONNXSigmoidOp::inferShapes() {
|
||||||
getResult()->setType(getOperand()->getType());
|
getResult()->setType(getOperand()->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Relu
|
||||||
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXReluOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand()->getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
|
|
|
@ -2185,7 +2185,7 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXReluOp:ONNX_Op<"Relu",
|
def ONNXReluOp:ONNX_Op<"Relu",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Relu operation";
|
let summary = "ONNX Relu operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Relu takes one input data (Tensor<T>) and produces one output data"
|
"Relu takes one input data (Tensor<T>) and produces one output data"
|
||||||
|
|
|
@ -251,6 +251,23 @@ Value* mapToLowerScalarOp<ONNXSigmoidOp>(Location loc,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXReluOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value* mapToLowerScalarOp<ONNXReluOp>(Location loc, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
||||||
|
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 0,
|
||||||
|
// %X)
|
||||||
|
Value* operand = operands[0];
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
auto lessThanZero =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Element-wise n-ary ops lowering to Krnl dialect.
|
// Element-wise n-ary ops lowering to Krnl dialect.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -452,6 +469,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXAddOp>,
|
ONNXElementwiseBinaryOpLowering<mlir::ONNXAddOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXMulOp>,
|
ONNXElementwiseBinaryOpLowering<mlir::ONNXMulOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXDivOp>,
|
ONNXElementwiseBinaryOpLowering<mlir::ONNXDivOp>,
|
||||||
|
|
|
@ -89,20 +89,21 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
// shaped outputs. All those operation need to implement the inferShape()
|
// shaped outputs. All those operation need to implement the inferShape()
|
||||||
// method.
|
// method.
|
||||||
if (op->getName().getStringRef() != "onnx.Exp" &&
|
if (op->getName().getStringRef() != "onnx.Exp" &&
|
||||||
op->getName().getStringRef() != "onnx.Tanh" &&
|
op->getName().getStringRef() != "onnx.Tanh" &&
|
||||||
op->getName().getStringRef() != "onnx.Sinh" &&
|
op->getName().getStringRef() != "onnx.Sinh" &&
|
||||||
op->getName().getStringRef() != "onnx.Cosh" &&
|
op->getName().getStringRef() != "onnx.Cosh" &&
|
||||||
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
||||||
op->getName().getStringRef() != "onnx.Mul" &&
|
op->getName().getStringRef() != "onnx.Relu" &&
|
||||||
op->getName().getStringRef() != "onnx.Add" &&
|
op->getName().getStringRef() != "onnx.Mul" &&
|
||||||
op->getName().getStringRef() != "onnx.Div" &&
|
op->getName().getStringRef() != "onnx.Add" &&
|
||||||
op->getName().getStringRef() != "onnx.Sub" &&
|
op->getName().getStringRef() != "onnx.Div" &&
|
||||||
op->getName().getStringRef() != "onnx.And" &&
|
op->getName().getStringRef() != "onnx.Sub" &&
|
||||||
op->getName().getStringRef() != "onnx.Or" &&
|
op->getName().getStringRef() != "onnx.And" &&
|
||||||
op->getName().getStringRef() != "onnx.Xor" &&
|
op->getName().getStringRef() != "onnx.Or" &&
|
||||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
op->getName().getStringRef() != "onnx.Xor" &&
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||||
op->getName().getStringRef() != "onnx.FullGemm")
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
|
op->getName().getStringRef() != "onnx.FullGemm")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(),
|
return llvm::any_of(op->getResultTypes(),
|
||||||
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
||||||
|
@ -118,4 +119,4 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<ShapeInferencePass> pass(
|
static PassRegistration<ShapeInferencePass> pass(
|
||||||
"shape-inference", "Shape inference for frontend dialects.");
|
"shape-inference", "Shape inference for frontend dialects.");
|
||||||
|
|
|
@ -257,3 +257,24 @@ func @test_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: store [[SIGMOID_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: store [[SIGMOID_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_relu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -528,3 +528,46 @@ func @test_sigmoid_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
|
||||||
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_relu_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Relu"(%0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_relu_relu
|
||||||
|
/// First Relu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second Relu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue