[MLIR] Change LMHLO Conditional and While to capture needed buffers, instead of passing them by operands.
This is consistent with the design of LMHLO FusionOp, and it simplifies the usage. Before the change, those redundant operands ended up unused as all sub-regions can already capture needed buffers. PiperOrigin-RevId: 362381155
This commit is contained in:
parent
8066794eea
commit
d16860d26d
|
@ -232,34 +232,21 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
|
|||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
|
||||
// A tuple-like pattern match syntax could work:
|
||||
// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
|
||||
// ...
|
||||
// }, {
|
||||
// ...
|
||||
// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> ()
|
||||
// TODO(timshen): Add a custom syntax for this.
|
||||
def LHLO_CaseOp: LHLO_Op<"case", [
|
||||
AttrSizedOperandSegments,
|
||||
SingleBlockImplicitTerminator<"TerminatorOp">
|
||||
]>, BASE_HLO_CaseOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$index,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$branch_operands,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out
|
||||
);
|
||||
let arguments = (ins Arg<LHLO_PredOrIntBuffer, "", [MemRead]>:$index);
|
||||
|
||||
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
|
||||
}
|
||||
|
||||
// TODO(timshen): Add a custom syntax for this.
|
||||
def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
||||
BASE_HLO_WhileOp {
|
||||
def LHLO_WhileOp: LHLO_Op<"while", []>, BASE_HLO_WhileOp {
|
||||
let arguments = (ins
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output
|
||||
);
|
||||
Arg<Variadic<LHLO_PredBuffer>, "", [MemWrite]>:$cond_val,
|
||||
OptionalAttr<I64Attr>:$trip_count);
|
||||
|
||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
}
|
||||
|
|
|
@ -472,21 +472,20 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
|
|||
|
||||
// CHECK-LABEL: func @case_memref
|
||||
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
|
||||
"lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.case"(%index) ( {
|
||||
^bb0:
|
||||
"lmhlo.negate"(%operand_1, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
^bb0:
|
||||
"lmhlo.copy"(%operand_2, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
^bb0:
|
||||
"lmhlo.add"(%operand_3, %operand_3, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>}
|
||||
: (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
) : (memref<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -908,22 +907,22 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func @while_memrefs
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
|
||||
"lmhlo.while"(%arg0, %arg_out) (
|
||||
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "lmhlo.terminator"() : () -> () }
|
||||
) : (memref<i64>, memref<i64>) -> ()
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>, %cond: memref<i1>) -> () {
|
||||
"lmhlo.while"(%cond) (
|
||||
{ ^bb0: "lmhlo.terminator"() : () -> () },
|
||||
{ ^bb0: "lmhlo.terminator"() : () -> () }
|
||||
) : (memref<i1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @while_memrefs
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () {
|
||||
"lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () }
|
||||
) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> ()
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>, %cond: memref<i1>) -> () {
|
||||
"lmhlo.while"(%cond) (
|
||||
{ ^bb0: "lmhlo.terminator"() : () -> () },
|
||||
{ ^bb0: "lmhlo.terminator"() : () -> () }
|
||||
) : (memref<i1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue