From 05e16dafae31e8f9c61e7e4e44eef3a0de1a6222 Mon Sep 17 00:00:00 2001 From: TUNG LEDUC Date: Fri, 29 Nov 2019 04:52:29 +0900 Subject: [PATCH] Use template to support lowering all binary onnx ops to kernel ir (#387) --- src/compiler/pass/lower_frontend_to_krnl.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index d30c7bf..22b1b68 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -99,11 +99,12 @@ static bool checkInsertDealloc(Operation *currentOp) { namespace { //===----------------------------------------------------------------------===// -// AddOp lowering to Krnl dialect. +// Binary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// -struct ONNXAddOpLowering : public ConversionPattern { - ONNXAddOpLowering(MLIRContext* ctx) - : ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {} +template +struct ONNXBinaryOpLowering : public ConversionPattern { + ONNXBinaryOpLowering(MLIRContext* ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { @@ -193,11 +194,11 @@ struct ONNXAddOpLowering : public ConversionPattern { auto loadedSecondVal = rewriter.create(loc, operands[1], loopIVs); // TODO: Choose type of the Add for now use the Float Add. - auto addOpResult = - rewriter.create(loc, loadedFirstVal, loadedSecondVal); + auto loweredOpResult = + rewriter.create(loc, loadedFirstVal, loadedSecondVal); // Store result in the resulting array. - rewriter.create(loc, addOpResult, alloc, loopIVs); + rewriter.create(loc, loweredOpResult, alloc, loopIVs); rewriter.replaceOp(op, alloc); @@ -205,6 +206,11 @@ struct ONNXAddOpLowering : public ConversionPattern { } }; +//===----------------------------------------------------------------------===// +// AddOp lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +using ONNXAddOpLowering = ONNXBinaryOpLowering; + //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===//