diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 5ccfafa..9416b25 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -47,7 +47,7 @@ OpsWithShapeInference = [ 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', - 'Sign' + 'Sign', 'Constant' ] # Operations supporting canonicalization. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3fc1b38..488f77a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -91,6 +91,7 @@ add_library(onnf_lower_frontend conversion/onnx_to_krnl/tensor/padconstantvaluepad.cpp conversion/onnx_to_krnl/tensor/transpose.cpp conversion/onnx_to_krnl/tensor/unsqueeze.cpp + conversion/onnx_to_krnl/tensor/constant.cpp conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp) target_include_directories(onnf_lower_frontend PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} diff --git a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp index 9797b44..4499617 100644 --- a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp +++ b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp @@ -97,6 +97,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext()); populateLoweringONNXTransposeOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext()); + populateLoweringONNXConstantOpPattern(patterns, &getContext()); // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); diff --git a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp index 4c5413d..1706398 100644 --- a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp +++ b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp @@ -242,3 +242,7 @@ void populateLoweringONNXReshapeOpPattern( void populateLoweringONNXIdentityOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXConstantOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + diff --git a/src/conversion/onnx_to_krnl/tensor/constant.cpp b/src/conversion/onnx_to_krnl/tensor/constant.cpp new file mode 100644 index 0000000..bc90028 --- /dev/null +++ b/src/conversion/onnx_to_krnl/tensor/constant.cpp @@ -0,0 +1,100 @@ +//===---- constant.cpp - Lowering Constant Op -----------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Constant Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + +template +void emitConstantAndStoreOpForDenseElementsAttr( + ConversionPatternRewriter &rewriter, Location loc, + DenseElementsAttr constantValue, ArrayRef valueShape, + ArrayRef constantIndices, Value alloc) { + // The following functor recursively walks the dimensions of the constant + // shape, generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.getValues().begin(); + std::function storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create(loc, + rewriter.create(loc, *valueIt++), alloc, + llvm::makeArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); +} + +struct ONNXConstantOpLowering : public ConversionPattern { + ONNXConstantOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {} + + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + auto constantOp = llvm::dyn_cast(op); + + if (constantOp.sparse_value().hasValue()) { + emitError(loc, "Only support dense values at this time"); + } + + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + emitError(loc, "Unexpected output has non-Constant shape"); + + DenseElementsAttr constantValue = + constantOp.value().getValue().cast(); + + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back(rewriter.create(loc, i)); + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. + if (memRefType.getElementType().isa()) { + emitConstantAndStoreOpForDenseElementsAttr( + rewriter, loc, constantValue, valueShape, constantIndices, alloc); + } else if (memRefType.getElementType().isa()) { + emitConstantAndStoreOpForDenseElementsAttr( + rewriter, loc, constantValue, valueShape, constantIndices, alloc); + } else { + emitError(loc, "Unsupported output type"); + } + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXConstantOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index f71032b..fc00e93 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -1171,6 +1171,23 @@ void ONNXUnsqueezeOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, operandTy.getElementType())); } +//===----------------------------------------------------------------------===// +// Constant + +void ONNXConstantOp::inferShapes() { + if ((sparse_value().hasValue() && value().hasValue()) || + (!sparse_value().hasValue() && !value().hasValue())) + emitError("Require exactly one of the two attributes, either value or " + "sparse_value"); + + ElementsAttr valAttr; + if (sparse_value().hasValue()) + valAttr = sparse_valueAttr().cast(); + else + valAttr = valueAttr().cast(); + getResult().setType(valAttr.getType()); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 30f00bd..59e8f7a 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -340,7 +340,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", } def ONNXConstantOp:ONNX_Op<"Constant", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Constant operation"; let description = [{ "A constant tensor. Exactly one of the two attributes, either value or sparse_value," diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 59bee64..eab5c2e 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -120,6 +120,7 @@ public: op->getName().getStringRef() != "onnx.PadConstantPad" && op->getName().getStringRef() != "onnx.PadConstantValuePad" && op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && + op->getName().getStringRef() != "onnx.Constant" && op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 817935a..3d05789 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1535,3 +1535,27 @@ func @test_constant_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> { // CHECK: store [[LOAD]], [[RES]][%arg1, [[ADD]]] : memref<18x20xf32> // CHECK: } } + +func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + // CHECK-LABEL: test_constant_dense_2d_value + // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32> + // CHECK: %[[INDEX_0:.+]] = constant 0 : index + // CHECK: %[[INDEX_1:.+]] = constant 1 : index + // CHECK: %[[INDEX_2:.+]] = constant 2 : index + // CHECK: [[CONSTANT_0:%.+]] = constant 0.000000e+00 : f32 + // CHECK: affine.store [[CONSTANT_0]], %0[%[[INDEX_0]], %[[INDEX_0]]] : memref<3x2xf32> + // CHECK: [[CONSTANT_1:%.+]] = constant 0.000000e+00 : f32 + // CHECK: affine.store [[CONSTANT_1]], %0[%[[INDEX_0]], %[[INDEX_1]]] : memref<3x2xf32> + // CHECK: [[CONSTANT_2:%.+]] = constant 1.000000e+00 : f32 + // CHECK: affine.store [[CONSTANT_2]], %0[%[[INDEX_1]], %[[INDEX_0]]] : memref<3x2xf32> + // CHECK: [[CONSTANT_3:%.+]] = constant 1.100000e+00 : f32 + // CHECK: affine.store [[CONSTANT_3]], %0[%[[INDEX_1]], %[[INDEX_1]]] : memref<3x2xf32> + // CHECK: [[CONSTANT_4:%.+]] = constant 2.000000e+00 : f32 + // CHECK: affine.store [[CONSTANT_4]], %0[%[[INDEX_2]], %[[INDEX_0]]] : memref<3x2xf32> + // CHECK: [[CONSTANT_5:%.+]] = constant 2.100000e+00 : f32 + // CHECK: affine.store [[CONSTANT_5]], %0[%[[INDEX_2]], %[[INDEX_1]]] : memref<3x2xf32> + // CHECK: return [[RES]] : memref<3x2xf32> +} + diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 5a350e4..5ee339f 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -296,3 +296,38 @@ func @test_PadConstantPad_1(%arg0 : tensor<16x13xf32>, %arg1 : tensor<*xf32>) -> // CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 2, 3, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32> // CHECK: return [[RES]] : tensor<18x17xf32> +/// Test ConstantOp shape inference for 1-D dense tensor. +func @test_constant_dense_1d_value() -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} +// CHECK-LABEL: test_constant_dense_1d_value +// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32> +// CHECK: return [[RES]] : tensor<3xf32> + +/// Test ConstantOp shape inference for 2-D dense tensor. +func @test_constant_dense_2d_value() -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} +// CHECK-LABEL: test_constant_dense_2d_value +// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<{{\[}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00{{\]}}]> : tensor<3x2xf32>} : () -> tensor<3x2xf32> +// CHECK: return [[RES]] : tensor<3x2xf32> + +/// Test ConstantOp shape inference for 1-D sparse tensor. +func @test_constant_sparse_1d_value() -> tensor<*xf32> { + %0 = "onnx.Constant"() {sparse_value = sparse<[[0]], [1.0]> : tensor<3xf32>} : () -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} +// CHECK-LABEL: test_constant_sparse_1d_value +// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<0, 1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> +// CHECK: return [[RES]] : tensor<3xf32> + +/// Test ConstantOp shape inference for 2-D sparse tensor. +func @test_constant_sparse_2d_value() -> tensor<*xf32> { + %0 = "onnx.Constant"() {sparse_value = sparse<[[0, 1]], [2.0]> : tensor<3x2xf32>} : () -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} +// CHECK-LABEL: test_constant_sparse_2d_value +// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<{{\[}}[0, 1{{\]}}], 2.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32> +// CHECK: return [[RES]] : tensor<3x2xf32>