[MLIR] Add conversion to MLIR types. (#360)
* Add conversion of tensor proto types to MLIR types. * Fix integer type conversion. Use new type conversion for ouput type.
This commit is contained in:
parent
626552f4a0
commit
a9960f6e44
|
@ -114,8 +114,38 @@ class FrontendGenImpl {
|
||||||
|
|
||||||
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
||||||
|
|
||||||
|
// Convert type to MLIR type.
|
||||||
|
// A complete list of types can be found in:
|
||||||
|
// <dlc-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
||||||
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
|
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
|
||||||
|
switch (intype) {
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||||
|
return builder_.getF16Type();
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||||
return builder_.getF32Type();
|
return builder_.getF32Type();
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||||
|
return builder_.getF64Type();
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||||
|
return builder_.getIntegerType(8);
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||||
|
return builder_.getIntegerType(16);
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||||
|
return builder_.getIntegerType(32);
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||||
|
return builder_.getIntegerType(64);
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||||
|
return builder_.getI1Type();
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||||
|
DLC_REQUIRE_MSG_CTX(false, "Unsupported data type encountered.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportInputTensor(onnx::ValueInfoProto& input) {
|
void ImportInputTensor(onnx::ValueInfoProto& input) {
|
||||||
|
@ -166,7 +196,8 @@ class FrontendGenImpl {
|
||||||
if (OpName == "Add") {
|
if (OpName == "Add") {
|
||||||
auto op =
|
auto op =
|
||||||
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
|
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
|
||||||
frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult());
|
frontend_symbols_.AddMapping(
|
||||||
|
legalize_name(node.output()[0]), op.getResult());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,8 +218,11 @@ class FrontendGenImpl {
|
||||||
|
|
||||||
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(output.name()))) {
|
if (frontend_symbols_.ContainKey(legalize_name(output.name()))) {
|
||||||
mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name());
|
mlir::OperationState result(
|
||||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
UnknownLoc(), "frontend.output " + output.name());
|
||||||
|
mlir::Type elementType =
|
||||||
|
TypeConvert(output.type().tensor_type().elem_type());
|
||||||
|
result.addTypes(mlir::UnrankedTensorType::get(elementType));
|
||||||
result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name()));
|
result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name()));
|
||||||
builder_.createOperation(result);
|
builder_.createOperation(result);
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Reference in New Issue