Use template to support lowering all binary onnx ops to kernel ir (#387)
This commit is contained in:
parent
7fb2f80dce
commit
05e16dafae
|
@ -99,11 +99,12 @@ static bool checkInsertDealloc(Operation *currentOp) {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AddOp lowering to Krnl dialect.
|
// Binary ops lowering to Krnl dialect.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
struct ONNXAddOpLowering : public ConversionPattern {
|
template <typename BinaryOp, typename LoweredBinaryOp>
|
||||||
ONNXAddOpLowering(MLIRContext* ctx)
|
struct ONNXBinaryOpLowering : public ConversionPattern {
|
||||||
: ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {}
|
ONNXBinaryOpLowering(MLIRContext* ctx)
|
||||||
|
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
@ -193,11 +194,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
||||||
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
||||||
|
|
||||||
// TODO: Choose type of the Add for now use the Float Add.
|
// TODO: Choose type of the Add for now use the Float Add.
|
||||||
auto addOpResult =
|
auto loweredOpResult =
|
||||||
rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal);
|
rewriter.create<LoweredBinaryOp>(loc, loadedFirstVal, loadedSecondVal);
|
||||||
|
|
||||||
// Store result in the resulting array.
|
// Store result in the resulting array.
|
||||||
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
|
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
|
||||||
|
|
||||||
rewriter.replaceOp(op, alloc);
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
@ -205,6 +206,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AddOp lowering to Krnl dialect.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
using ONNXAddOpLowering = ONNXBinaryOpLowering<mlir::ONNXAddOp, AddFOp>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue