Use SqrtOp in Standard dialect (#108)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
0c4a010283
commit
5357fc1421
|
@ -86,8 +86,8 @@ struct ScalarOp<ONNXLogOp> {
|
|||
|
||||
template <>
|
||||
struct ScalarOp<ONNXSqrtOp> {
|
||||
using FOp = KrnlSqrtOp;
|
||||
using IOp = KrnlSqrtOp; // not use
|
||||
using FOp = SqrtOp;
|
||||
using IOp = SqrtOp; // not use
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -123,8 +123,7 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
|||
auto dividend = rewriter.create<SubFOp>(loc, xVal, meanVal);
|
||||
auto adjustedVarianceVal =
|
||||
rewriter.create<AddFOp>(loc, varianceVal, epsilon);
|
||||
auto divisor = rewriter.create<KrnlSqrtOp>(loc, memRefType.getElementType(),
|
||||
adjustedVarianceVal);
|
||||
auto divisor = rewriter.create<SqrtOp>(loc, adjustedVarianceVal);
|
||||
auto normVal = rewriter.create<DivFOp>(loc, dividend, divisor);
|
||||
// scale and shift
|
||||
auto scaleNormVal = rewriter.create<MulFOp>(loc, scaleVal, normVal);
|
||||
|
|
|
@ -190,17 +190,3 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
|||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
||||
def KrnlSqrtOp : Op<Krnl_Dialect, "sqrt", [NoSideEffect]> {
|
||||
let summary = "Krnl sqrt operation";
|
||||
let description = [{
|
||||
"The `sqrt` computes the square root value. It takes one operand and returns
|
||||
one result with the same type."
|
||||
}];
|
||||
|
||||
let arguments = (ins FloatLike:$operand);
|
||||
let results = (outs FloatLike);
|
||||
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
|
|
@ -144,7 +144,6 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
|||
target.addIllegalDialect<KrnlOpsDialect>();
|
||||
target.addLegalOp<KrnlMemcpyOp>();
|
||||
target.addLegalOp<KrnlEntryPointOp>();
|
||||
target.addLegalOp<KrnlSqrtOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||
|
|
|
@ -480,67 +480,6 @@ private:
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL to LLVM: KrnlSqrlOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class KrnlSqrtOpLowering : public ConversionPattern {
|
||||
public:
|
||||
explicit KrnlSqrtOpLowering(MLIRContext *context)
|
||||
: ConversionPattern(KrnlSqrtOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OperandAdaptor<KrnlSqrtOp> adaptor(operands);
|
||||
LLVM::LLVMType operandType =
|
||||
adaptor.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
|
||||
|
||||
if (!operandType)
|
||||
return matchFailure();
|
||||
|
||||
std::string functionName;
|
||||
if (operandType.isFloatTy())
|
||||
functionName = "llvm.sqrt.f32";
|
||||
else if (operandType.isDoubleTy())
|
||||
functionName = "llvm.sqrt.f64";
|
||||
else
|
||||
assert(false && "Unsupported operand type.");
|
||||
|
||||
// Get a symbol reference to the sqrt function, inserting it if necessary.
|
||||
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
|
||||
auto sqrtRef =
|
||||
getOrInsertSqrt(rewriter, parentModule, functionName, operandType);
|
||||
|
||||
// Sqrt call
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operandType, sqrtRef,
|
||||
adaptor.operand());
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Return a symbol reference to the sqrt function, inserting it into the
|
||||
/// module if necessary.
|
||||
static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter,
|
||||
ModuleOp module, std::string fnName,
|
||||
LLVM::LLVMType operandType) {
|
||||
auto *context = module.getContext();
|
||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>(fnName))
|
||||
return SymbolRefAttr::get(fnName, context);
|
||||
// Create a function declaration for sqrt, the signature is:
|
||||
// * `float (float)`
|
||||
auto llvmFnType =
|
||||
LLVM::LLVMType::getFunctionTy(operandType, operandType, false);
|
||||
|
||||
// Insert the sqrt function into the body of the parent module.
|
||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), fnName, llvmFnType);
|
||||
return SymbolRefAttr::get(fnName, context);
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -572,8 +511,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
|||
/*emitCWrapper=*/true);
|
||||
|
||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
|
||||
KrnlSqrtOpLowering>(&getContext());
|
||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
||||
&getContext());
|
||||
|
||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||
// ensures that only legal operations will remain after the conversion.
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
// RUN: onnf-opt --shape-inference --lower-all-llvm %s -split-input-file | FileCheck %s
|
||||
module {
|
||||
func @test_sqrt_32(%arg0 : f32) -> f32 {
|
||||
%0 = "krnl.sqrt"(%arg0) : (f32) -> f32
|
||||
"std.return"(%0) : (f32) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.sqrt.f32(!llvm.float) -> !llvm.float
|
||||
// CHECK-NEXT: llvm.func @test_sqrt_32(%arg0: !llvm.float) -> !llvm.float {
|
||||
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f32(%arg0) : (!llvm.float) -> !llvm.float
|
||||
// CHECK-NEXT: llvm.return [[RES]] : !llvm.float
|
||||
}
|
||||
}
|
||||
|
||||
module{
|
||||
func @test_sqrt_64(%arg0 : f64) -> f64 {
|
||||
%0 = "krnl.sqrt"(%arg0) : (f64) -> f64
|
||||
"std.return"(%0) : (f64) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.sqrt.f64(!llvm.double) -> !llvm.double
|
||||
// CHECK-NEXT: llvm.func @test_sqrt_64(%arg0: !llvm.double) -> !llvm.double {
|
||||
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f64(%arg0) : (!llvm.double) -> !llvm.double
|
||||
// CHECK-NEXT: llvm.return [[RES]] : !llvm.double
|
||||
}
|
||||
}
|
|
@ -820,7 +820,7 @@ func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
|||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: [[SQRT:%.+]] = "krnl.sqrt"([[LOAD]]) : (f32) -> f32
|
||||
// CHECK: [[SQRT:%.+]] = sqrt [[LOAD]] : f32
|
||||
// CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
@ -1305,7 +1305,7 @@ func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32
|
|||
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
|
||||
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
||||
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
||||
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
|
||||
// CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32
|
||||
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
||||
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
||||
|
@ -1335,7 +1335,7 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a
|
|||
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32>
|
||||
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
||||
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
||||
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
|
||||
// CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32
|
||||
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
||||
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
||||
|
|
Loading…
Reference in New Issue