Lower constant padding operation to KRNL dialect (#27)
This commit is contained in:
parent
e8a0b47e10
commit
391f565a66
|
@ -88,6 +88,7 @@ add_library(onnf_lower_frontend
|
|||
conversion/onnx_to_krnl/nn/pooling.cpp
|
||||
conversion/onnx_to_krnl/tensor/identity.cpp
|
||||
conversion/onnx_to_krnl/tensor/reshape.cpp
|
||||
conversion/onnx_to_krnl/tensor/padconstantvaluepad.cpp
|
||||
conversion/onnx_to_krnl/tensor/transpose.cpp
|
||||
conversion/onnx_to_krnl/tensor/unsqueeze.cpp
|
||||
conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
|
||||
|
|
|
@ -93,6 +93,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
populateLoweringONNXMatMulOpPattern(patterns, &getContext());
|
||||
// Tensor
|
||||
populateLoweringONNXReshapeOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||
|
|
|
@ -234,6 +234,9 @@ void populateLoweringONNXUnsqueezeOpPattern(
|
|||
void populateLoweringONNXTransposeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXPadConstantValuePadOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXReshapeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
//===----padconstantvaluepad.cpp - Lowering PadConstantValuePad Op --------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file lowers the ONNX PadConstantValuePad Operator to Krnl dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
|
||||
ONNXPadConstantValuePadOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXPadConstantValuePadOp::getOperationName(),
|
||||
1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin());
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Only constant padding is supported now.
|
||||
auto padMode = llvm::dyn_cast<ONNXPadConstantValuePadOp>(op).mode();
|
||||
if (padMode != "constant")
|
||||
emitError(loc, "unsupported mode for PadConstantValuePad");
|
||||
auto constantValAttr =
|
||||
llvm::dyn_cast<ONNXPadConstantValuePadOp>(op).constant_valueAttr();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertToMemRefType(tensorType);
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
emitError(loc, "unexpected output has non-Constant shape");
|
||||
|
||||
// Number of loops
|
||||
auto memRefShape = memRefType.getShape();
|
||||
int64_t rank = memRefShape.size();
|
||||
|
||||
// Iterate over the loop nest using the output shape.
|
||||
BuildKrnlLoop padLoops(rewriter, loc, rank);
|
||||
padLoops.createDefineAndOptimizeOp();
|
||||
for (int i = 0; i < rank; ++i)
|
||||
padLoops.pushBounds(0, alloc, i);
|
||||
padLoops.createIterateOp();
|
||||
|
||||
// Iterate over the loop nest using the input shape.
|
||||
BuildKrnlLoop valueLoops(rewriter, loc, rank);
|
||||
valueLoops.createDefineAndOptimizeOp();
|
||||
for (int i = 0; i < rank; ++i)
|
||||
valueLoops.pushBounds(0, operands[0], i);
|
||||
valueLoops.createIterateOp();
|
||||
|
||||
// Copy the input data into the output.
|
||||
rewriter.setInsertionPointToStart(valueLoops.getIterateBlock());
|
||||
|
||||
SmallVector<Value, 4> inLoopIVs;
|
||||
for (int i = 0; i < rank; ++i)
|
||||
inLoopIVs.emplace_back(valueLoops.getInductionVar(i));
|
||||
|
||||
auto pads = llvm::dyn_cast<ONNXPadConstantValuePadOp>(op).pads();
|
||||
SmallVector<int64_t, 4> pad_begin;
|
||||
for (int i = 0; i < pads.size()/2; ++i) {
|
||||
pad_begin.emplace_back(pads.getValue()[i].cast<IntegerAttr>().getInt());
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> outLoopIVs;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
// Calculate the index for the load and store.
|
||||
if (pad_begin[i] == 0) {
|
||||
outLoopIVs.emplace_back(valueLoops.getInductionVar(i));
|
||||
} else {
|
||||
auto outIV = rewriter.create<AddIOp>(
|
||||
loc, rewriter.create<ConstantIndexOp>(loc, pad_begin[i]),
|
||||
valueLoops.getInductionVar(i));
|
||||
outLoopIVs.emplace_back(outIV);
|
||||
}
|
||||
}
|
||||
|
||||
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
||||
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
||||
rewriter.setInsertionPointToStart(padLoops.getIterateBlock());
|
||||
|
||||
SmallVector<Value, 4> outLoopIVs1;
|
||||
for (int i = 0; i < rank; ++i)
|
||||
outLoopIVs1.emplace_back(padLoops.getInductionVar(i));
|
||||
|
||||
auto inVal1 = rewriter.create<ConstantOp>(loc, constantValAttr);
|
||||
rewriter.create<StoreOp>(loc, inVal1, alloc, outLoopIVs1);
|
||||
|
||||
// Replace the original op with the generated code.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLoweringONNXPadConstantValuePadOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXPadConstantValuePadOpLowering>(ctx);
|
||||
}
|
|
@ -1510,3 +1510,28 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg
|
|||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<?x3x?x16xf32>
|
||||
}
|
||||
|
||||
func @test_constant_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> {
|
||||
%0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x16xf32>) -> tensor<18x20xf32>
|
||||
return %0 : tensor<18x20xf32>
|
||||
// CHECK-LABEL: test_constant_pad1
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<18x20xf32>
|
||||
// CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 18, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 20) {
|
||||
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: store [[CST]], [[RES]][%arg1, %arg2] : memref<18x20xf32>
|
||||
// CHECK: }
|
||||
// CHECK: [[DEF_LOOPS2:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS2:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 16, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 16) {
|
||||
// CHECK: [[CST1:%.+]] = constant 3 : index
|
||||
// CHECK: [[ADD:%.+]] = addi [[CST1]], %arg2 : index
|
||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<16x16xf32>
|
||||
// CHECK: store [[LOAD]], [[RES]][%arg1, [[ADD]]] : memref<18x20xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue