Import graph output type from protobuf (#333)
* import output type Signed-off-by: daquexian <daquexian566@gmail.com> * rename input->value_info, update doc Signed-off-by: daquexian <daquexian566@gmail.com> * infer shape on return op inputs Signed-off-by: daquexian <daquexian566@gmail.com> * import output type from protobuf only if it has shape Signed-off-by: daquexian <daquexian566@gmail.com> * fix wrong gather test Signed-off-by: daquexian <daquexian566@gmail.com> * add comments Signed-off-by: daquexian <daquexian566@gmail.com>
This commit is contained in:
parent
0db735f48d
commit
cb3d1e4f64
|
@ -61,20 +61,17 @@ private:
|
|||
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
||||
|
||||
/*!
|
||||
* Import an onnx input tensor type by determining and recording its type
|
||||
* in a list of input tensor mlir types.
|
||||
* @param input onnx input tensor ValueInfoProto.
|
||||
* @param arg_types list of mlir types representing types of graph input.
|
||||
* Import an onnx tensor type by determining and returning its type
|
||||
* @param value_info onnx tensor ValueInfoProto.
|
||||
*/
|
||||
mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) {
|
||||
mlir::Type ImportTensorType(const onnx::ValueInfoProto &value_info) {
|
||||
std::vector<int64_t> dims;
|
||||
auto shape_proto = input.type().tensor_type().shape();
|
||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||
auto shape_proto = value_info.type().tensor_type().shape();
|
||||
for (int i = 0; i < shape_proto.dim_size(); i++) {
|
||||
if (shape_proto.dim()[i].dim_value()) {
|
||||
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
||||
assert(dim_numeric_size != 0 &&
|
||||
"Parsed an input tensor with a dimension size of zero");
|
||||
"Parsed an tensor with a dimension size of zero");
|
||||
if (dim_numeric_size > 0) {
|
||||
dims.push_back(dim_numeric_size);
|
||||
} else { // If dim_value < 0, then dim is parametric.
|
||||
|
@ -88,7 +85,7 @@ private:
|
|||
}
|
||||
|
||||
auto elementOnnxType =
|
||||
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
||||
(onnx::TensorProto_DataType)value_info.type().tensor_type().elem_type();
|
||||
mlir::Type elementType =
|
||||
convertONNXTypeToMLIRType(builder_, elementOnnxType);
|
||||
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
||||
|
@ -532,6 +529,11 @@ private:
|
|||
|
||||
auto tensor_val =
|
||||
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
|
||||
if (output.type().value_case() == onnx::TypeProto::kTensorType) {
|
||||
if (output.type().tensor_type().has_shape()) {
|
||||
tensor_val.setType(ImportTensorType(output));
|
||||
}
|
||||
}
|
||||
ret_types.emplace_back(tensor_val.getType());
|
||||
ret_vals.push_back(tensor_val);
|
||||
}
|
||||
|
@ -560,7 +562,7 @@ private:
|
|||
for (const auto &input : graph.input()) {
|
||||
if (!initializedTensors.ContainKey(legalize_name(input.name()))) {
|
||||
inputNames.push_back(input.name());
|
||||
arg_types.emplace_back(ImportInputTensorType(input));
|
||||
arg_types.emplace_back(ImportTensorType(input));
|
||||
// numInputs is the number of graph inputs not contained within the
|
||||
// initializer
|
||||
++numInputs;
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
@ -32,9 +33,13 @@ public:
|
|||
auto f = getFunction();
|
||||
|
||||
// Iterate on the operations that need shape inference i.e the operations
|
||||
// that return a dynamic shape.
|
||||
// that return a dynamic shape or followed by a return op.
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (returnsDynamicShape(op)) {
|
||||
// The shape of graph output has been imported from onnx protobuf model,
|
||||
// so the ops followed by a return op may not have dynamic shape output.
|
||||
// However, shape inference is still need on these ops
|
||||
// to infer optional attributes.
|
||||
if (isUsedByReturnOp(op) || returnsDynamicShape(op)) {
|
||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||
if (failed(shape_op.inferShapes())) {
|
||||
op->emitError("shape inference failed");
|
||||
|
@ -69,6 +74,15 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
static bool isUsedByReturnOp(Operation *op) {
|
||||
for (auto *user : op->getUsers()) {
|
||||
if (dyn_cast<ReturnOp>(user)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check if the given operation has a dynamically shaped result.
|
||||
*/
|
||||
|
@ -86,4 +100,4 @@ public:
|
|||
*/
|
||||
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
||||
return std::make_unique<ShapeInferencePass>();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2197,13 +2197,13 @@ func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
|
|||
// -----
|
||||
|
||||
// Test gather along axis 1, second example in ONNX for Gather.
|
||||
func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
|
||||
func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<3x1x2xf32> {
|
||||
%indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64>
|
||||
%0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x3x2xf32>
|
||||
"std.return"(%0) : (tensor<1x3x2xf32>) -> ()
|
||||
%0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
|
||||
"std.return"(%0) : (tensor<3x1x2xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_gather_axis1
|
||||
// CHECK: [[ALLOC:%.+]] = alloc() : memref<1x3x2xf32>
|
||||
// CHECK: [[ALLOC:%.+]] = alloc() : memref<3x1x2xf32>
|
||||
// CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<{{\[+}}0, 2{{\]+}}> : tensor<1x2xi64>} : () -> memref<1x2xi64>
|
||||
// CHECK: [[LOOP:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[ZERO:%.+]] = constant 0 : index
|
||||
|
@ -2216,7 +2216,7 @@ func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<1x3x2xf32> {
|
|||
// CHECK: [[CMP:%.+]] = cmpi "slt", [[AFFINE2]], [[ZERO]] : index
|
||||
// CHECK: [[AFFINE4:%.+]] = select [[CMP]], [[AFFINE3]], [[AFFINE2]] : index
|
||||
// CHECK: [[DATA:%.+]] = load %arg0{{.}}[[ARG1]], [[AFFINE4]]{{.}} : memref<3x3xf32>
|
||||
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<1x3x2xf32>
|
||||
// CHECK: affine.store [[DATA]], [[ALLOC]]{{.}}[[ARG1]], [[ARG2]], [[ARG3]]{{.}} : memref<3x1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue