Use template to support lowering all binary onnx ops to kernel ir (#387)

This commit is contained in:
TUNG LEDUC 2019-11-29 04:52:29 +09:00 committed by Tian Jin
parent 7fb2f80dce
commit 05e16dafae
1 changed files with 13 additions and 7 deletions

View File

@ -99,11 +99,12 @@ static bool checkInsertDealloc(Operation *currentOp) {
namespace { namespace {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AddOp lowering to Krnl dialect. // Binary ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
struct ONNXAddOpLowering : public ConversionPattern { template <typename BinaryOp, typename LoweredBinaryOp>
ONNXAddOpLowering(MLIRContext* ctx) struct ONNXBinaryOpLowering : public ConversionPattern {
: ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {} ONNXBinaryOpLowering(MLIRContext* ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
@ -193,11 +194,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs); auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
// TODO: Choose type of the Add for now use the Float Add. // TODO: Choose type of the Add for now use the Float Add.
auto addOpResult = auto loweredOpResult =
rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal); rewriter.create<LoweredBinaryOp>(loc, loadedFirstVal, loadedSecondVal);
// Store result in the resulting array. // Store result in the resulting array.
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs); rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
rewriter.replaceOp(op, alloc); rewriter.replaceOp(op, alloc);
@ -205,6 +206,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
} }
}; };
//===----------------------------------------------------------------------===//
// AddOp lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
using ONNXAddOpLowering = ONNXBinaryOpLowering<mlir::ONNXAddOp, AddFOp>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type. // Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//