From cb3d1e4f64b36b899effb3e0beea780f7365e8ae Mon Sep 17 00:00:00 2001 From: daquexian Date: Sat, 3 Oct 2020 23:21:15 +0800 Subject: [PATCH] Import graph output type from protobuf (#333) * import output type Signed-off-by: daquexian * rename input->value_info, update doc Signed-off-by: daquexian * infer shape on return op inputs Signed-off-by: daquexian * import output type from protobuf only if it has shape Signed-off-by: daquexian * fix wrong gather test Signed-off-by: daquexian * add comments Signed-off-by: daquexian --- src/Builder/FrontendDialectTransformer.cpp | 22 ++++++++++++---------- src/Transform/ONNX/ShapeInferencePass.cpp | 20 +++++++++++++++++--- test/mlir/onnx/onnx_lowering.mlir | 10 +++++----- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 5198aa5..a71de9d 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -61,20 +61,17 @@ private: mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } /*! - * Import an onnx input tensor type by determining and recording its type - * in a list of input tensor mlir types. - * @param input onnx input tensor ValueInfoProto. - * @param arg_types list of mlir types representing types of graph input. + * Import an onnx tensor type by determining and returning its type + * @param value_info onnx tensor ValueInfoProto. */ - mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) { + mlir::Type ImportTensorType(const onnx::ValueInfoProto &value_info) { std::vector dims; - auto shape_proto = input.type().tensor_type().shape(); - auto input_tensor_legalized_name = legalize_name(input.name()); + auto shape_proto = value_info.type().tensor_type().shape(); for (int i = 0; i < shape_proto.dim_size(); i++) { if (shape_proto.dim()[i].dim_value()) { int dim_numeric_size = shape_proto.dim()[i].dim_value(); assert(dim_numeric_size != 0 && - "Parsed an input tensor with a dimension size of zero"); + "Parsed an tensor with a dimension size of zero"); if (dim_numeric_size > 0) { dims.push_back(dim_numeric_size); } else { // If dim_value < 0, then dim is parametric. @@ -88,7 +85,7 @@ private: } auto elementOnnxType = - (onnx::TensorProto_DataType)input.type().tensor_type().elem_type(); + (onnx::TensorProto_DataType)value_info.type().tensor_type().elem_type(); mlir::Type elementType = convertONNXTypeToMLIRType(builder_, elementOnnxType); llvm::ArrayRef tensor_dims(dims.data(), dims.size()); @@ -532,6 +529,11 @@ private: auto tensor_val = frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name); + if (output.type().value_case() == onnx::TypeProto::kTensorType) { + if (output.type().tensor_type().has_shape()) { + tensor_val.setType(ImportTensorType(output)); + } + } ret_types.emplace_back(tensor_val.getType()); ret_vals.push_back(tensor_val); } @@ -560,7 +562,7 @@ private: for (const auto &input : graph.input()) { if (!initializedTensors.ContainKey(legalize_name(input.name()))) { inputNames.push_back(input.name()); - arg_types.emplace_back(ImportInputTensorType(input)); + arg_types.emplace_back(ImportTensorType(input)); // numInputs is the number of graph inputs not contained within the // initializer ++numInputs; diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 232f005..4be90b1 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -9,6 +9,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallPtrSet.h" @@ -32,9 +33,13 @@ public: auto f = getFunction(); // Iterate on the operations that need shape inference i.e the operations - // that return a dynamic shape. + // that return a dynamic shape or followed by a return op. f.walk([&](mlir::Operation *op) { - if (returnsDynamicShape(op)) { + // The shape of graph output has been imported from onnx protobuf model, + // so the ops followed by a return op may not have dynamic shape output. + // However, shape inference is still need on these ops + // to infer optional attributes. + if (isUsedByReturnOp(op) || returnsDynamicShape(op)) { if (auto shape_op = dyn_cast(op)) { if (failed(shape_op.inferShapes())) { op->emitError("shape inference failed"); @@ -69,6 +74,15 @@ public: } } + static bool isUsedByReturnOp(Operation *op) { + for (auto *user : op->getUsers()) { + if (dyn_cast(user)) { + return true; + } + } + return false; + } + /*! * Check if the given operation has a dynamically shaped result. */ @@ -86,4 +100,4 @@ public: */ std::unique_ptr mlir::createShapeInferencePass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 0c30e44..de3511d 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2197,13 +2197,13 @@ func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { // ----- // Test gather along axis 1, second example in ONNX for Gather. -func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> { +func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<3x1x2xf32> { %indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64> - %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x3x2xf32> - "std.return"(%0) : (tensor<1x3x2xf32>) -> () + %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> + "std.return"(%0) : (tensor<3x1x2xf32>) -> () // CHECK-LABEL: test_gather_axis1 - // CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32> + // CHECK: [[ALLOC:%.+]] = alloc() : memref<3x1x2xf32> // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64> // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3 // CHECK: [[ZERO:%.+]] = constant 0 : index @@ -2216,7 +2216,7 @@ func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> { // CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index // CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32> - // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32> + // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<3x1x2xf32> } // -----