Add KrnlSqrtOp (#22)

* Initial lowering of KrnlSqrtOp

* Fix errors and add a testcase

* typos

* Add the MLIR example

* Restore doc/doc_check/CMakeLists.txt

* Clean the code

* Edit comments

* Remove redundant parts

* Chang the use of -> to .

* Add a test for f64

* Support ONNXSqrtOp

* Fix indentation

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-01-29 01:10:47 +09:00 committed by Gheorghe-Teodor Bercea
parent f00206cecf
commit 195bf9d15d
11 changed files with 145 additions and 5 deletions

View File

@ -190,3 +190,17 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
let parser = ?; let parser = ?;
let printer = ?; let printer = ?;
} }
def KrnlSqrtOp : Op<Krnl_Dialect, "sqrt", [NoSideEffect]> {
let summary = "Krnl sqrt operation";
let description = [{
"The `sqrt` computes the square root value. It takes one operand and returns
one result with the same type."
}];
let arguments = (ins FloatLike:$operand);
let results = (outs FloatLike);
let parser = ?;
let printer = ?;
}

View File

@ -268,7 +268,7 @@ def gen_schema(schema) :
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
'Softplus', 'Softsign'] 'Softplus', 'Softsign', 'Sqrt']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
manual_code = dict([ manual_code = dict([
('DummyExample', ' let extraClassDeclaration = [{ \n'+ ('DummyExample', ' let extraClassDeclaration = [{ \n'+

View File

@ -182,6 +182,14 @@ void ONNXSoftsignOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===//
// Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface.
void ONNXSqrtOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the

View File

@ -3212,7 +3212,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
} }
def ONNXSqrtOp:ONNX_Op<"Sqrt", def ONNXSqrtOp:ONNX_Op<"Sqrt",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Sqrt operation"; let summary = "ONNX Sqrt operation";
let description = [{ let description = [{
"Square root takes one input data (Tensor<T>) and produces one output data" "Square root takes one input data (Tensor<T>) and produces one output data"

View File

@ -304,6 +304,12 @@ struct ScalarOp<ONNXLogOp> {
using IOp = LogOp; // not use using IOp = LogOp; // not use
}; };
template <>
struct ScalarOp<ONNXSqrtOp> {
using FOp = KrnlSqrtOp;
using IOp = KrnlSqrtOp; // not use
};
template <typename ElementwiseNaryOp> template <typename ElementwiseNaryOp>
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp; using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
template <typename ElementwiseNaryOp> template <typename ElementwiseNaryOp>
@ -1267,6 +1273,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,

View File

@ -120,6 +120,7 @@ public:
op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Softmax" &&
op->getName().getStringRef() != "onnx.Sqrt" &&
op->getName().getStringRef() != "onnx.ConvNoBias") op->getName().getStringRef() != "onnx.ConvNoBias")
return false; return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {

View File

@ -144,6 +144,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
target.addIllegalDialect<KrnlOpsDialect>(); target.addIllegalDialect<KrnlOpsDialect>();
target.addLegalOp<KrnlMemcpyOp>(); target.addLegalOp<KrnlMemcpyOp>();
target.addLegalOp<KrnlEntryPointOp>(); target.addLegalOp<KrnlEntryPointOp>();
target.addLegalOp<KrnlSqrtOp>();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
@ -162,4 +163,4 @@ std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
} }
static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl", static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl",
"Lower Krnl dialect."); "Lower Krnl dialect.");

View File

@ -460,6 +460,67 @@ private:
} }
} }
}; };
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlSqrlOpLowering
//===----------------------------------------------------------------------===//
class KrnlSqrtOpLowering : public ConversionPattern {
public:
explicit KrnlSqrtOpLowering(MLIRContext *context)
: ConversionPattern(KrnlSqrtOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<KrnlSqrtOp> adaptor(operands);
LLVM::LLVMType operandType =
adaptor.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
if (!operandType)
return matchFailure();
std::string functionName;
if (operandType.isFloatTy())
functionName = "llvm.sqrt.f32";
else if (operandType.isDoubleTy())
functionName = "llvm.sqrt.f64";
else
assert(false && "Unsupported operand type.");
// Get a symbol reference to the sqrt function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto sqrtRef =
getOrInsertSqrt(rewriter, parentModule, functionName, operandType);
// Sqrt call
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operandType, sqrtRef,
adaptor.operand());
return matchSuccess();
}
private:
/// Return a symbol reference to the sqrt function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter,
ModuleOp module, std::string fnName,
LLVM::LLVMType operandType) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(fnName))
return SymbolRefAttr::get(fnName, context);
// Create a function declaration for sqrt, the signature is:
// * `float (float)`
auto llvmFnType =
LLVM::LLVMType::getFunctionTy(operandType, operandType, false);
// Insert the sqrt function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), fnName, llvmFnType);
return SymbolRefAttr::get(fnName, context);
}
};
} // end namespace } // end namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -489,8 +550,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns);
// Lower from the `krnl` dialect i.e. the Reshape operation. // Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>( patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
&getContext()); KrnlSqrtOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This // We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion. // ensures that only legal operations will remain after the conversion.

View File

@ -138,6 +138,10 @@ test_to_enable = [
"test_softmax_example_cpu", "test_softmax_example_cpu",
"test_softmax_large_number_cpu", "test_softmax_large_number_cpu",
# Sqrt Op:
"test_sqrt_cpu",
"test_sqrt_example_cpu",
# Sum Op: # Sum Op:
"test_sum_example_cpu", "test_sum_example_cpu",
"test_sum_one_input_cpu", "test_sum_one_input_cpu",

24
test/mlir/krnl/sqrt.mlir Normal file
View File

@ -0,0 +1,24 @@
// RUN: onnf-opt --shape-inference --lower-all-llvm %s -split-input-file | FileCheck %s
module {
func @test_sqrt_32(%arg0 : f32) -> f32 {
%0 = "krnl.sqrt"(%arg0) : (f32) -> f32
"std.return"(%0) : (f32) -> ()
// CHECK: llvm.func @llvm.sqrt.f32(!llvm.float) -> !llvm.float
// CHECK-NEXT: llvm.func @test_sqrt_32(%arg0: !llvm.float) -> !llvm.float {
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f32(%arg0) : (!llvm.float) -> !llvm.float
// CHECK-NEXT: llvm.return [[RES]] : !llvm.float
}
}
module{
func @test_sqrt_64(%arg0 : f64) -> f64 {
%0 = "krnl.sqrt"(%arg0) : (f64) -> f64
"std.return"(%0) : (f64) -> ()
// CHECK: llvm.func @llvm.sqrt.f64(!llvm.double) -> !llvm.double
// CHECK-NEXT: llvm.func @test_sqrt_64(%arg0: !llvm.double) -> !llvm.double {
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f64(%arg0) : (!llvm.double) -> !llvm.double
// CHECK-NEXT: llvm.return [[RES]] : !llvm.double
}
}

View File

@ -623,3 +623,23 @@ func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
// CHECK: dealloc [[MAX]] : memref<f32> // CHECK: dealloc [[MAX]] : memref<f32>
// CHECK: return [[RES]] : memref<10x10xf32> // CHECK: return [[RES]] : memref<10x10xf32>
} }
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_sqrt
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[SQRT:%.+]] = "krnl.sqrt"([[LOAD]]) : (f32) -> f32
// CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}