[MLIR] Add conversion to MLIR types. (#360)

* Add conversion of tensor proto types to MLIR types.

* Fix integer type conversion. Use new type conversion for ouput type.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-05 17:03:15 -05:00 committed by Doru Bercea
parent 626552f4a0
commit a9960f6e44
1 changed files with 38 additions and 4 deletions

View File

@ -114,8 +114,38 @@ class FrontendGenImpl {
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
// Convert type to MLIR type.
// A complete list of types can be found in:
// <dlc-build-folder>/third_party/onnx/onnx/onnx.pb.h
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) { mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
switch (intype) {
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return builder_.getF16Type();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
return builder_.getF32Type(); return builder_.getF32Type();
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return builder_.getF64Type();
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(8);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(16);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(32);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(64);
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
DLC_REQUIRE_MSG_CTX(false, "Unsupported data type encountered.");
return nullptr;
}
} }
void ImportInputTensor(onnx::ValueInfoProto& input) { void ImportInputTensor(onnx::ValueInfoProto& input) {
@ -166,7 +196,8 @@ class FrontendGenImpl {
if (OpName == "Add") { if (OpName == "Add") {
auto op = auto op =
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]); builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult()); frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
return; return;
} }
@ -187,8 +218,11 @@ class FrontendGenImpl {
void ImportOutputTensor(onnx::ValueInfoProto& output) { void ImportOutputTensor(onnx::ValueInfoProto& output) {
if (frontend_symbols_.ContainKey(legalize_name(output.name()))) { if (frontend_symbols_.ContainKey(legalize_name(output.name()))) {
mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name()); mlir::OperationState result(
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); UnknownLoc(), "frontend.output " + output.name());
mlir::Type elementType =
TypeConvert(output.type().tensor_type().elem_type());
result.addTypes(mlir::UnrankedTensorType::get(elementType));
result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name())); result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name()));
builder_.createOperation(result); builder_.createOperation(result);
} else { } else {