Import graph output type from protobuf (#333)
* import output type Signed-off-by: daquexian <daquexian566@gmail.com> * rename input->value_info, update doc Signed-off-by: daquexian <daquexian566@gmail.com> * infer shape on return op inputs Signed-off-by: daquexian <daquexian566@gmail.com> * import output type from protobuf only if it has shape Signed-off-by: daquexian <daquexian566@gmail.com> * fix wrong gather test Signed-off-by: daquexian <daquexian566@gmail.com> * add comments Signed-off-by: daquexian <daquexian566@gmail.com>
This commit is contained in:
		
							parent
							
								
									0db735f48d
								
							
						
					
					
						commit
						cb3d1e4f64
					
				| 
						 | 
					@ -61,20 +61,17 @@ private:
 | 
				
			||||||
  mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
 | 
					  mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /*!
 | 
					  /*!
 | 
				
			||||||
   * Import an onnx input tensor type by determining and recording its type
 | 
					   * Import an onnx tensor type by determining and returning its type
 | 
				
			||||||
   * in a list of input tensor mlir types.
 | 
					   * @param value_info onnx tensor ValueInfoProto.
 | 
				
			||||||
   * @param input onnx input tensor ValueInfoProto.
 | 
					 | 
				
			||||||
   * @param arg_types list of mlir types representing types of graph input.
 | 
					 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) {
 | 
					  mlir::Type ImportTensorType(const onnx::ValueInfoProto &value_info) {
 | 
				
			||||||
    std::vector<int64_t> dims;
 | 
					    std::vector<int64_t> dims;
 | 
				
			||||||
    auto shape_proto = input.type().tensor_type().shape();
 | 
					    auto shape_proto = value_info.type().tensor_type().shape();
 | 
				
			||||||
    auto input_tensor_legalized_name = legalize_name(input.name());
 | 
					 | 
				
			||||||
    for (int i = 0; i < shape_proto.dim_size(); i++) {
 | 
					    for (int i = 0; i < shape_proto.dim_size(); i++) {
 | 
				
			||||||
      if (shape_proto.dim()[i].dim_value()) {
 | 
					      if (shape_proto.dim()[i].dim_value()) {
 | 
				
			||||||
        int dim_numeric_size = shape_proto.dim()[i].dim_value();
 | 
					        int dim_numeric_size = shape_proto.dim()[i].dim_value();
 | 
				
			||||||
        assert(dim_numeric_size != 0 &&
 | 
					        assert(dim_numeric_size != 0 &&
 | 
				
			||||||
               "Parsed an input tensor with a dimension size of zero");
 | 
					               "Parsed an tensor with a dimension size of zero");
 | 
				
			||||||
        if (dim_numeric_size > 0) {
 | 
					        if (dim_numeric_size > 0) {
 | 
				
			||||||
          dims.push_back(dim_numeric_size);
 | 
					          dims.push_back(dim_numeric_size);
 | 
				
			||||||
        } else { // If dim_value < 0, then dim is parametric.
 | 
					        } else { // If dim_value < 0, then dim is parametric.
 | 
				
			||||||
| 
						 | 
					@ -88,7 +85,7 @@ private:
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto elementOnnxType =
 | 
					    auto elementOnnxType =
 | 
				
			||||||
        (onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
 | 
					        (onnx::TensorProto_DataType)value_info.type().tensor_type().elem_type();
 | 
				
			||||||
    mlir::Type elementType =
 | 
					    mlir::Type elementType =
 | 
				
			||||||
        convertONNXTypeToMLIRType(builder_, elementOnnxType);
 | 
					        convertONNXTypeToMLIRType(builder_, elementOnnxType);
 | 
				
			||||||
    llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
 | 
					    llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
 | 
				
			||||||
| 
						 | 
					@ -532,6 +529,11 @@ private:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto tensor_val =
 | 
					    auto tensor_val =
 | 
				
			||||||
        frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
 | 
					        frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
 | 
				
			||||||
 | 
					    if (output.type().value_case() == onnx::TypeProto::kTensorType) {
 | 
				
			||||||
 | 
					      if (output.type().tensor_type().has_shape()) {
 | 
				
			||||||
 | 
					        tensor_val.setType(ImportTensorType(output));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    ret_types.emplace_back(tensor_val.getType());
 | 
					    ret_types.emplace_back(tensor_val.getType());
 | 
				
			||||||
    ret_vals.push_back(tensor_val);
 | 
					    ret_vals.push_back(tensor_val);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -560,7 +562,7 @@ private:
 | 
				
			||||||
    for (const auto &input : graph.input()) {
 | 
					    for (const auto &input : graph.input()) {
 | 
				
			||||||
      if (!initializedTensors.ContainKey(legalize_name(input.name()))) {
 | 
					      if (!initializedTensors.ContainKey(legalize_name(input.name()))) {
 | 
				
			||||||
        inputNames.push_back(input.name());
 | 
					        inputNames.push_back(input.name());
 | 
				
			||||||
        arg_types.emplace_back(ImportInputTensorType(input));
 | 
					        arg_types.emplace_back(ImportTensorType(input));
 | 
				
			||||||
        // numInputs is the number of graph inputs not contained within the
 | 
					        // numInputs is the number of graph inputs not contained within the
 | 
				
			||||||
        // initializer
 | 
					        // initializer
 | 
				
			||||||
        ++numInputs;
 | 
					        ++numInputs;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,6 +9,7 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
#include "mlir/IR/StandardTypes.h"
 | 
					#include "mlir/IR/StandardTypes.h"
 | 
				
			||||||
#include "mlir/Pass/Pass.h"
 | 
					#include "mlir/Pass/Pass.h"
 | 
				
			||||||
#include "llvm/ADT/SmallPtrSet.h"
 | 
					#include "llvm/ADT/SmallPtrSet.h"
 | 
				
			||||||
| 
						 | 
					@ -32,9 +33,13 @@ public:
 | 
				
			||||||
    auto f = getFunction();
 | 
					    auto f = getFunction();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Iterate on the operations that need shape inference i.e the operations
 | 
					    // Iterate on the operations that need shape inference i.e the operations
 | 
				
			||||||
    // that return a dynamic shape.
 | 
					    // that return a dynamic shape or followed by a return op.
 | 
				
			||||||
    f.walk([&](mlir::Operation *op) {
 | 
					    f.walk([&](mlir::Operation *op) {
 | 
				
			||||||
      if (returnsDynamicShape(op)) {
 | 
					      // The shape of graph output has been imported from onnx protobuf model,
 | 
				
			||||||
 | 
					      // so the ops followed by a return op may not have dynamic shape output.
 | 
				
			||||||
 | 
					      // However, shape inference is still need on these ops
 | 
				
			||||||
 | 
					      // to infer optional attributes.
 | 
				
			||||||
 | 
					      if (isUsedByReturnOp(op) || returnsDynamicShape(op)) {
 | 
				
			||||||
        if (auto shape_op = dyn_cast<ShapeInference>(op)) {
 | 
					        if (auto shape_op = dyn_cast<ShapeInference>(op)) {
 | 
				
			||||||
          if (failed(shape_op.inferShapes())) {
 | 
					          if (failed(shape_op.inferShapes())) {
 | 
				
			||||||
            op->emitError("shape inference failed");
 | 
					            op->emitError("shape inference failed");
 | 
				
			||||||
| 
						 | 
					@ -69,6 +74,15 @@ public:
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static bool isUsedByReturnOp(Operation *op) {
 | 
				
			||||||
 | 
					    for (auto *user : op->getUsers()) {
 | 
				
			||||||
 | 
					      if (dyn_cast<ReturnOp>(user)) {
 | 
				
			||||||
 | 
					        return true;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return false;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /*!
 | 
					  /*!
 | 
				
			||||||
   *  Check if the given operation has a dynamically shaped result.
 | 
					   *  Check if the given operation has a dynamically shaped result.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2197,13 +2197,13 @@ func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Test gather along axis 1, second example in ONNX for Gather.
 | 
					// Test gather along axis 1, second example in ONNX for Gather.
 | 
				
			||||||
func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
 | 
					func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<3x1x2xf32> {
 | 
				
			||||||
  %indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64>
 | 
					  %indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64>
 | 
				
			||||||
  %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x3x2xf32>
 | 
					  %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
 | 
				
			||||||
  "std.return"(%0) : (tensor<1x3x2xf32>) -> ()
 | 
					  "std.return"(%0) : (tensor<3x1x2xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_gather_axis1
 | 
					  // CHECK-LABEL: test_gather_axis1
 | 
				
			||||||
  // CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32>
 | 
					  // CHECK: [[ALLOC:%.+]] = alloc() : memref<3x1x2xf32>
 | 
				
			||||||
  // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
 | 
					  // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
 | 
				
			||||||
  // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
 | 
					  // CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
 | 
				
			||||||
  // CHECK: [[ZERO:%.+]] = constant 0 : index
 | 
					  // CHECK: [[ZERO:%.+]] = constant 0 : index
 | 
				
			||||||
| 
						 | 
					@ -2216,7 +2216,7 @@ func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
 | 
				
			||||||
  // CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index
 | 
					  // CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index
 | 
				
			||||||
  // CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index
 | 
					  // CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index
 | 
				
			||||||
  // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32>
 | 
					  // CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32>
 | 
				
			||||||
  // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32>
 | 
					  // CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<3x1x2xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue