Use template to support lowering all binary onnx ops to kernel ir (#387)
This commit is contained in:
		
							parent
							
								
									7fb2f80dce
								
							
						
					
					
						commit
						05e16dafae
					
				| 
						 | 
				
			
			@ -99,11 +99,12 @@ static bool checkInsertDealloc(Operation *currentOp) {
 | 
			
		|||
namespace {
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// AddOp lowering to Krnl dialect.
 | 
			
		||||
// Binary ops lowering to Krnl dialect.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
struct ONNXAddOpLowering : public ConversionPattern {
 | 
			
		||||
  ONNXAddOpLowering(MLIRContext* ctx)
 | 
			
		||||
      : ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {}
 | 
			
		||||
template <typename BinaryOp, typename LoweredBinaryOp>
 | 
			
		||||
struct ONNXBinaryOpLowering : public ConversionPattern {
 | 
			
		||||
  ONNXBinaryOpLowering(MLIRContext* ctx)
 | 
			
		||||
      : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
 | 
			
		||||
 | 
			
		||||
  PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
| 
						 | 
				
			
			@ -193,11 +194,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
 | 
			
		|||
    auto loadedSecondVal = rewriter.create<LoadOp>(loc, operands[1], loopIVs);
 | 
			
		||||
 | 
			
		||||
    // TODO: Choose type of the Add for now use the Float Add.
 | 
			
		||||
    auto addOpResult =
 | 
			
		||||
        rewriter.create<AddFOp>(loc, loadedFirstVal, loadedSecondVal);
 | 
			
		||||
    auto loweredOpResult =
 | 
			
		||||
        rewriter.create<LoweredBinaryOp>(loc, loadedFirstVal, loadedSecondVal);
 | 
			
		||||
 | 
			
		||||
    // Store result in the resulting array.
 | 
			
		||||
    rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
 | 
			
		||||
    rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, alloc);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -205,6 +206,11 @@ struct ONNXAddOpLowering : public ConversionPattern {
 | 
			
		|||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// AddOp lowering to Krnl dialect.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
using ONNXAddOpLowering = ONNXBinaryOpLowering<mlir::ONNXAddOp, AddFOp>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Conversion from Tensor type to the Standard dialect MemRef type.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue