diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index aab5981..122a90c 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -524,7 +524,9 @@ def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { // Xla Client API has two separate calls for indexed and predicated conditional, // although both eventually map to kConditional HLO. IfOp maps to predicated // conditional use of kConditional HLO. -def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> { +def HLO_IfOp: HLO_Op<"if", [ + RecursiveSideEffects, + SingleBlockImplicitTerminator<"ReturnOp">]> { string summary = "If operator"; string description = [{ @@ -540,8 +542,8 @@ def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> { HLO_TensorOrTuple:$false_arg ); - let regions = (region AnyRegion:$true_branch, - AnyRegion:$false_branch); + let regions = (region SizedRegion<1>:$true_branch, + SizedRegion<1>:$false_branch); let results = (outs HLO_TensorOrTuple); @@ -552,15 +554,17 @@ def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> { // Xla Client API has two separate calls for indexed and predicated conditional, // although both eventually map to kConditional HLO. CaseOp maps to indexed // conditional use of kConditional HLO. -def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>, - BASE_HLO_CaseOp { +def HLO_CaseOp: HLO_Op<"case", [ + RecursiveSideEffects, + SingleBlockImplicitTerminator<"ReturnOp"> + ]>, BASE_HLO_CaseOp { let arguments = (ins I32Tensor:$index, Variadic:$branch_operands ); - let regions = (region VariadicRegion:$branches); + let regions = (region VariadicRegion>:$branches); let results = (outs Variadic); @@ -568,12 +572,14 @@ def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>, } -def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects, - SameOperandsAndResultType]>, - BASE_HLO_WhileOp { +def HLO_WhileOp: HLO_Op<"while", [ + RecursiveSideEffects, + SameOperandsAndResultType, + SingleBlockImplicitTerminator<"ReturnOp"> + ]>, BASE_HLO_WhileOp { let arguments = (ins HLO_TensorOrTuple:$val); - let regions = (region AnyRegion:$cond, AnyRegion:$body); + let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); let results = (outs HLO_TensorOrTuple); diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 16b987d..cdd2629 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2099,8 +2099,6 @@ static LogicalResult Verify(CaseOp op) { OperandRange branch_operands = op.branch_operands(); for (unsigned i = 0; i < num_branches; ++i) { mlir::Region& branch_region = branches[i]; - if (branch_region.empty()) - return op.emitOpError() << "cannot have empty regions"; mlir::Block& entry_block = branch_region.front(); if (entry_block.getNumArguments() != 1) return op.emitOpError() diff --git a/tests/legalize-control-flow.mlir b/tests/legalize-control-flow.mlir index 8e5e18a..3e06271 100644 --- a/tests/legalize-control-flow.mlir +++ b/tests/legalize-control-flow.mlir @@ -58,89 +58,3 @@ func @conditional(%arg0: tensor) -> tensor { return %1 : tensor } -// CHECK-LABEL: func @while_with_multiple_blocks_in_body(%arg0: tensor) -> tensor { -func @while_with_multiple_blocks_in_body(%arg0: tensor) -> tensor { - // CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor) - // CHECK: ^[[COND_ENTRY]](%0: tensor): - // CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: %2 = tensor.extract %1[] : tensor - // CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) - // CHECK: ^[[BODY_ENTRY]](%3: tensor): - // CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor) - // CHECK: ^[[BODY_SUCC]](%4: tensor): - // CHECK: %5 = mhlo.add %4, %4 : tensor - // CHECK: br ^[[COND_ENTRY]](%5 : tensor) - // CHECK: ^[[EXIT]](%6: tensor): - // CHECK: return %6 : tensor - // CHECK: } - %0 = "mhlo.while"(%arg0) ( { - ^cond_entry(%arg1: tensor): - %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }, { - ^body_entry(%arg1: tensor): - br ^body_succ(%arg1: tensor) - ^body_succ(%0: tensor): - %1 = mhlo.add %0, %0 : tensor - "mhlo.return"(%1) : (tensor) -> () - }) : (tensor) -> tensor - - return %0 : tensor -} - -// CHECK-LABEL: func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { -func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { - // CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor) - // CHECK: ^[[COND_ENTRY]](%0: tensor): - // CHECK: br ^[[COND_SUCC:.+]](%0 : tensor) - // CHECK: ^[[COND_SUCC]](%1: tensor): - // CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: %3 = tensor.extract %2[] : tensor - // CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) - // CHECK: ^[[BODY_ENTRY]](%4: tensor): - // CHECK: br ^[[COND_ENTRY]](%4 : tensor) - // CHECK: ^[[EXIT]](%5: tensor): - // CHECK: return %5 : tensor - // CHECK: } - %0 = "mhlo.while"(%arg0) ( { - ^cond_entry(%arg1: tensor): - br ^cond_succ(%arg1: tensor) - ^cond_succ(%0: tensor): - %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }, { - ^body_entry(%arg1: tensor): - "mhlo.return"(%arg1) : (tensor) -> () - }) : (tensor) -> tensor - - return %0 : tensor -} - -// CHECK-LABEL: func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { -func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, %pred: tensor) -> tensor { - // CHECK: %0 = tensor.extract %arg2[] : tensor - // CHECK: cond_br %0, ^[[THEN_ENTRY:.+]](%arg0 : tensor), ^[[ELSE_ENTRY:.+]](%arg1 : tensor) - // CHECK: ^[[THEN_ENTRY]](%1: tensor): - // CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor) - // CHECK: ^[[THEN_SUCC]](%2: tensor): - // CHECK: %3 = "mhlo.log"(%2) : (tensor) -> tensor - // CHECK: br ^[[EXIT:.+]](%3 : tensor) - // CHECK: ^[[ELSE_ENTRY]](%4: tensor): - // CHECK: %5 = "mhlo.exponential"(%4) : (tensor) -> tensor - // CHECK: br ^[[EXIT]](%5 : tensor) - // CHECK: ^[[EXIT]](%6: tensor): - // CHECK: return %6 : tensor - // CHECK: } - %1 = "mhlo.if"(%pred, %arg0, %arg1) ( { - ^then_entry(%arg2: tensor): - br ^then_succ(%arg2: tensor) - ^then_succ(%0: tensor): - %2 = "mhlo.log"(%0) : (tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () - }, { - ^else_entry(%arg2: tensor): - %2 = "mhlo.exponential"(%arg2) : (tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () - }) : (tensor, tensor, tensor) -> tensor - return %1 : tensor -} diff --git a/tests/ops.mlir b/tests/ops.mlir index 32f20ea..5adad7d 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -288,14 +288,6 @@ func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %o // ----- -func @case_empty_region(%index: tensor, %operand_1: tensor) -> () { - // expected-error@+1 {{cannot have empty regions}} - "mhlo.case"(%index, %operand_1) ( {} ) : (tensor, tensor) -> tensor - return -} - -// ----- - // CHECK-LABEL: func @comp_eq func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>