diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index 7684649..bab2dd3 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -24,6 +24,7 @@ add_library(OMONNXToKrnl Tensor/Concat.cpp Tensor/Split.cpp Tensor/Gather.cpp + Tensor/Size.cpp ConvertONNXToKrnl.cpp) target_link_libraries(OMONNXToKrnl onnx) diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 083202f..a559c40 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -103,6 +103,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { populateLoweringONNXConcatOpPattern(patterns, &getContext()); populateLoweringONNXSqueezeOpPattern(patterns, &getContext()); populateLoweringONNXSplitOpPattern(patterns, &getContext()); + populateLoweringONNXSizeOpPattern(patterns, &getContext()); // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index d39012e..aa3d848 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -253,6 +253,9 @@ void populateLoweringONNXSqueezeOpPattern( void populateLoweringONNXSplitOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); +void populateLoweringONNXSizeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + bool checkOpResultIsUsedByGetRef(AllocOp *allocOp); int64_t getMemRefSizeInBytes(Value val); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Size.cpp b/src/Conversion/ONNXToKrnl/Tensor/Size.cpp new file mode 100644 index 0000000..fff208b --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Tensor/Size.cpp @@ -0,0 +1,73 @@ +//===---------------- Size.cpp - Lowering Size Op +//-------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Size Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +using namespace mlir; + +struct ONNXSizeOpLowering : public ConversionPattern { + ONNXSizeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXSizeOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Gather info. + Location loc = op->getLoc(); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + ONNXSizeOp sizeOp = llvm::dyn_cast(op); + + ONNXSizeOpAdaptor operandAdaptor(operands); + Value data = operandAdaptor.data(); + ArrayRef dataShape = data.getType().cast().getShape(); + Value resultOperand = sizeOp.size(); + ValueRange indices; + MemRefType memRefType = convertToMemRefType(*op->result_type_begin()); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, {resultOperand}); + + // Accumulate static dimensions first. + int64_t staticNumElement = 1; + bool allStaticDimensions = true; + for (unsigned i = 0; i < dataShape.size(); i++) { + if (dataShape[i] != -1) + staticNumElement *= dataShape[i]; + else + allStaticDimensions = false; + } + // Accumulate the remaining dimensions that are unknown. + Value noElements = emitConstantOp( + rewriter, loc, memRefType.getElementType(), staticNumElement); + if (!allStaticDimensions) { + for (unsigned i = 0; i < dataShape.size(); i++) { + if (dataShape[i] == -1) { + Value index = rewriter.create(loc, data, i); + Value dim = rewriter.create( + loc, index, memRefType.getElementType()); + noElements = rewriter.create(loc, noElements, dim); + } + } + } + + rewriter.create(loc, noElements, alloc, llvm::None); + rewriter.replaceOp(op, alloc); + return success(); + } +}; + +void populateLoweringONNXSizeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/test/backend/test.py b/test/backend/test.py index 94c5523..a25f891 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -422,6 +422,12 @@ test_to_enable = [ # ConstantOfShape "test_constantofshape_float_ones_cpu", + + # Size + # TODO(tjingrant): fix unit test for size ops. + # "test_size_cpu", + # "test_size_example_cpu", + # Error: # Items are not equal: # ACTUAL: dtype('int32') diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 6519a42..0882f0d 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2135,6 +2135,42 @@ func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> { // ----- +func @test_size_known(%arg0: tensor<2x2xf32>) -> tensor { + %1 = "onnx.Size"(%arg0) : (tensor<2x2xf32>) -> tensor + "std.return"(%1) : (tensor) -> () + + // CHECK-LABEL: test_size_known + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK-NEXT [[SIZE:%.+]] = constant 4 : i64 + // CHECK-NEXT affine.store [[SIZE]], [[RES]][] : memref + // CHECK-NEXT return [[RES]] : memref + +} + +// ----- + +func @test_size_unknown(%arg0 : tensor) -> tensor { + + // CHECK-LABEL: test_size_unknown + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK-NEXT: [[INIT:%.+]] = constant 2 : i64 + // CHECK-NEXT: [[IND1:%.+]] = constant 0 : index + // CHECK-NEXT: [[DIM1:%.+]] = dim %arg0, [[IND1]] : memref + // CHECK-NEXT: [[CAST1:%.+]] = index_cast [[DIM1]] : index to i64 + // CHECK-NEXT: [[TMP1:%.+]] = muli [[INIT]], [[CAST1]] : i64 + // CHECK-NEXT: [[IND2:%.+]] = constant 2 : index + // CHECK-NEXT: [[DIM2:%.+]] = dim %arg0, [[IND2]] : memref + // CHECK-NEXT: [[IND3:%.+]] = index_cast [[DIM2]] : index to i64 + // CHECK-NEXT: [[SIZE:%.+]] = muli [[TMP1]], [[IND3]] : i64 + // CHECK-NEXT: affine.store [[SIZE]], [[RES]][] : memref + // CHECK-NEXT: return [[RES]] : memref + + %1 = "onnx.Size"(%arg0) : (tensor) -> tensor + "std.return"(%1) : (tensor) -> () +} + +// ----- + // Test gather along axis 0, first example in ONNX for Gather. func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { %indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>