modify builder (#214)

This commit is contained in:
chentong319 2020-07-21 22:05:18 -04:00 committed by GitHub
parent a58594ec81
commit f43f26a79c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 48 deletions

View File

@ -95,8 +95,8 @@ def ONNXAddOp:ONNX_Op<"Add",
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -106,8 +106,8 @@ def ONNXAddOp:ONNX_Op<"Add",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -146,8 +146,8 @@ def ONNXAndOp:ONNX_Op<"And",
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -157,8 +157,8 @@ def ONNXAndOp:ONNX_Op<"And",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -987,8 +987,8 @@ def ONNXDivOp:ONNX_Op<"Div",
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -998,8 +998,8 @@ def ONNXDivOp:ONNX_Op<"Div",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -1135,8 +1135,8 @@ def ONNXEqualOp:ONNX_Op<"Equal",
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -1146,8 +1146,8 @@ def ONNXEqualOp:ONNX_Op<"Equal",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -1803,8 +1803,8 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -1814,8 +1814,8 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -2207,8 +2207,8 @@ def ONNXLessOp:ONNX_Op<"Less",
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -2218,8 +2218,8 @@ def ONNXLessOp:ONNX_Op<"Less",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -2802,8 +2802,8 @@ def ONNXMulOp:ONNX_Op<"Mul",
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -2813,8 +2813,8 @@ def ONNXMulOp:ONNX_Op<"Mul",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -3020,8 +3020,8 @@ def ONNXOrOp:ONNX_Op<"Or",
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -3031,8 +3031,8 @@ def ONNXOrOp:ONNX_Op<"Or",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -3215,8 +3215,8 @@ def ONNXPowOp:ONNX_Op<"Pow",
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Z); let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Z);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{
auto lhsTy = X.getType().cast<RankedTensorType>(); auto lhsTy = X.getType();
auto rhsTy = Y.getType().cast<RankedTensorType>(); auto rhsTy = Y.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -3226,8 +3226,8 @@ def ONNXPowOp:ONNX_Op<"Pow",
build(builder, state, elementType, X, Y); build(builder, state, elementType, X, Y);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -5162,8 +5162,8 @@ def ONNXSubOp:ONNX_Op<"Sub",
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -5173,8 +5173,8 @@ def ONNXSubOp:ONNX_Op<"Sub",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -5629,8 +5629,8 @@ def ONNXXorOp:ONNX_Op<"Xor",
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>(); auto lhsTy = A.getType();
auto rhsTy = B.getType().cast<RankedTensorType>(); auto rhsTy = B.getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {
@ -5640,8 +5640,8 @@ def ONNXXorOp:ONNX_Op<"Xor",
build(builder, state, elementType, A, B); build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>(); auto lhsTy = operands[0].getType();
auto rhsTy = operands[1].getType().cast<RankedTensorType>(); auto rhsTy = operands[1].getType();
auto elementType = getBroadcastedType(lhsTy, rhsTy); auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>(); auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) { if (!shapedType || !shapedType.hasStaticShape()) {

View File

@ -851,9 +851,9 @@ def gen_op_def(schema):
build_type_name = '' build_type_name = ''
if schema.name in custom_builder_broadcast_ops_list: if schema.name in custom_builder_broadcast_ops_list:
second_operand_name = list(ins.items())[1][0] second_operand_name = list(ins.items())[1][0]
s += indent + 'auto lhsTy = {}.getType().cast<RankedTensorType>();\n'. \ s += indent + 'auto lhsTy = {}.getType();\n'. \
format(first_operand_name) format(first_operand_name)
s += indent + 'auto rhsTy = {}.getType().cast<RankedTensorType>();\n'. \ s += indent + 'auto rhsTy = {}.getType();\n'. \
format(second_operand_name) format(second_operand_name)
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n' s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n'; s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
@ -881,8 +881,8 @@ def gen_op_def(schema):
'ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n' 'ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
indent = inc_indent(indent) indent = inc_indent(indent)
if schema.name in custom_builder_broadcast_ops_list: if schema.name in custom_builder_broadcast_ops_list:
s += indent + 'auto lhsTy = operands[0].getType().cast<RankedTensorType>();\n' s += indent + 'auto lhsTy = operands[0].getType();\n'
s += indent + 'auto rhsTy = operands[1].getType().cast<RankedTensorType>();\n' s += indent + 'auto rhsTy = operands[1].getType();\n'
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n' s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n'; s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n'; s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';