Use SqrtOp in Standard dialect (#108)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
0c4a010283
commit
5357fc1421
|
@ -86,8 +86,8 @@ struct ScalarOp<ONNXLogOp> {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXSqrtOp> {
|
struct ScalarOp<ONNXSqrtOp> {
|
||||||
using FOp = KrnlSqrtOp;
|
using FOp = SqrtOp;
|
||||||
using IOp = KrnlSqrtOp; // not use
|
using IOp = SqrtOp; // not use
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -123,8 +123,7 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||||
auto dividend = rewriter.create<SubFOp>(loc, xVal, meanVal);
|
auto dividend = rewriter.create<SubFOp>(loc, xVal, meanVal);
|
||||||
auto adjustedVarianceVal =
|
auto adjustedVarianceVal =
|
||||||
rewriter.create<AddFOp>(loc, varianceVal, epsilon);
|
rewriter.create<AddFOp>(loc, varianceVal, epsilon);
|
||||||
auto divisor = rewriter.create<KrnlSqrtOp>(loc, memRefType.getElementType(),
|
auto divisor = rewriter.create<SqrtOp>(loc, adjustedVarianceVal);
|
||||||
adjustedVarianceVal);
|
|
||||||
auto normVal = rewriter.create<DivFOp>(loc, dividend, divisor);
|
auto normVal = rewriter.create<DivFOp>(loc, dividend, divisor);
|
||||||
// scale and shift
|
// scale and shift
|
||||||
auto scaleNormVal = rewriter.create<MulFOp>(loc, scaleVal, normVal);
|
auto scaleNormVal = rewriter.create<MulFOp>(loc, scaleVal, normVal);
|
||||||
|
|
|
@ -190,17 +190,3 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
||||||
let parser = ?;
|
let parser = ?;
|
||||||
let printer = ?;
|
let printer = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
def KrnlSqrtOp : Op<Krnl_Dialect, "sqrt", [NoSideEffect]> {
|
|
||||||
let summary = "Krnl sqrt operation";
|
|
||||||
let description = [{
|
|
||||||
"The `sqrt` computes the square root value. It takes one operand and returns
|
|
||||||
one result with the same type."
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins FloatLike:$operand);
|
|
||||||
let results = (outs FloatLike);
|
|
||||||
|
|
||||||
let parser = ?;
|
|
||||||
let printer = ?;
|
|
||||||
}
|
|
||||||
|
|
|
@ -144,7 +144,6 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
target.addIllegalDialect<KrnlOpsDialect>();
|
target.addIllegalDialect<KrnlOpsDialect>();
|
||||||
target.addLegalOp<KrnlMemcpyOp>();
|
target.addLegalOp<KrnlMemcpyOp>();
|
||||||
target.addLegalOp<KrnlEntryPointOp>();
|
target.addLegalOp<KrnlEntryPointOp>();
|
||||||
target.addLegalOp<KrnlSqrtOp>();
|
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
|
|
|
@ -480,67 +480,6 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// KRNL to LLVM: KrnlSqrlOpLowering
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
class KrnlSqrtOpLowering : public ConversionPattern {
|
|
||||||
public:
|
|
||||||
explicit KrnlSqrtOpLowering(MLIRContext *context)
|
|
||||||
: ConversionPattern(KrnlSqrtOp::getOperationName(), 1, context) {}
|
|
||||||
|
|
||||||
PatternMatchResult
|
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
OperandAdaptor<KrnlSqrtOp> adaptor(operands);
|
|
||||||
LLVM::LLVMType operandType =
|
|
||||||
adaptor.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
|
|
||||||
|
|
||||||
if (!operandType)
|
|
||||||
return matchFailure();
|
|
||||||
|
|
||||||
std::string functionName;
|
|
||||||
if (operandType.isFloatTy())
|
|
||||||
functionName = "llvm.sqrt.f32";
|
|
||||||
else if (operandType.isDoubleTy())
|
|
||||||
functionName = "llvm.sqrt.f64";
|
|
||||||
else
|
|
||||||
assert(false && "Unsupported operand type.");
|
|
||||||
|
|
||||||
// Get a symbol reference to the sqrt function, inserting it if necessary.
|
|
||||||
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
|
|
||||||
auto sqrtRef =
|
|
||||||
getOrInsertSqrt(rewriter, parentModule, functionName, operandType);
|
|
||||||
|
|
||||||
// Sqrt call
|
|
||||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operandType, sqrtRef,
|
|
||||||
adaptor.operand());
|
|
||||||
|
|
||||||
return matchSuccess();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// Return a symbol reference to the sqrt function, inserting it into the
|
|
||||||
/// module if necessary.
|
|
||||||
static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter,
|
|
||||||
ModuleOp module, std::string fnName,
|
|
||||||
LLVM::LLVMType operandType) {
|
|
||||||
auto *context = module.getContext();
|
|
||||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>(fnName))
|
|
||||||
return SymbolRefAttr::get(fnName, context);
|
|
||||||
// Create a function declaration for sqrt, the signature is:
|
|
||||||
// * `float (float)`
|
|
||||||
auto llvmFnType =
|
|
||||||
LLVM::LLVMType::getFunctionTy(operandType, operandType, false);
|
|
||||||
|
|
||||||
// Insert the sqrt function into the body of the parent module.
|
|
||||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
|
||||||
rewriter.setInsertionPointToStart(module.getBody());
|
|
||||||
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), fnName, llvmFnType);
|
|
||||||
return SymbolRefAttr::get(fnName, context);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -572,8 +511,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
/*emitCWrapper=*/true);
|
/*emitCWrapper=*/true);
|
||||||
|
|
||||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
|
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
||||||
KrnlSqrtOpLowering>(&getContext());
|
&getContext());
|
||||||
|
|
||||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||||
// ensures that only legal operations will remain after the conversion.
|
// ensures that only legal operations will remain after the conversion.
|
||||||
|
|
|
@ -1,24 +0,0 @@
|
||||||
// RUN: onnf-opt --shape-inference --lower-all-llvm %s -split-input-file | FileCheck %s
|
|
||||||
module {
|
|
||||||
func @test_sqrt_32(%arg0 : f32) -> f32 {
|
|
||||||
%0 = "krnl.sqrt"(%arg0) : (f32) -> f32
|
|
||||||
"std.return"(%0) : (f32) -> ()
|
|
||||||
|
|
||||||
// CHECK: llvm.func @llvm.sqrt.f32(!llvm.float) -> !llvm.float
|
|
||||||
// CHECK-NEXT: llvm.func @test_sqrt_32(%arg0: !llvm.float) -> !llvm.float {
|
|
||||||
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f32(%arg0) : (!llvm.float) -> !llvm.float
|
|
||||||
// CHECK-NEXT: llvm.return [[RES]] : !llvm.float
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
module{
|
|
||||||
func @test_sqrt_64(%arg0 : f64) -> f64 {
|
|
||||||
%0 = "krnl.sqrt"(%arg0) : (f64) -> f64
|
|
||||||
"std.return"(%0) : (f64) -> ()
|
|
||||||
|
|
||||||
// CHECK: llvm.func @llvm.sqrt.f64(!llvm.double) -> !llvm.double
|
|
||||||
// CHECK-NEXT: llvm.func @test_sqrt_64(%arg0: !llvm.double) -> !llvm.double {
|
|
||||||
// CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f64(%arg0) : (!llvm.double) -> !llvm.double
|
|
||||||
// CHECK-NEXT: llvm.return [[RES]] : !llvm.double
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -820,7 +820,7 @@ func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: [[SQRT:%.+]] = "krnl.sqrt"([[LOAD]]) : (f32) -> f32
|
// CHECK: [[SQRT:%.+]] = sqrt [[LOAD]] : f32
|
||||||
// CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
@ -1305,7 +1305,7 @@ func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32
|
||||||
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
|
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
|
||||||
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
||||||
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
||||||
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
|
// CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32
|
||||||
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||||
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
||||||
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
||||||
|
@ -1335,7 +1335,7 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a
|
||||||
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32>
|
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32>
|
||||||
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
||||||
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
||||||
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
|
// CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32
|
||||||
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||||
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
||||||
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
||||||
|
|
Loading…
Reference in New Issue