Use template to support lowering all binary onnx ops to kernel ir (#387)

This commit is contained in:
TUNG LEDUC 2019-11-29 04:52:29 +09:00 committed by Tian Jin
parent 7fb2f80dce
commit 05e16dafae
1 changed files with 13 additions and 7 deletions

View File

@ -99,11 +99,12 @@ static bool checkInsertDealloc(Operation *currentOp) {
namespace {
//===----------------------------------------------------------------------===//
// AddOp lowering to Krnl dialect.
// Binary ops lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
struct ONNXAddOpLowering : public ConversionPattern {
ONNXAddOpLowering(MLIRContext* ctx)
: ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {}
template <typename BinaryOp, typename LoweredBinaryOp>
struct ONNXBinaryOpLowering : public ConversionPattern {
ONNXBinaryOpLowering(MLIRContext* ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
@ -193,11 +194,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
// TODO: Choose type of the Add for now use the Float Add.
auto addOpResult =
rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal);
auto loweredOpResult =
rewriter.create<LoweredBinaryOp>(loc, loadedFirstVal, loadedSecondVal);
// Store result in the resulting array.
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
rewriter.replaceOp(op, alloc);
@ -205,6 +206,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
}
};
//===----------------------------------------------------------------------===//
// AddOp lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
using ONNXAddOpLowering = ONNXBinaryOpLowering<mlir::ONNXAddOp, AddFOp>;
//===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===//