Import graph output type from protobuf (#333)

* import output type

Signed-off-by: daquexian <daquexian566@gmail.com>

* rename input->value_info, update doc

Signed-off-by: daquexian <daquexian566@gmail.com>

* infer shape on return op inputs

Signed-off-by: daquexian <daquexian566@gmail.com>

* import output type from protobuf only if it has shape

Signed-off-by: daquexian <daquexian566@gmail.com>

* fix wrong gather test

Signed-off-by: daquexian <daquexian566@gmail.com>

* add comments

Signed-off-by: daquexian <daquexian566@gmail.com>
This commit is contained in:
daquexian 2020-10-03 23:21:15 +08:00 committed by GitHub
parent 0db735f48d
commit cb3d1e4f64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 18 deletions

View File

@ -61,20 +61,17 @@ private:
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
/*! /*!
* Import an onnx input tensor type by determining and recording its type * Import an onnx tensor type by determining and returning its type
* in a list of input tensor mlir types. * @param value_info onnx tensor ValueInfoProto.
* @param input onnx input tensor ValueInfoProto.
* @param arg_types list of mlir types representing types of graph input.
*/ */
mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) { mlir::Type ImportTensorType(const onnx::ValueInfoProto &value_info) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
auto shape_proto = input.type().tensor_type().shape(); auto shape_proto = value_info.type().tensor_type().shape();
auto input_tensor_legalized_name = legalize_name(input.name());
for (int i = 0; i < shape_proto.dim_size(); i++) { for (int i = 0; i < shape_proto.dim_size(); i++) {
if (shape_proto.dim()[i].dim_value()) { if (shape_proto.dim()[i].dim_value()) {
int dim_numeric_size = shape_proto.dim()[i].dim_value(); int dim_numeric_size = shape_proto.dim()[i].dim_value();
assert(dim_numeric_size != 0 && assert(dim_numeric_size != 0 &&
"Parsed an input tensor with a dimension size of zero"); "Parsed an tensor with a dimension size of zero");
if (dim_numeric_size > 0) { if (dim_numeric_size > 0) {
dims.push_back(dim_numeric_size); dims.push_back(dim_numeric_size);
} else { // If dim_value < 0, then dim is parametric. } else { // If dim_value < 0, then dim is parametric.
@ -88,7 +85,7 @@ private:
} }
auto elementOnnxType = auto elementOnnxType =
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type(); (onnx::TensorProto_DataType)value_info.type().tensor_type().elem_type();
mlir::Type elementType = mlir::Type elementType =
convertONNXTypeToMLIRType(builder_, elementOnnxType); convertONNXTypeToMLIRType(builder_, elementOnnxType);
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size()); llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
@ -532,6 +529,11 @@ private:
auto tensor_val = auto tensor_val =
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name); frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
if (output.type().value_case() == onnx::TypeProto::kTensorType) {
if (output.type().tensor_type().has_shape()) {
tensor_val.setType(ImportTensorType(output));
}
}
ret_types.emplace_back(tensor_val.getType()); ret_types.emplace_back(tensor_val.getType());
ret_vals.push_back(tensor_val); ret_vals.push_back(tensor_val);
} }
@ -560,7 +562,7 @@ private:
for (const auto &input : graph.input()) { for (const auto &input : graph.input()) {
if (!initializedTensors.ContainKey(legalize_name(input.name()))) { if (!initializedTensors.ContainKey(legalize_name(input.name()))) {
inputNames.push_back(input.name()); inputNames.push_back(input.name());
arg_types.emplace_back(ImportInputTensorType(input)); arg_types.emplace_back(ImportTensorType(input));
// numInputs is the number of graph inputs not contained within the // numInputs is the number of graph inputs not contained within the
// initializer // initializer
++numInputs; ++numInputs;

View File

@ -9,6 +9,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
@ -32,9 +33,13 @@ public:
auto f = getFunction(); auto f = getFunction();
// Iterate on the operations that need shape inference i.e the operations // Iterate on the operations that need shape inference i.e the operations
// that return a dynamic shape. // that return a dynamic shape or followed by a return op.
f.walk([&](mlir::Operation *op) { f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op)) { // The shape of graph output has been imported from onnx protobuf model,
// so the ops followed by a return op may not have dynamic shape output.
// However, shape inference is still need on these ops
// to infer optional attributes.
if (isUsedByReturnOp(op) || returnsDynamicShape(op)) {
if (auto shape_op = dyn_cast<ShapeInference>(op)) { if (auto shape_op = dyn_cast<ShapeInference>(op)) {
if (failed(shape_op.inferShapes())) { if (failed(shape_op.inferShapes())) {
op->emitError("shape inference failed"); op->emitError("shape inference failed");
@ -69,6 +74,15 @@ public:
} }
} }
static bool isUsedByReturnOp(Operation *op) {
for (auto *user : op->getUsers()) {
if (dyn_cast<ReturnOp>(user)) {
return true;
}
}
return false;
}
/*! /*!
* Check if the given operation has a dynamically shaped result. * Check if the given operation has a dynamically shaped result.
*/ */
@ -86,4 +100,4 @@ public:
*/ */
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() { std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>(); return std::make_unique<ShapeInferencePass>();
} }

View File

@ -2197,13 +2197,13 @@ func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
// ----- // -----
// Test gather along axis 1, second example in ONNX for Gather. // Test gather along axis 1, second example in ONNX for Gather.
func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> { func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<3x1x2xf32> {
%indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64> %indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64>
%0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x3x2xf32> %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
"std.return"(%0) : (tensor<1x3x2xf32>) -> () "std.return"(%0) : (tensor<3x1x2xf32>) -> ()
// CHECK-LABEL: test_gather_axis1 // CHECK-LABEL: test_gather_axis1
// CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32> // CHECK: [[ALLOC:%.+]] = alloc() : memref<3x1x2xf32>
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64> // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3 // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
// CHECK: [[ZERO:%.+]] = constant 0 : index // CHECK: [[ZERO:%.+]] = constant 0 : index
@ -2216,7 +2216,7 @@ func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
// CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index // CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index
// CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index // CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32> // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32>
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32> // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<3x1x2xf32>
} }
// ----- // -----