diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index ea1b452..2c5e36f 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -818,7 +818,11 @@ def parse_type_str(allowedType): 'complex128' : 'Complex', 'string' : 'StringType'} - for key, item in onnx_to_mlir_type_dict.items(): + # Apply substitutions in decreasing order of key-length, so that float16 is replaced + # before float, and uint16 is replaced before int16, etc. + mapping = list(onnx_to_mlir_type_dict.items()) + mapping.sort(key=lambda pair:len(pair[0]), reverse=True) + for key, item in mapping: allowedType = allowedType.replace(key, item) return allowedType