diff --git a/src/dialect/krnl/krnl_ops.td b/src/dialect/krnl/krnl_ops.td index bb894f7..1649146 100644 --- a/src/dialect/krnl/krnl_ops.td +++ b/src/dialect/krnl/krnl_ops.td @@ -190,3 +190,17 @@ def KrnlMemcpyOp : Op { let parser = ?; let printer = ?; } + +def KrnlSqrtOp : Op { + let summary = "Krnl sqrt operation"; + let description = [{ + "The `sqrt` computes the square root value. It takes one operand and returns + one result with the same type." + }]; + + let arguments = (ins FloatLike:$operand); + let results = (outs FloatLike); + + let parser = ?; + let printer = ?; +} diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index a40680d..0f56234 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -268,7 +268,7 @@ def gen_schema(schema) : 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', - 'Softplus', 'Softsign'] + 'Softplus', 'Softsign', 'Sqrt'] CanonicalList=['Add', 'Identity'] manual_code = dict([ ('DummyExample', ' let extraClassDeclaration = [{ \n'+ diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 0a4fb5e..edb2611 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -182,6 +182,14 @@ void ONNXSoftsignOp::inferShapes() { getResult().setType(getOperand().getType()); } +//===----------------------------------------------------------------------===// +// Sqrt +/// Infer the output shape of the ONNXSqrtOp. This method is required by +/// the shape inference interface. +void ONNXSqrtOp::inferShapes() { + getResult().setType(getOperand().getType()); +} + //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 5db12e9..2467dc1 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3212,7 +3212,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", } def ONNXSqrtOp:ONNX_Op<"Sqrt", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Sqrt operation"; let description = [{ "Square root takes one input data (Tensor) and produces one output data" diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index c1277ff..8b5e296 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -304,6 +304,12 @@ struct ScalarOp { using IOp = LogOp; // not use }; +template <> +struct ScalarOp { + using FOp = KrnlSqrtOp; + using IOp = KrnlSqrtOp; // not use +}; + template using ScalarFOp = typename ScalarOp::FOp; template @@ -1267,6 +1273,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index c95191f..a861ed7 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -120,6 +120,7 @@ public: op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.Softmax" && + op->getName().getStringRef() != "onnx.Sqrt" && op->getName().getStringRef() != "onnx.ConvNoBias") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { diff --git a/src/transform/lower_krnl.cpp b/src/transform/lower_krnl.cpp index da16f81..193e1bf 100644 --- a/src/transform/lower_krnl.cpp +++ b/src/transform/lower_krnl.cpp @@ -144,6 +144,7 @@ void KrnlToAffineLoweringPass::runOnFunction() { target.addIllegalDialect(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); OwningRewritePatternList patterns; patterns.insert mlir::createLowerKrnlPass() { } static PassRegistration pass("lower-krnl", - "Lower Krnl dialect."); \ No newline at end of file + "Lower Krnl dialect."); diff --git a/src/transform/lower_to_llvm.cpp b/src/transform/lower_to_llvm.cpp index 9d229f5..de6e671 100644 --- a/src/transform/lower_to_llvm.cpp +++ b/src/transform/lower_to_llvm.cpp @@ -460,6 +460,67 @@ private: } } }; + +//===----------------------------------------------------------------------===// +// KRNL to LLVM: KrnlSqrlOpLowering +//===----------------------------------------------------------------------===// + +class KrnlSqrtOpLowering : public ConversionPattern { +public: + explicit KrnlSqrtOpLowering(MLIRContext *context) + : ConversionPattern(KrnlSqrtOp::getOperationName(), 1, context) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor adaptor(operands); + LLVM::LLVMType operandType = + adaptor.operand().getType().dyn_cast_or_null(); + + if (!operandType) + return matchFailure(); + + std::string functionName; + if (operandType.isFloatTy()) + functionName = "llvm.sqrt.f32"; + else if (operandType.isDoubleTy()) + functionName = "llvm.sqrt.f64"; + else + assert(false && "Unsupported operand type."); + + // Get a symbol reference to the sqrt function, inserting it if necessary. + ModuleOp parentModule = op->getParentOfType(); + auto sqrtRef = + getOrInsertSqrt(rewriter, parentModule, functionName, operandType); + + // Sqrt call + rewriter.replaceOpWithNewOp(op, operandType, sqrtRef, + adaptor.operand()); + + return matchSuccess(); + } + +private: + /// Return a symbol reference to the sqrt function, inserting it into the + /// module if necessary. + static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter, + ModuleOp module, std::string fnName, + LLVM::LLVMType operandType) { + auto *context = module.getContext(); + if (module.lookupSymbol(fnName)) + return SymbolRefAttr::get(fnName, context); + // Create a function declaration for sqrt, the signature is: + // * `float (float)` + auto llvmFnType = + LLVM::LLVMType::getFunctionTy(operandType, operandType, false); + + // Insert the sqrt function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), fnName, llvmFnType); + return SymbolRefAttr::get(fnName, context); + } +}; } // end namespace //===----------------------------------------------------------------------===// @@ -489,8 +550,8 @@ void KrnlToLLVMLoweringPass::runOnModule() { populateStdToLLVMConversionPatterns(typeConverter, patterns); // Lower from the `krnl` dialect i.e. the Reshape operation. - patterns.insert( - &getContext()); + patterns.insert(&getContext()); // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. diff --git a/test/backend/test.py b/test/backend/test.py index c850462..7cebffc 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -138,6 +138,10 @@ test_to_enable = [ "test_softmax_example_cpu", "test_softmax_large_number_cpu", + # Sqrt Op: + "test_sqrt_cpu", + "test_sqrt_example_cpu", + # Sum Op: "test_sum_example_cpu", "test_sum_one_input_cpu", diff --git a/test/mlir/krnl/sqrt.mlir b/test/mlir/krnl/sqrt.mlir new file mode 100644 index 0000000..bc91bc4 --- /dev/null +++ b/test/mlir/krnl/sqrt.mlir @@ -0,0 +1,24 @@ +// RUN: onnf-opt --shape-inference --lower-all-llvm %s -split-input-file | FileCheck %s +module { + func @test_sqrt_32(%arg0 : f32) -> f32 { + %0 = "krnl.sqrt"(%arg0) : (f32) -> f32 + "std.return"(%0) : (f32) -> () + + // CHECK: llvm.func @llvm.sqrt.f32(!llvm.float) -> !llvm.float + // CHECK-NEXT: llvm.func @test_sqrt_32(%arg0: !llvm.float) -> !llvm.float { + // CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f32(%arg0) : (!llvm.float) -> !llvm.float + // CHECK-NEXT: llvm.return [[RES]] : !llvm.float + } +} + +module{ + func @test_sqrt_64(%arg0 : f64) -> f64 { + %0 = "krnl.sqrt"(%arg0) : (f64) -> f64 + "std.return"(%0) : (f64) -> () + + // CHECK: llvm.func @llvm.sqrt.f64(!llvm.double) -> !llvm.double + // CHECK-NEXT: llvm.func @test_sqrt_64(%arg0: !llvm.double) -> !llvm.double { + // CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f64(%arg0) : (!llvm.double) -> !llvm.double + // CHECK-NEXT: llvm.return [[RES]] : !llvm.double + } +} diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 2b8c52e..fcf0bfe 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -623,3 +623,23 @@ func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { // CHECK: dealloc [[MAX]] : memref // CHECK: return [[RES]] : memref<10x10xf32> } + +func @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sqrt"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_sqrt + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[SQRT:%.+]] = "krnl.sqrt"([[LOAD]]) : (f32) -> f32 + // CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} +