diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f95f1c5..3fc1b38 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -88,6 +88,7 @@ add_library(onnf_lower_frontend conversion/onnx_to_krnl/nn/pooling.cpp conversion/onnx_to_krnl/tensor/identity.cpp conversion/onnx_to_krnl/tensor/reshape.cpp + conversion/onnx_to_krnl/tensor/padconstantvaluepad.cpp conversion/onnx_to_krnl/tensor/transpose.cpp conversion/onnx_to_krnl/tensor/unsqueeze.cpp conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp) diff --git a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp index 3373143..9797b44 100644 --- a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp +++ b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp @@ -93,6 +93,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { populateLoweringONNXMatMulOpPattern(patterns, &getContext()); // Tensor populateLoweringONNXReshapeOpPattern(patterns, &getContext()); + populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext()); populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext()); populateLoweringONNXTransposeOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext()); diff --git a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp index 7750c16..4c5413d 100644 --- a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp +++ b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp @@ -234,6 +234,9 @@ void populateLoweringONNXUnsqueezeOpPattern( void populateLoweringONNXTransposeOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); +void populateLoweringONNXPadConstantValuePadOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + void populateLoweringONNXReshapeOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/src/conversion/onnx_to_krnl/tensor/padconstantvaluepad.cpp b/src/conversion/onnx_to_krnl/tensor/padconstantvaluepad.cpp new file mode 100644 index 0000000..4a6d9fa --- /dev/null +++ b/src/conversion/onnx_to_krnl/tensor/padconstantvaluepad.cpp @@ -0,0 +1,108 @@ +//===----padconstantvaluepad.cpp - Lowering PadConstantValuePad Op --------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX PadConstantValuePad Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + +struct ONNXPadConstantValuePadOpLowering : public ConversionPattern { + ONNXPadConstantValuePadOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXPadConstantValuePadOp::getOperationName(), + 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()); + auto loc = op->getLoc(); + + // Only constant padding is supported now. + auto padMode = llvm::dyn_cast(op).mode(); + if (padMode != "constant") + emitError(loc, "unsupported mode for PadConstantValuePad"); + auto constantValAttr = + llvm::dyn_cast(op).constant_valueAttr(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(tensorType); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + emitError(loc, "unexpected output has non-Constant shape"); + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + + // Iterate over the loop nest using the output shape. + BuildKrnlLoop padLoops(rewriter, loc, rank); + padLoops.createDefineAndOptimizeOp(); + for (int i = 0; i < rank; ++i) + padLoops.pushBounds(0, alloc, i); + padLoops.createIterateOp(); + + // Iterate over the loop nest using the input shape. + BuildKrnlLoop valueLoops(rewriter, loc, rank); + valueLoops.createDefineAndOptimizeOp(); + for (int i = 0; i < rank; ++i) + valueLoops.pushBounds(0, operands[0], i); + valueLoops.createIterateOp(); + + // Copy the input data into the output. + rewriter.setInsertionPointToStart(valueLoops.getIterateBlock()); + + SmallVector inLoopIVs; + for (int i = 0; i < rank; ++i) + inLoopIVs.emplace_back(valueLoops.getInductionVar(i)); + + auto pads = llvm::dyn_cast(op).pads(); + SmallVector pad_begin; + for (int i = 0; i < pads.size()/2; ++i) { + pad_begin.emplace_back(pads.getValue()[i].cast().getInt()); + } + + SmallVector outLoopIVs; + for (int i = 0; i < rank; ++i) { + // Calculate the index for the load and store. + if (pad_begin[i] == 0) { + outLoopIVs.emplace_back(valueLoops.getInductionVar(i)); + } else { + auto outIV = rewriter.create( + loc, rewriter.create(loc, pad_begin[i]), + valueLoops.getInductionVar(i)); + outLoopIVs.emplace_back(outIV); + } + } + + auto inVal = rewriter.create(loc, operands[0], inLoopIVs); + rewriter.create(loc, inVal, alloc, outLoopIVs); + rewriter.setInsertionPointToStart(padLoops.getIterateBlock()); + + SmallVector outLoopIVs1; + for (int i = 0; i < rank; ++i) + outLoopIVs1.emplace_back(padLoops.getInductionVar(i)); + + auto inVal1 = rewriter.create(loc, constantValAttr); + rewriter.create(loc, inVal1, alloc, outLoopIVs1); + + // Replace the original op with the generated code. + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXPadConstantValuePadOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 149d724..817935a 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1510,3 +1510,28 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg // CHECK: } // CHECK: return [[RES]] : memref } + +func @test_constant_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> { + %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x16xf32>) -> tensor<18x20xf32> + return %0 : tensor<18x20xf32> + // CHECK-LABEL: test_constant_pad1 + // CHECK: [[RES:%.+]] = alloc() : memref<18x20xf32> + // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 18, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 20) { + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: store [[CST]], [[RES]][%arg1, %arg2] : memref<18x20xf32> + // CHECK: } + // CHECK: [[DEF_LOOPS2:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS2:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 16, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 16) { + // CHECK: [[CST1:%.+]] = constant 3 : index + // CHECK: [[ADD:%.+]] = addi [[CST1]], %arg2 : index + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<16x16xf32> + // CHECK: store [[LOAD]], [[RES]][%arg1, [[ADD]]] : memref<18x20xf32> + // CHECK: } +}