modify builder (#214)
This commit is contained in:
parent
a58594ec81
commit
f43f26a79c
|
@ -95,8 +95,8 @@ def ONNXAddOp:ONNX_Op<"Add",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -106,8 +106,8 @@ def ONNXAddOp:ONNX_Op<"Add",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -146,8 +146,8 @@ def ONNXAndOp:ONNX_Op<"And",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -157,8 +157,8 @@ def ONNXAndOp:ONNX_Op<"And",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -987,8 +987,8 @@ def ONNXDivOp:ONNX_Op<"Div",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -998,8 +998,8 @@ def ONNXDivOp:ONNX_Op<"Div",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -1135,8 +1135,8 @@ def ONNXEqualOp:ONNX_Op<"Equal",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -1146,8 +1146,8 @@ def ONNXEqualOp:ONNX_Op<"Equal",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -1803,8 +1803,8 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -1814,8 +1814,8 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -2207,8 +2207,8 @@ def ONNXLessOp:ONNX_Op<"Less",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -2218,8 +2218,8 @@ def ONNXLessOp:ONNX_Op<"Less",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -2802,8 +2802,8 @@ def ONNXMulOp:ONNX_Op<"Mul",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -2813,8 +2813,8 @@ def ONNXMulOp:ONNX_Op<"Mul",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -3020,8 +3020,8 @@ def ONNXOrOp:ONNX_Op<"Or",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -3031,8 +3031,8 @@ def ONNXOrOp:ONNX_Op<"Or",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -3215,8 +3215,8 @@ def ONNXPowOp:ONNX_Op<"Pow",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Z);
|
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Z);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{
|
||||||
auto lhsTy = X.getType().cast<RankedTensorType>();
|
auto lhsTy = X.getType();
|
||||||
auto rhsTy = Y.getType().cast<RankedTensorType>();
|
auto rhsTy = Y.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -3226,8 +3226,8 @@ def ONNXPowOp:ONNX_Op<"Pow",
|
||||||
build(builder, state, elementType, X, Y);
|
build(builder, state, elementType, X, Y);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -5162,8 +5162,8 @@ def ONNXSubOp:ONNX_Op<"Sub",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -5173,8 +5173,8 @@ def ONNXSubOp:ONNX_Op<"Sub",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -5629,8 +5629,8 @@ def ONNXXorOp:ONNX_Op<"Xor",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, AnyMemRef]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
auto lhsTy = A.getType();
|
||||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
auto rhsTy = B.getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
@ -5640,8 +5640,8 @@ def ONNXXorOp:ONNX_Op<"Xor",
|
||||||
build(builder, state, elementType, A, B);
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
auto lhsTy = operands[0].getType();
|
||||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
auto rhsTy = operands[1].getType();
|
||||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
|
|
@ -851,9 +851,9 @@ def gen_op_def(schema):
|
||||||
build_type_name = ''
|
build_type_name = ''
|
||||||
if schema.name in custom_builder_broadcast_ops_list:
|
if schema.name in custom_builder_broadcast_ops_list:
|
||||||
second_operand_name = list(ins.items())[1][0]
|
second_operand_name = list(ins.items())[1][0]
|
||||||
s += indent + 'auto lhsTy = {}.getType().cast<RankedTensorType>();\n'. \
|
s += indent + 'auto lhsTy = {}.getType();\n'. \
|
||||||
format(first_operand_name)
|
format(first_operand_name)
|
||||||
s += indent + 'auto rhsTy = {}.getType().cast<RankedTensorType>();\n'. \
|
s += indent + 'auto rhsTy = {}.getType();\n'. \
|
||||||
format(second_operand_name)
|
format(second_operand_name)
|
||||||
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
||||||
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
||||||
|
@ -881,8 +881,8 @@ def gen_op_def(schema):
|
||||||
'ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
'ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
||||||
indent = inc_indent(indent)
|
indent = inc_indent(indent)
|
||||||
if schema.name in custom_builder_broadcast_ops_list:
|
if schema.name in custom_builder_broadcast_ops_list:
|
||||||
s += indent + 'auto lhsTy = operands[0].getType().cast<RankedTensorType>();\n'
|
s += indent + 'auto lhsTy = operands[0].getType();\n'
|
||||||
s += indent + 'auto rhsTy = operands[1].getType().cast<RankedTensorType>();\n'
|
s += indent + 'auto rhsTy = operands[1].getType();\n'
|
||||||
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
||||||
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
||||||
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
|
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
|
||||||
|
|
Loading…
Reference in New Issue