Use SqrtOp in Standard dialect (#108)

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-27 02:03:24 +09:00 committed by GitHub
parent 0c4a010283
commit 5357fc1421
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 8 additions and 109 deletions

View File

@ -86,8 +86,8 @@ struct ScalarOp<ONNXLogOp> {
template <>
struct ScalarOp<ONNXSqrtOp> {
using FOp = KrnlSqrtOp;
using IOp = KrnlSqrtOp; // not use
using FOp = SqrtOp;
using IOp = SqrtOp; // not use
};
//===----------------------------------------------------------------------===//

View File

@ -123,8 +123,7 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
auto dividend = rewriter.create<SubFOp>(loc, xVal, meanVal);
auto adjustedVarianceVal =
rewriter.create<AddFOp>(loc, varianceVal, epsilon);
auto divisor = rewriter.create<KrnlSqrtOp>(loc, memRefType.getElementType(),
adjustedVarianceVal);
auto divisor = rewriter.create<SqrtOp>(loc, adjustedVarianceVal);
auto normVal = rewriter.create<DivFOp>(loc, dividend, divisor);
// scale and shift
auto scaleNormVal = rewriter.create<MulFOp>(loc, scaleVal, normVal);

View File

@ -190,17 +190,3 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
let parser = ?;
let printer = ?;
}
def KrnlSqrtOp : Op<Krnl_Dialect, "sqrt", [NoSideEffect]> {
let summary = "Krnl sqrt operation";
let description = [{
"The `sqrt` computes the square root value. It takes one operand and returns
one result with the same type."
}];
let arguments = (ins FloatLike:$operand);
let results = (outs FloatLike);
let parser = ?;
let printer = ?;
}

View File

@ -144,7 +144,6 @@ void KrnlToAffineLoweringPass::runOnFunction() {
target.addIllegalDialect<KrnlOpsDialect>();
target.addLegalOp<KrnlMemcpyOp>();
target.addLegalOp<KrnlEntryPointOp>();
target.addLegalOp<KrnlSqrtOp>();
OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,

View File

@ -480,67 +480,6 @@ private:
}
}
};
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlSqrlOpLowering
//===----------------------------------------------------------------------===//
class KrnlSqrtOpLowering : public ConversionPattern {
public:
explicit KrnlSqrtOpLowering(MLIRContext *context)
: ConversionPattern(KrnlSqrtOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<KrnlSqrtOp> adaptor(operands);
LLVM::LLVMType operandType =
adaptor.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
if (!operandType)
return matchFailure();
std::string functionName;
if (operandType.isFloatTy())
functionName = "llvm.sqrt.f32";
else if (operandType.isDoubleTy())
functionName = "llvm.sqrt.f64";
else
assert(false && "Unsupported operand type.");
// Get a symbol reference to the sqrt function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto sqrtRef =
getOrInsertSqrt(rewriter, parentModule, functionName, operandType);
// Sqrt call
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operandType, sqrtRef,
adaptor.operand());
return matchSuccess();
}
private:
/// Return a symbol reference to the sqrt function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter,
ModuleOp module, std::string fnName,
LLVM::LLVMType operandType) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(fnName))
return SymbolRefAttr::get(fnName, context);
// Create a function declaration for sqrt, the signature is:
// * `float (float)`
auto llvmFnType =
LLVM::LLVMType::getFunctionTy(operandType, operandType, false);
// Insert the sqrt function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), fnName, llvmFnType);
return SymbolRefAttr::get(fnName, context);
}
};
} // end namespace
//===----------------------------------------------------------------------===//
@ -572,8 +511,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
/*emitCWrapper=*/true);
// Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
KrnlSqrtOpLowering>(&getContext());
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.

View File

@ -1,24 +0,0 @@
// RUN: onnf-opt --shape-inference --lower-all-llvm %s -split-input-file | FileCheck %s
module {
func @test_sqrt_32(%arg0 : f32) -> f32 {
%0 = "krnl.sqrt"(%arg0) : (f32) -> f32
"std.return"(%0) : (f32) -> ()
// CHECK: llvm.func @llvm.sqrt.f32(!llvm.float) -> !llvm.float
// CHECK-NEXT: llvm.func @test_sqrt_32(%arg0: !llvm.float) -> !llvm.float {
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f32(%arg0) : (!llvm.float) -> !llvm.float
// CHECK-NEXT: llvm.return [[RES]] : !llvm.float
}
}
module{
func @test_sqrt_64(%arg0 : f64) -> f64 {
%0 = "krnl.sqrt"(%arg0) : (f64) -> f64
"std.return"(%0) : (f64) -> ()
// CHECK: llvm.func @llvm.sqrt.f64(!llvm.double) -> !llvm.double
// CHECK-NEXT: llvm.func @test_sqrt_64(%arg0: !llvm.double) -> !llvm.double {
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f64(%arg0) : (!llvm.double) -> !llvm.double
// CHECK-NEXT: llvm.return [[RES]] : !llvm.double
}
}

View File

@ -820,7 +820,7 @@ func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[SQRT:%.+]] = "krnl.sqrt"([[LOAD]]) : (f32) -> f32
// CHECK: [[SQRT:%.+]] = sqrt [[LOAD]] : f32
// CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
@ -1305,7 +1305,7 @@ func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
// CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
@ -1335,7 +1335,7 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32>
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
// CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32