Add KrnlSqrtOp (#22)
* Initial lowering of KrnlSqrtOp * Fix errors and add a testcase * typos * Add the MLIR example * Restore doc/doc_check/CMakeLists.txt * Clean the code * Edit comments * Remove redundant parts * Chang the use of -> to . * Add a test for f64 * Support ONNXSqrtOp * Fix indentation Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
f00206cecf
commit
195bf9d15d
|
@ -190,3 +190,17 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
||||||
let parser = ?;
|
let parser = ?;
|
||||||
let printer = ?;
|
let printer = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def KrnlSqrtOp : Op<Krnl_Dialect, "sqrt", [NoSideEffect]> {
|
||||||
|
let summary = "Krnl sqrt operation";
|
||||||
|
let description = [{
|
||||||
|
"The `sqrt` computes the square root value. It takes one operand and returns
|
||||||
|
one result with the same type."
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins FloatLike:$operand);
|
||||||
|
let results = (outs FloatLike);
|
||||||
|
|
||||||
|
let parser = ?;
|
||||||
|
let printer = ?;
|
||||||
|
}
|
||||||
|
|
|
@ -268,7 +268,7 @@ def gen_schema(schema) :
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||||
'Softplus', 'Softsign']
|
'Softplus', 'Softsign', 'Sqrt']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
manual_code = dict([
|
manual_code = dict([
|
||||||
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||||
|
|
|
@ -182,6 +182,14 @@ void ONNXSoftsignOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Sqrt
|
||||||
|
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
||||||
|
/// the shape inference interface.
|
||||||
|
void ONNXSqrtOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
|
|
|
@ -3212,7 +3212,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSqrtOp:ONNX_Op<"Sqrt",
|
def ONNXSqrtOp:ONNX_Op<"Sqrt",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Sqrt operation";
|
let summary = "ONNX Sqrt operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Square root takes one input data (Tensor<T>) and produces one output data"
|
"Square root takes one input data (Tensor<T>) and produces one output data"
|
||||||
|
|
|
@ -304,6 +304,12 @@ struct ScalarOp<ONNXLogOp> {
|
||||||
using IOp = LogOp; // not use
|
using IOp = LogOp; // not use
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXSqrtOp> {
|
||||||
|
using FOp = KrnlSqrtOp;
|
||||||
|
using IOp = KrnlSqrtOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ElementwiseNaryOp>
|
template <typename ElementwiseNaryOp>
|
||||||
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
|
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
|
||||||
template <typename ElementwiseNaryOp>
|
template <typename ElementwiseNaryOp>
|
||||||
|
@ -1267,6 +1273,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
||||||
|
|
|
@ -120,6 +120,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Sqrt" &&
|
||||||
op->getName().getStringRef() != "onnx.ConvNoBias")
|
op->getName().getStringRef() != "onnx.ConvNoBias")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
|
|
|
@ -144,6 +144,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
target.addIllegalDialect<KrnlOpsDialect>();
|
target.addIllegalDialect<KrnlOpsDialect>();
|
||||||
target.addLegalOp<KrnlMemcpyOp>();
|
target.addLegalOp<KrnlMemcpyOp>();
|
||||||
target.addLegalOp<KrnlEntryPointOp>();
|
target.addLegalOp<KrnlEntryPointOp>();
|
||||||
|
target.addLegalOp<KrnlSqrtOp>();
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
|
@ -162,4 +163,4 @@ std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl",
|
static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl",
|
||||||
"Lower Krnl dialect.");
|
"Lower Krnl dialect.");
|
||||||
|
|
|
@ -460,6 +460,67 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KRNL to LLVM: KrnlSqrlOpLowering
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlSqrtOpLowering : public ConversionPattern {
|
||||||
|
public:
|
||||||
|
explicit KrnlSqrtOpLowering(MLIRContext *context)
|
||||||
|
: ConversionPattern(KrnlSqrtOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
OperandAdaptor<KrnlSqrtOp> adaptor(operands);
|
||||||
|
LLVM::LLVMType operandType =
|
||||||
|
adaptor.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
|
||||||
|
|
||||||
|
if (!operandType)
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
std::string functionName;
|
||||||
|
if (operandType.isFloatTy())
|
||||||
|
functionName = "llvm.sqrt.f32";
|
||||||
|
else if (operandType.isDoubleTy())
|
||||||
|
functionName = "llvm.sqrt.f64";
|
||||||
|
else
|
||||||
|
assert(false && "Unsupported operand type.");
|
||||||
|
|
||||||
|
// Get a symbol reference to the sqrt function, inserting it if necessary.
|
||||||
|
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
|
||||||
|
auto sqrtRef =
|
||||||
|
getOrInsertSqrt(rewriter, parentModule, functionName, operandType);
|
||||||
|
|
||||||
|
// Sqrt call
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operandType, sqrtRef,
|
||||||
|
adaptor.operand());
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Return a symbol reference to the sqrt function, inserting it into the
|
||||||
|
/// module if necessary.
|
||||||
|
static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter,
|
||||||
|
ModuleOp module, std::string fnName,
|
||||||
|
LLVM::LLVMType operandType) {
|
||||||
|
auto *context = module.getContext();
|
||||||
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>(fnName))
|
||||||
|
return SymbolRefAttr::get(fnName, context);
|
||||||
|
// Create a function declaration for sqrt, the signature is:
|
||||||
|
// * `float (float)`
|
||||||
|
auto llvmFnType =
|
||||||
|
LLVM::LLVMType::getFunctionTy(operandType, operandType, false);
|
||||||
|
|
||||||
|
// Insert the sqrt function into the body of the parent module.
|
||||||
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), fnName, llvmFnType);
|
||||||
|
return SymbolRefAttr::get(fnName, context);
|
||||||
|
}
|
||||||
|
};
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -489,8 +550,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
|
||||||
&getContext());
|
KrnlSqrtOpLowering>(&getContext());
|
||||||
|
|
||||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||||
// ensures that only legal operations will remain after the conversion.
|
// ensures that only legal operations will remain after the conversion.
|
||||||
|
|
|
@ -138,6 +138,10 @@ test_to_enable = [
|
||||||
"test_softmax_example_cpu",
|
"test_softmax_example_cpu",
|
||||||
"test_softmax_large_number_cpu",
|
"test_softmax_large_number_cpu",
|
||||||
|
|
||||||
|
# Sqrt Op:
|
||||||
|
"test_sqrt_cpu",
|
||||||
|
"test_sqrt_example_cpu",
|
||||||
|
|
||||||
# Sum Op:
|
# Sum Op:
|
||||||
"test_sum_example_cpu",
|
"test_sum_example_cpu",
|
||||||
"test_sum_one_input_cpu",
|
"test_sum_one_input_cpu",
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
// RUN: onnf-opt --shape-inference --lower-all-llvm %s -split-input-file | FileCheck %s
|
||||||
|
module {
|
||||||
|
func @test_sqrt_32(%arg0 : f32) -> f32 {
|
||||||
|
%0 = "krnl.sqrt"(%arg0) : (f32) -> f32
|
||||||
|
"std.return"(%0) : (f32) -> ()
|
||||||
|
|
||||||
|
// CHECK: llvm.func @llvm.sqrt.f32(!llvm.float) -> !llvm.float
|
||||||
|
// CHECK-NEXT: llvm.func @test_sqrt_32(%arg0: !llvm.float) -> !llvm.float {
|
||||||
|
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f32(%arg0) : (!llvm.float) -> !llvm.float
|
||||||
|
// CHECK-NEXT: llvm.return [[RES]] : !llvm.float
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module{
|
||||||
|
func @test_sqrt_64(%arg0 : f64) -> f64 {
|
||||||
|
%0 = "krnl.sqrt"(%arg0) : (f64) -> f64
|
||||||
|
"std.return"(%0) : (f64) -> ()
|
||||||
|
|
||||||
|
// CHECK: llvm.func @llvm.sqrt.f64(!llvm.double) -> !llvm.double
|
||||||
|
// CHECK-NEXT: llvm.func @test_sqrt_64(%arg0: !llvm.double) -> !llvm.double {
|
||||||
|
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f64(%arg0) : (!llvm.double) -> !llvm.double
|
||||||
|
// CHECK-NEXT: llvm.return [[RES]] : !llvm.double
|
||||||
|
}
|
||||||
|
}
|
|
@ -623,3 +623,23 @@ func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: dealloc [[MAX]] : memref<f32>
|
// CHECK: dealloc [[MAX]] : memref<f32>
|
||||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_sqrt
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[SQRT:%.+]] = "krnl.sqrt"([[LOAD]]) : (f32) -> f32
|
||||||
|
// CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue