From d86591d61ad27f2688ffe2d8f1678979a5af4d88 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Tue, 17 Mar 2020 00:17:28 +0900 Subject: [PATCH] Import all initialized tensors as dense constants (#30) * Import initialized tensor as dense attribute * Import all initialize tensors as dense constants * Remove unintentional code * Fix value attribute format in shape inference tests of reshape * Readd rank check for reshape's shape inference * Remove a redundant variable Co-authored-by: Gheorghe-Teodor Bercea --- src/builder/frontend_dialect_helper.cpp | 33 ++++++++++---------- src/builder/frontend_dialect_transformer.cpp | 18 +++++++---- src/dialect/onnx/onnx_ops.cpp | 17 ++++++---- test/mlir/onnx/onnx_shape_inference.mlir | 6 ++-- 4 files changed, 42 insertions(+), 32 deletions(-) diff --git a/src/builder/frontend_dialect_helper.cpp b/src/builder/frontend_dialect_helper.cpp index 1dbadf4..36e362a 100644 --- a/src/builder/frontend_dialect_helper.cpp +++ b/src/builder/frontend_dialect_helper.cpp @@ -129,10 +129,16 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( // Initializer for input. onnx::TensorProto initializer = GetInitializedTensor(name); + // Tensor dimensions. + llvm::ArrayRef tensorDims(initializer.dims().data(), + initializer.dims().size()); + // Emit ConstantOp and record the mapping between the input and // the constant value. - mlir::ArrayAttr constantArrayAttribute; + // Create value attribute. + mlir::DenseElementsAttr constantDenseAttribute; mlir::Type elementType; + mlir::ShapedType tensorType; int length; switch (initializer.data_type()) { case (onnx::TensorProto::FLOAT): { @@ -141,8 +147,9 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( std::vector arrayAttrInitializer( typeArray, typeArray + length); llvm::ArrayRef array(typeArray, length); - constantArrayAttribute = builder.getF32ArrayAttr(array); elementType = builder.getF32Type(); + tensorType = mlir::RankedTensorType::get(tensorDims, elementType); + constantDenseAttribute = mlir::DenseElementsAttr::get(tensorType, array); break; } case (onnx::TensorProto::INT32): { @@ -151,8 +158,9 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( std::vector arrayAttrInitializer( typeArray, typeArray + length); llvm::ArrayRef array(typeArray, length); - constantArrayAttribute = builder.getI32ArrayAttr(array); elementType = builder.getIntegerType(32); + tensorType = mlir::RankedTensorType::get(tensorDims, elementType); + constantDenseAttribute = mlir::DenseElementsAttr::get(tensorType, array); break; } case (onnx::TensorProto::INT64): { @@ -161,25 +169,16 @@ mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( std::vector arrayAttrInitializer( typeArray, typeArray + length); llvm::ArrayRef array(typeArray, length); - constantArrayAttribute = builder.getI64ArrayAttr(array); elementType = builder.getIntegerType(64); + tensorType = mlir::RankedTensorType::get(tensorDims, elementType); + constantDenseAttribute = mlir::DenseElementsAttr::get(tensorType, array); break; } } - // Create empty sparse_value attribute. - llvm::ArrayRef array; - auto sparseValueAttribute = builder.getI64ArrayAttr(array); - - // Create value attribute. - llvm::ArrayRef tensorDims(initializer.dims().data(), - initializer.dims().size()); - mlir::Type tensorType = - mlir::RankedTensorType::get(tensorDims, elementType); - + // Create ConstantOp for dense array. return builder.create( - loc, tensorType, sparseValueAttribute, - constantArrayAttribute); + loc, tensorType, nullptr, constantDenseAttribute); } -} // namespace onnf \ No newline at end of file +} // namespace onnf diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index a7a6d09..bec4094 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -274,8 +274,12 @@ private: int expectedNumResults = -1) { std::vector inputs; for (const auto &item : node.input()) - if (frontend_symbols_.ContainKey(legalize_name(item))) + if (initializedTensors.ContainKey(legalize_name(item))) { + inputs.push_back(initializedTensors.EmitInitializerForInputTensor( + UnknownLoc(), builder_, legalize_name(item))); + } else if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); + } buildOutputAndOperation(node, inputs, expectedNumOperands, expectedNumResults); @@ -287,7 +291,7 @@ private: for (int i = 0; i < node.input().size(); ++i) { item = node.input()[i]; // For the second argument, check if there exists an initializer. - if (i == 1 && initializedTensors.ContainKey(legalize_name(item))) { + if (initializedTensors.ContainKey(legalize_name(item))) { inputs.push_back( initializedTensors.EmitInitializerForInputTensor( UnknownLoc(), builder_, legalize_name(item))); @@ -412,9 +416,10 @@ private: // * maintain a list of the defined graph llvm::SmallVector arg_types; - // Import the input tensor types that are not constant. + // Import the input tensor types that are not constant and not initialized. for (const auto &input : graph.input()) - arg_types.emplace_back(ImportInputTensorType(input)); + if (!initializedTensors.ContainKey(legalize_name(input.name()))) + arg_types.emplace_back(ImportInputTensorType(input)); // Create the main function. auto funcType = builder_.getFunctionType(arg_types, {}); @@ -438,8 +443,9 @@ private: // Map graph inputs to entry block arguments. for (int i = 0; i < graph.input().size(); ++i) - ImportInputTensorSymbol( - graph.input()[i], entryBlock.getArguments()[i]); + if (!initializedTensors.ContainKey( + legalize_name(graph.input()[i].name()))) + ImportInputTensorSymbol(graph.input()[i], entryBlock.getArguments()[i]); // Create a NoneTyped constant to be used for optional operation inputs // which are not used. diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 82c37da..89e269d 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -876,12 +876,18 @@ void ONNXReshapeOp::inferShapes() { SmallVector dims(outputRank, -1); if (constantOp) { - // Cast attribute to ArrayAttr. - ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast(); - if (!valueAttribute) - emitError("ArrayAttr expected"); + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); - if (ArrayAttrSize(valueAttribute) != outputRank) + if (!valueAttribute) + emitError("DenseElementsAttr expected"); + + // Get dims from valueAttribute. + auto valueIt = valueAttribute.getValues().begin(); + for (int i=0; i().getInt(); + + if (valueIt != valueAttribute.getValues().end()) emitError("Constant value must have same rank as output"); int64_t numberOfDynamicInputs = 0; @@ -889,7 +895,6 @@ void ONNXReshapeOp::inferShapes() { int64_t dynamicValueIndex = -1; for (int i=0; i, %arg1 : tensor<4xi32>) } func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { - %0 = "onnx.Constant"() {sparse_value = [], value = [5, 5, 16, 2] } : () -> tensor<4xi32> + %0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi32> } : () -> tensor<4xi32> %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () @@ -445,7 +445,7 @@ func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { } func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { - %0 = "onnx.Constant"() {sparse_value = [], value = [-1, 16, 2] } : () -> tensor<3xi32> + %0 = "onnx.Constant"() {value = dense<[-1, 16, 2]> : tensor<3xi32> } : () -> tensor<3xi32> %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () @@ -455,7 +455,7 @@ func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { } func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { - %0 = "onnx.Constant"() {sparse_value = [], value = [-1, 0, 2] } : () -> tensor<3xi32> + %0 = "onnx.Constant"() {value = dense<[-1, 0, 2]> : tensor<3xi32> } : () -> tensor<3xi32> %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> ()