diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index d1aa3fe..4a64395 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -232,34 +232,21 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp { let regions = (region SizedRegion<1>:$body); } -// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example, -// A tuple-like pattern match syntax could work: -// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { -// ... -// }, { -// ... -// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> () +// TODO(timshen): Add a custom syntax for this. def LHLO_CaseOp: LHLO_Op<"case", [ - AttrSizedOperandSegments, SingleBlockImplicitTerminator<"TerminatorOp"> ]>, BASE_HLO_CaseOp { - let arguments = (ins - Arg:$index, - Arg, "", [MemRead]>:$branch_operands, - Arg, "", [MemWrite]>:$out - ); + let arguments = (ins Arg:$index); let regions = (region VariadicRegion>:$branches); } // TODO(timshen): Add a custom syntax for this. -def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>, - BASE_HLO_WhileOp { +def LHLO_WhileOp: LHLO_Op<"while", []>, BASE_HLO_WhileOp { let arguments = (ins - Arg, "", [MemRead]>:$val, - Arg, "", [MemWrite]>:$output - ); + Arg, "", [MemWrite]>:$cond_val, + OptionalAttr:$trip_count); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); } diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index a7fc702..69a88d4 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -472,21 +472,20 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m // CHECK-LABEL: func @case_memref func @case_memref(%index: memref, %operand_1: memref, %operand_2: memref, %operand_3: memref, %out: memref) -> () { - "lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { - ^bb0(%arg0: memref): - "lmhlo.negate"(%arg0, %out) : (memref, memref) -> () + "lmhlo.case"(%index) ( { + ^bb0: + "lmhlo.negate"(%operand_1, %out) : (memref, memref) -> () "lmhlo.terminator"() : () -> () }, { - ^bb0(%arg0: memref): - "lmhlo.copy"(%arg0, %out) : (memref, memref) -> () + ^bb0: + "lmhlo.copy"(%operand_2, %out) : (memref, memref) -> () "lmhlo.terminator"() : () -> () }, { - ^bb0(%arg0: memref): - "lmhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () + ^bb0: + "lmhlo.add"(%operand_3, %operand_3, %out) : (memref, memref, memref) -> () "lmhlo.terminator"() : () -> () } - ) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>} - : (memref, memref, memref, memref, memref) -> () + ) : (memref) -> () return } @@ -908,22 +907,22 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, % // ----- // CHECK-LABEL: func @while_memrefs -func @while_memrefs(%arg0: memref, %arg_out: memref) -> () { - "lmhlo.while"(%arg0, %arg_out) ( - { ^bb0(%arg: memref, %cond: memref): "lmhlo.terminator"() : () -> () }, - { ^bb0(%arg: memref, %body_out: memref): "lmhlo.terminator"() : () -> () } - ) : (memref, memref) -> () +func @while_memrefs(%arg0: memref, %arg_out: memref, %cond: memref) -> () { + "lmhlo.while"(%cond) ( + { ^bb0: "lmhlo.terminator"() : () -> () }, + { ^bb0: "lmhlo.terminator"() : () -> () } + ) : (memref) -> () return } // ----- // CHECK-LABEL: func @while_memrefs -func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref, %arg1_out: memref<5xf32>) -> () { - "lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( - { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %cond: memref): "lmhlo.terminator"() : () -> () }, - { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %body_out0: memref, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () } - ) : (memref, memref<5xf32>, memref, memref<5xf32>) -> () +func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref, %arg1_out: memref<5xf32>, %cond: memref) -> () { + "lmhlo.while"(%cond) ( + { ^bb0: "lmhlo.terminator"() : () -> () }, + { ^bb0: "lmhlo.terminator"() : () -> () } + ) : (memref) -> () return }