[MLIR] Change LMHLO Conditional and While to capture needed buffers, instead of passing them by operands.

This is consistent with the design of LMHLO FusionOp, and it simplifies the
usage. Before the change, those redundant operands ended up unused as all sub-regions can already capture needed buffers.

PiperOrigin-RevId: 362381155
This commit is contained in:
Tim Shen 2021-03-11 14:41:50 -08:00 committed by TensorFlow MLIR Team
parent 8066794eea
commit d16860d26d
2 changed files with 23 additions and 37 deletions

View File

@ -232,34 +232,21 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
let regions = (region SizedRegion<1>:$body);
}
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
// A tuple-like pattern match syntax could work:
// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// ...
// }, {
// ...
// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> ()
// TODO(timshen): Add a custom syntax for this.
def LHLO_CaseOp: LHLO_Op<"case", [
AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_CaseOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$index,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$branch_operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out
);
let arguments = (ins Arg<LHLO_PredOrIntBuffer, "", [MemRead]>:$index);
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
}
// TODO(timshen): Add a custom syntax for this.
def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
BASE_HLO_WhileOp {
def LHLO_WhileOp: LHLO_Op<"while", []>, BASE_HLO_WhileOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output
);
Arg<Variadic<LHLO_PredBuffer>, "", [MemWrite]>:$cond_val,
OptionalAttr<I64Attr>:$trip_count);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}

View File

@ -472,21 +472,20 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
// CHECK-LABEL: func @case_memref
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
"lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
^bb0(%arg0: memref<f32>):
"lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.case"(%index) ( {
^bb0:
"lmhlo.negate"(%operand_1, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}, {
^bb0(%arg0: memref<f32>):
"lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
^bb0:
"lmhlo.copy"(%operand_2, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}, {
^bb0(%arg0: memref<f32>):
"lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
^bb0:
"lmhlo.add"(%operand_3, %operand_3, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}
) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>}
: (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
) : (memref<i32>) -> ()
return
}
@ -908,22 +907,22 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %
// -----
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
"lmhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<i64>) -> ()
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>, %cond: memref<i1>) -> () {
"lmhlo.while"(%cond) (
{ ^bb0: "lmhlo.terminator"() : () -> () },
{ ^bb0: "lmhlo.terminator"() : () -> () }
) : (memref<i1>) -> ()
return
}
// -----
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () {
"lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> ()
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>, %cond: memref<i1>) -> () {
"lmhlo.while"(%cond) (
{ ^bb0: "lmhlo.terminator"() : () -> () },
{ ^bb0: "lmhlo.terminator"() : () -> () }
) : (memref<i1>) -> ()
return
}