diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index a957897..0821b81 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -114,8 +114,38 @@ class FrontendGenImpl { mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } + // Convert type to MLIR type. + // A complete list of types can be found in: + // /third_party/onnx/onnx/onnx.pb.h mlir::Type TypeConvert(onnx::TensorProto_DataType intype) { - return builder_.getF32Type(); + switch (intype) { + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: + return builder_.getF16Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: + return builder_.getF32Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: + return builder_.getF64Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_INT8: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: + return builder_.getIntegerType(8); + case onnx::TensorProto_DataType::TensorProto_DataType_INT16: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: + return builder_.getIntegerType(16); + case onnx::TensorProto_DataType::TensorProto_DataType_INT32: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: + return builder_.getIntegerType(32); + case onnx::TensorProto_DataType::TensorProto_DataType_INT64: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: + return builder_.getIntegerType(64); + case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: + return builder_.getI1Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_STRING: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: + case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: + DLC_REQUIRE_MSG_CTX(false, "Unsupported data type encountered."); + return nullptr; + } } void ImportInputTensor(onnx::ValueInfoProto& input) { @@ -166,7 +196,8 @@ class FrontendGenImpl { if (OpName == "Add") { auto op = builder_.create(UnknownLoc(), inputs[0], inputs[1]); - frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult()); + frontend_symbols_.AddMapping( + legalize_name(node.output()[0]), op.getResult()); return; } @@ -187,8 +218,11 @@ class FrontendGenImpl { void ImportOutputTensor(onnx::ValueInfoProto& output) { if (frontend_symbols_.ContainKey(legalize_name(output.name()))) { - mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name()); - result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); + mlir::OperationState result( + UnknownLoc(), "frontend.output " + output.name()); + mlir::Type elementType = + TypeConvert(output.type().tensor_type().elem_type()); + result.addTypes(mlir::UnrankedTensorType::get(elementType)); result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name())); builder_.createOperation(result); } else {