Add KrnlSqrtOp (#22)
* Initial lowering of KrnlSqrtOp * Fix errors and add a testcase * typos * Add the MLIR example * Restore doc/doc_check/CMakeLists.txt * Clean the code * Edit comments * Remove redundant parts * Chang the use of -> to . * Add a test for f64 * Support ONNXSqrtOp * Fix indentation Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
f00206cecf
commit
195bf9d15d
|
@ -190,3 +190,17 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
|||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
||||
def KrnlSqrtOp : Op<Krnl_Dialect, "sqrt", [NoSideEffect]> {
|
||||
let summary = "Krnl sqrt operation";
|
||||
let description = [{
|
||||
"The `sqrt` computes the square root value. It takes one operand and returns
|
||||
one result with the same type."
|
||||
}];
|
||||
|
||||
let arguments = (ins FloatLike:$operand);
|
||||
let results = (outs FloatLike);
|
||||
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
|
|
@ -268,7 +268,7 @@ def gen_schema(schema) :
|
|||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||
'Softplus', 'Softsign']
|
||||
'Softplus', 'Softsign', 'Sqrt']
|
||||
CanonicalList=['Add', 'Identity']
|
||||
manual_code = dict([
|
||||
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||
|
|
|
@ -182,6 +182,14 @@ void ONNXSoftsignOp::inferShapes() {
|
|||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sqrt
|
||||
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSqrtOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Add
|
||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||
|
|
|
@ -3212,7 +3212,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
|
|||
}
|
||||
|
||||
def ONNXSqrtOp:ONNX_Op<"Sqrt",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Sqrt operation";
|
||||
let description = [{
|
||||
"Square root takes one input data (Tensor<T>) and produces one output data"
|
||||
|
|
|
@ -304,6 +304,12 @@ struct ScalarOp<ONNXLogOp> {
|
|||
using IOp = LogOp; // not use
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ScalarOp<ONNXSqrtOp> {
|
||||
using FOp = KrnlSqrtOp;
|
||||
using IOp = KrnlSqrtOp; // not use
|
||||
};
|
||||
|
||||
template <typename ElementwiseNaryOp>
|
||||
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
|
||||
template <typename ElementwiseNaryOp>
|
||||
|
@ -1267,6 +1273,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
||||
|
|
|
@ -120,6 +120,7 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||
op->getName().getStringRef() != "onnx.Sqrt" &&
|
||||
op->getName().getStringRef() != "onnx.ConvNoBias")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
|
|
|
@ -144,6 +144,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
|||
target.addIllegalDialect<KrnlOpsDialect>();
|
||||
target.addLegalOp<KrnlMemcpyOp>();
|
||||
target.addLegalOp<KrnlEntryPointOp>();
|
||||
target.addLegalOp<KrnlSqrtOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||
|
@ -162,4 +163,4 @@ std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
|||
}
|
||||
|
||||
static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl",
|
||||
"Lower Krnl dialect.");
|
||||
"Lower Krnl dialect.");
|
||||
|
|
|
@ -460,6 +460,67 @@ private:
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL to LLVM: KrnlSqrlOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class KrnlSqrtOpLowering : public ConversionPattern {
|
||||
public:
|
||||
explicit KrnlSqrtOpLowering(MLIRContext *context)
|
||||
: ConversionPattern(KrnlSqrtOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OperandAdaptor<KrnlSqrtOp> adaptor(operands);
|
||||
LLVM::LLVMType operandType =
|
||||
adaptor.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
|
||||
|
||||
if (!operandType)
|
||||
return matchFailure();
|
||||
|
||||
std::string functionName;
|
||||
if (operandType.isFloatTy())
|
||||
functionName = "llvm.sqrt.f32";
|
||||
else if (operandType.isDoubleTy())
|
||||
functionName = "llvm.sqrt.f64";
|
||||
else
|
||||
assert(false && "Unsupported operand type.");
|
||||
|
||||
// Get a symbol reference to the sqrt function, inserting it if necessary.
|
||||
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
|
||||
auto sqrtRef =
|
||||
getOrInsertSqrt(rewriter, parentModule, functionName, operandType);
|
||||
|
||||
// Sqrt call
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operandType, sqrtRef,
|
||||
adaptor.operand());
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Return a symbol reference to the sqrt function, inserting it into the
|
||||
/// module if necessary.
|
||||
static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter,
|
||||
ModuleOp module, std::string fnName,
|
||||
LLVM::LLVMType operandType) {
|
||||
auto *context = module.getContext();
|
||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>(fnName))
|
||||
return SymbolRefAttr::get(fnName, context);
|
||||
// Create a function declaration for sqrt, the signature is:
|
||||
// * `float (float)`
|
||||
auto llvmFnType =
|
||||
LLVM::LLVMType::getFunctionTy(operandType, operandType, false);
|
||||
|
||||
// Insert the sqrt function into the body of the parent module.
|
||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), fnName, llvmFnType);
|
||||
return SymbolRefAttr::get(fnName, context);
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -489,8 +550,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
|||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
||||
&getContext());
|
||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
|
||||
KrnlSqrtOpLowering>(&getContext());
|
||||
|
||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||
// ensures that only legal operations will remain after the conversion.
|
||||
|
|
|
@ -138,6 +138,10 @@ test_to_enable = [
|
|||
"test_softmax_example_cpu",
|
||||
"test_softmax_large_number_cpu",
|
||||
|
||||
# Sqrt Op:
|
||||
"test_sqrt_cpu",
|
||||
"test_sqrt_example_cpu",
|
||||
|
||||
# Sum Op:
|
||||
"test_sum_example_cpu",
|
||||
"test_sum_one_input_cpu",
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
// RUN: onnf-opt --shape-inference --lower-all-llvm %s -split-input-file | FileCheck %s
|
||||
module {
|
||||
func @test_sqrt_32(%arg0 : f32) -> f32 {
|
||||
%0 = "krnl.sqrt"(%arg0) : (f32) -> f32
|
||||
"std.return"(%0) : (f32) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.sqrt.f32(!llvm.float) -> !llvm.float
|
||||
// CHECK-NEXT: llvm.func @test_sqrt_32(%arg0: !llvm.float) -> !llvm.float {
|
||||
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f32(%arg0) : (!llvm.float) -> !llvm.float
|
||||
// CHECK-NEXT: llvm.return [[RES]] : !llvm.float
|
||||
}
|
||||
}
|
||||
|
||||
module{
|
||||
func @test_sqrt_64(%arg0 : f64) -> f64 {
|
||||
%0 = "krnl.sqrt"(%arg0) : (f64) -> f64
|
||||
"std.return"(%0) : (f64) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.sqrt.f64(!llvm.double) -> !llvm.double
|
||||
// CHECK-NEXT: llvm.func @test_sqrt_64(%arg0: !llvm.double) -> !llvm.double {
|
||||
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f64(%arg0) : (!llvm.double) -> !llvm.double
|
||||
// CHECK-NEXT: llvm.return [[RES]] : !llvm.double
|
||||
}
|
||||
}
|
|
@ -623,3 +623,23 @@ func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
|||
// CHECK: dealloc [[MAX]] : memref<f32>
|
||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
}
|
||||
|
||||
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_sqrt
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: [[SQRT:%.+]] = "krnl.sqrt"([[LOAD]]) : (f32) -> f32
|
||||
// CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue