From a9960f6e4438cf473fc06bdc2e2b82fd1c6f73bf Mon Sep 17 00:00:00 2001 From: GHEORGHE-TEOD BERCEA Date: Tue, 5 Nov 2019 17:03:15 -0500 Subject: [PATCH] [MLIR] Add conversion to MLIR types. (#360) * Add conversion of tensor proto types to MLIR types. * Fix integer type conversion. Use new type conversion for ouput type. --- src/builder/frontend_dialect_transformer.cpp | 42 ++++++++++++++++++-- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index a957897..0821b81 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -114,8 +114,38 @@ class FrontendGenImpl { mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } + // Convert type to MLIR type. + // A complete list of types can be found in: + // /third_party/onnx/onnx/onnx.pb.h mlir::Type TypeConvert(onnx::TensorProto_DataType intype) { - return builder_.getF32Type(); + switch (intype) { + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: + return builder_.getF16Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: + return builder_.getF32Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: + return builder_.getF64Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_INT8: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: + return builder_.getIntegerType(8); + case onnx::TensorProto_DataType::TensorProto_DataType_INT16: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: + return builder_.getIntegerType(16); + case onnx::TensorProto_DataType::TensorProto_DataType_INT32: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: + return builder_.getIntegerType(32); + case onnx::TensorProto_DataType::TensorProto_DataType_INT64: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: + return builder_.getIntegerType(64); + case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: + return builder_.getI1Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_STRING: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: + case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: + DLC_REQUIRE_MSG_CTX(false, "Unsupported data type encountered."); + return nullptr; + } } void ImportInputTensor(onnx::ValueInfoProto& input) { @@ -166,7 +196,8 @@ class FrontendGenImpl { if (OpName == "Add") { auto op = builder_.create(UnknownLoc(), inputs[0], inputs[1]); - frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult()); + frontend_symbols_.AddMapping( + legalize_name(node.output()[0]), op.getResult()); return; } @@ -187,8 +218,11 @@ class FrontendGenImpl { void ImportOutputTensor(onnx::ValueInfoProto& output) { if (frontend_symbols_.ContainKey(legalize_name(output.name()))) { - mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name()); - result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); + mlir::OperationState result( + UnknownLoc(), "frontend.output " + output.name()); + mlir::Type elementType = + TypeConvert(output.type().tensor_type().elem_type()); + result.addTypes(mlir::UnrankedTensorType::get(elementType)); result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name())); builder_.createOperation(result); } else {