Use template to support lowering all binary onnx ops to kernel ir (#387)
This commit is contained in:
parent
7fb2f80dce
commit
05e16dafae
|
@ -99,11 +99,12 @@ static bool checkInsertDealloc(Operation *currentOp) {
|
|||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp lowering to Krnl dialect.
|
||||
// Binary ops lowering to Krnl dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct ONNXAddOpLowering : public ConversionPattern {
|
||||
ONNXAddOpLowering(MLIRContext* ctx)
|
||||
: ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {}
|
||||
template <typename BinaryOp, typename LoweredBinaryOp>
|
||||
struct ONNXBinaryOpLowering : public ConversionPattern {
|
||||
ONNXBinaryOpLowering(MLIRContext* ctx)
|
||||
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
|
@ -193,11 +194,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
|||
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
||||
|
||||
// TODO: Choose type of the Add for now use the Float Add.
|
||||
auto addOpResult =
|
||||
rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal);
|
||||
auto loweredOpResult =
|
||||
rewriter.create<LoweredBinaryOp>(loc, loadedFirstVal, loadedSecondVal);
|
||||
|
||||
// Store result in the resulting array.
|
||||
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
|
||||
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
|
@ -205,6 +206,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp lowering to Krnl dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
using ONNXAddOpLowering = ONNXBinaryOpLowering<mlir::ONNXAddOp, AddFOp>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue