[MLIR] Add SizeOp conversion from ONNX dialect to Krnl dialect (#295)

* [MLIR] Add SizeOp conversion from ONNX dialect to Krnl dialect

Added ONNXSizeOp conversion from ONNX dialect to Krnl dialect. This op is added as a part of --convert-onnx-to-krnl pass.

Signed-off-by: Prashant Kumar <pk5561@gmail.com>

* Add unit tests for Size op.

* Remove unit tests.

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Prashant Kumar 2020-09-21 15:25:21 +05:30 committed by GitHub
parent 3520dbd6e1
commit 4cc16aceb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 120 additions and 0 deletions

View File

@ -24,6 +24,7 @@ add_library(OMONNXToKrnl
Tensor/Concat.cpp Tensor/Concat.cpp
Tensor/Split.cpp Tensor/Split.cpp
Tensor/Gather.cpp Tensor/Gather.cpp
Tensor/Size.cpp
ConvertONNXToKrnl.cpp) ConvertONNXToKrnl.cpp)
target_link_libraries(OMONNXToKrnl target_link_libraries(OMONNXToKrnl
onnx) onnx)

View File

@ -103,6 +103,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
populateLoweringONNXConcatOpPattern(patterns, &getContext()); populateLoweringONNXConcatOpPattern(patterns, &getContext());
populateLoweringONNXSqueezeOpPattern(patterns, &getContext()); populateLoweringONNXSqueezeOpPattern(patterns, &getContext());
populateLoweringONNXSplitOpPattern(patterns, &getContext()); populateLoweringONNXSplitOpPattern(patterns, &getContext());
populateLoweringONNXSizeOpPattern(patterns, &getContext());
// Neural network // Neural network
populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXConvOpPattern(patterns, &getContext());
populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext());

View File

@ -253,6 +253,9 @@ void populateLoweringONNXSqueezeOpPattern(
void populateLoweringONNXSplitOpPattern( void populateLoweringONNXSplitOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXSizeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp); bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
int64_t getMemRefSizeInBytes(Value val); int64_t getMemRefSizeInBytes(Value val);

View File

@ -0,0 +1,73 @@
//===---------------- Size.cpp - Lowering Size Op
//-------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Size Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
struct ONNXSizeOpLowering : public ConversionPattern {
ONNXSizeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSizeOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Gather info.
Location loc = op->getLoc();
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
ONNXSizeOp sizeOp = llvm::dyn_cast<ONNXSizeOp>(op);
ONNXSizeOpAdaptor operandAdaptor(operands);
Value data = operandAdaptor.data();
ArrayRef<int64_t> dataShape = data.getType().cast<MemRefType>().getShape();
Value resultOperand = sizeOp.size();
ValueRange indices;
MemRefType memRefType = convertToMemRefType(*op->result_type_begin());
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {resultOperand});
// Accumulate static dimensions first.
int64_t staticNumElement = 1;
bool allStaticDimensions = true;
for (unsigned i = 0; i < dataShape.size(); i++) {
if (dataShape[i] != -1)
staticNumElement *= dataShape[i];
else
allStaticDimensions = false;
}
// Accumulate the remaining dimensions that are unknown.
Value noElements = emitConstantOp(
rewriter, loc, memRefType.getElementType(), staticNumElement);
if (!allStaticDimensions) {
for (unsigned i = 0; i < dataShape.size(); i++) {
if (dataShape[i] == -1) {
Value index = rewriter.create<DimOp>(loc, data, i);
Value dim = rewriter.create<IndexCastOp>(
loc, index, memRefType.getElementType());
noElements = rewriter.create<MulIOp>(loc, noElements, dim);
}
}
}
rewriter.create<AffineStoreOp>(loc, noElements, alloc, llvm::None);
rewriter.replaceOp(op, alloc);
return success();
}
};
void populateLoweringONNXSizeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXSizeOpLowering>(ctx);
}

View File

@ -422,6 +422,12 @@ test_to_enable = [
# ConstantOfShape # ConstantOfShape
"test_constantofshape_float_ones_cpu", "test_constantofshape_float_ones_cpu",
# Size
# TODO(tjingrant): fix unit test for size ops.
# "test_size_cpu",
# "test_size_example_cpu",
# Error: # Error:
# Items are not equal: # Items are not equal:
# ACTUAL: dtype('int32') # ACTUAL: dtype('int32')

View File

@ -2135,6 +2135,42 @@ func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> {
// ----- // -----
func @test_size_known(%arg0: tensor<2x2xf32>) -> tensor<i64> {
%1 = "onnx.Size"(%arg0) : (tensor<2x2xf32>) -> tensor<i64>
"std.return"(%1) : (tensor<i64>) -> ()
// CHECK-LABEL: test_size_known
// CHECK: [[RES:%.+]] = alloc() : memref<i64>
// CHECK-NEXT [[SIZE:%.+]] = constant 4 : i64
// CHECK-NEXT affine.store [[SIZE]], [[RES]][] : memref<i64>
// CHECK-NEXT return [[RES]] : memref<i64>
}
// -----
func @test_size_unknown(%arg0 : tensor<?x2x?xf32>) -> tensor<i64> {
// CHECK-LABEL: test_size_unknown
// CHECK: [[RES:%.+]] = alloc() : memref<i64>
// CHECK-NEXT: [[INIT:%.+]] = constant 2 : i64
// CHECK-NEXT: [[IND1:%.+]] = constant 0 : index
// CHECK-NEXT: [[DIM1:%.+]] = dim %arg0, [[IND1]] : memref<?x2x?xf32>
// CHECK-NEXT: [[CAST1:%.+]] = index_cast [[DIM1]] : index to i64
// CHECK-NEXT: [[TMP1:%.+]] = muli [[INIT]], [[CAST1]] : i64
// CHECK-NEXT: [[IND2:%.+]] = constant 2 : index
// CHECK-NEXT: [[DIM2:%.+]] = dim %arg0, [[IND2]] : memref<?x2x?xf32>
// CHECK-NEXT: [[IND3:%.+]] = index_cast [[DIM2]] : index to i64
// CHECK-NEXT: [[SIZE:%.+]] = muli [[TMP1]], [[IND3]] : i64
// CHECK-NEXT: affine.store [[SIZE]], [[RES]][] : memref<i64>
// CHECK-NEXT: return [[RES]] : memref<i64>
%1 = "onnx.Size"(%arg0) : (tensor<?x2x?xf32>) -> tensor<i64>
"std.return"(%1) : (tensor<i64>) -> ()
}
// -----
// Test gather along axis 0, first example in ONNX for Gather. // Test gather along axis 0, first example in ONNX for Gather.
func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
%indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> %indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>