diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 5589f75..c7ae1b9 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -95,8 +95,8 @@ def ONNXAddOp:ONNX_Op<"Add", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -106,8 +106,8 @@ def ONNXAddOp:ONNX_Op<"Add", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -146,8 +146,8 @@ def ONNXAndOp:ONNX_Op<"And", let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -157,8 +157,8 @@ def ONNXAndOp:ONNX_Op<"And", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -987,8 +987,8 @@ def ONNXDivOp:ONNX_Op<"Div", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -998,8 +998,8 @@ def ONNXDivOp:ONNX_Op<"Div", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -1135,8 +1135,8 @@ def ONNXEqualOp:ONNX_Op<"Equal", let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -1146,8 +1146,8 @@ def ONNXEqualOp:ONNX_Op<"Equal", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -1803,8 +1803,8 @@ def ONNXGreaterOp:ONNX_Op<"Greater", let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -1814,8 +1814,8 @@ def ONNXGreaterOp:ONNX_Op<"Greater", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -2207,8 +2207,8 @@ def ONNXLessOp:ONNX_Op<"Less", let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -2218,8 +2218,8 @@ def ONNXLessOp:ONNX_Op<"Less", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -2802,8 +2802,8 @@ def ONNXMulOp:ONNX_Op<"Mul", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -2813,8 +2813,8 @@ def ONNXMulOp:ONNX_Op<"Mul", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -3020,8 +3020,8 @@ def ONNXOrOp:ONNX_Op<"Or", let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -3031,8 +3031,8 @@ def ONNXOrOp:ONNX_Op<"Or", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -3215,8 +3215,8 @@ def ONNXPowOp:ONNX_Op<"Pow", let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Z); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{ - auto lhsTy = X.getType().cast(); - auto rhsTy = Y.getType().cast(); + auto lhsTy = X.getType(); + auto rhsTy = Y.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -3226,8 +3226,8 @@ def ONNXPowOp:ONNX_Op<"Pow", build(builder, state, elementType, X, Y); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -5162,8 +5162,8 @@ def ONNXSubOp:ONNX_Op<"Sub", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -5173,8 +5173,8 @@ def ONNXSubOp:ONNX_Op<"Sub", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -5629,8 +5629,8 @@ def ONNXXorOp:ONNX_Op<"Xor", let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto lhsTy = A.getType().cast(); - auto rhsTy = B.getType().cast(); + auto lhsTy = A.getType(); + auto rhsTy = B.getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { @@ -5640,8 +5640,8 @@ def ONNXXorOp:ONNX_Op<"Xor", build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto lhsTy = operands[0].getType().cast(); - auto rhsTy = operands[1].getType().cast(); + auto lhsTy = operands[0].getType(); + auto rhsTy = operands[1].getType(); auto elementType = getBroadcastedType(lhsTy, rhsTy); auto shapedType = elementType.dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) { diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index da99203..7ebb8e7 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -851,9 +851,9 @@ def gen_op_def(schema): build_type_name = '' if schema.name in custom_builder_broadcast_ops_list: second_operand_name = list(ins.items())[1][0] - s += indent + 'auto lhsTy = {}.getType().cast();\n'. \ + s += indent + 'auto lhsTy = {}.getType();\n'. \ format(first_operand_name) - s += indent + 'auto rhsTy = {}.getType().cast();\n'. \ + s += indent + 'auto rhsTy = {}.getType();\n'. \ format(second_operand_name) s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n' s += indent + 'auto shapedType = elementType.dyn_cast_or_null();\n'; @@ -881,8 +881,8 @@ def gen_op_def(schema): 'ValueRange operands, ArrayRef attributes", [{\n' indent = inc_indent(indent) if schema.name in custom_builder_broadcast_ops_list: - s += indent + 'auto lhsTy = operands[0].getType().cast();\n' - s += indent + 'auto rhsTy = operands[1].getType().cast();\n' + s += indent + 'auto lhsTy = operands[0].getType();\n' + s += indent + 'auto rhsTy = operands[1].getType();\n' s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n' s += indent + 'auto shapedType = elementType.dyn_cast_or_null();\n'; s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';