[MLIR] Add conversion to MLIR types. (#360)
* Add conversion of tensor proto types to MLIR types. * Fix integer type conversion. Use new type conversion for ouput type.
This commit is contained in:
parent
626552f4a0
commit
a9960f6e44
|
@ -114,8 +114,38 @@ class FrontendGenImpl {
|
|||
|
||||
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
||||
|
||||
// Convert type to MLIR type.
|
||||
// A complete list of types can be found in:
|
||||
// <dlc-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
||||
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
|
||||
return builder_.getF32Type();
|
||||
switch (intype) {
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||
return builder_.getF16Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||
return builder_.getF32Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||
return builder_.getF64Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||
return builder_.getIntegerType(8);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||
return builder_.getIntegerType(16);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||
return builder_.getIntegerType(32);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||
return builder_.getIntegerType(64);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||
return builder_.getI1Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||
DLC_REQUIRE_MSG_CTX(false, "Unsupported data type encountered.");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void ImportInputTensor(onnx::ValueInfoProto& input) {
|
||||
|
@ -166,7 +196,8 @@ class FrontendGenImpl {
|
|||
if (OpName == "Add") {
|
||||
auto op =
|
||||
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult());
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -187,8 +218,11 @@ class FrontendGenImpl {
|
|||
|
||||
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(output.name()))) {
|
||||
mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name());
|
||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
mlir::OperationState result(
|
||||
UnknownLoc(), "frontend.output " + output.name());
|
||||
mlir::Type elementType =
|
||||
TypeConvert(output.type().tensor_type().elem_type());
|
||||
result.addTypes(mlir::UnrankedTensorType::get(elementType));
|
||||
result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name()));
|
||||
builder_.createOperation(result);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue