[MLIR] Change LMHLO Conditional and While to capture needed buffers, instead of passing them by operands.

This is consistent with the design of LMHLO FusionOp, and it simplifies the
usage. Before the change, those redundant operands ended up unused as all sub-regions can already capture needed buffers.

PiperOrigin-RevId: 362381155
This commit is contained in:
Tim Shen 2021-03-11 14:41:50 -08:00 committed by TensorFlow MLIR Team
parent 8066794eea
commit d16860d26d
2 changed files with 23 additions and 37 deletions

View File

@ -232,34 +232,21 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
} }
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example, // TODO(timshen): Add a custom syntax for this.
// A tuple-like pattern match syntax could work:
// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// ...
// }, {
// ...
// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> ()
def LHLO_CaseOp: LHLO_Op<"case", [ def LHLO_CaseOp: LHLO_Op<"case", [
AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"TerminatorOp"> SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_CaseOp { ]>, BASE_HLO_CaseOp {
let arguments = (ins let arguments = (ins Arg<LHLO_PredOrIntBuffer, "", [MemRead]>:$index);
Arg<LHLO_Buffer, "", [MemRead]>:$index,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$branch_operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out
);
let regions = (region VariadicRegion<SizedRegion<1>>:$branches); let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
} }
// TODO(timshen): Add a custom syntax for this. // TODO(timshen): Add a custom syntax for this.
def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>, def LHLO_WhileOp: LHLO_Op<"while", []>, BASE_HLO_WhileOp {
BASE_HLO_WhileOp {
let arguments = (ins let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val, Arg<Variadic<LHLO_PredBuffer>, "", [MemWrite]>:$cond_val,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output OptionalAttr<I64Attr>:$trip_count);
);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
} }

View File

@ -472,21 +472,20 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
// CHECK-LABEL: func @case_memref // CHECK-LABEL: func @case_memref
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () { func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
"lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { "lmhlo.case"(%index) ( {
^bb0(%arg0: memref<f32>): ^bb0:
"lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> () "lmhlo.negate"(%operand_1, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
}, { }, {
^bb0(%arg0: memref<f32>): ^bb0:
"lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> () "lmhlo.copy"(%operand_2, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
}, { }, {
^bb0(%arg0: memref<f32>): ^bb0:
"lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> () "lmhlo.add"(%operand_3, %operand_3, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>} ) : (memref<i32>) -> ()
: (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
return return
} }
@ -908,22 +907,22 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %
// ----- // -----
// CHECK-LABEL: func @while_memrefs // CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () { func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>, %cond: memref<i1>) -> () {
"lmhlo.while"(%arg0, %arg_out) ( "lmhlo.while"(%cond) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () }, { ^bb0: "lmhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "lmhlo.terminator"() : () -> () } { ^bb0: "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<i64>) -> () ) : (memref<i1>) -> ()
return return
} }
// ----- // -----
// CHECK-LABEL: func @while_memrefs // CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () { func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>, %cond: memref<i1>) -> () {
"lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( "lmhlo.while"(%cond) (
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () }, { ^bb0: "lmhlo.terminator"() : () -> () },
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () } { ^bb0: "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> () ) : (memref<i1>) -> ()
return return
} }