Restrict MHLO control flow ops to single-block regions
PiperOrigin-RevId: 365935824
This commit is contained in:
parent
e78c59d927
commit
7a9394dca5
|
@ -524,7 +524,9 @@ def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> {
|
||||||
// Xla Client API has two separate calls for indexed and predicated conditional,
|
// Xla Client API has two separate calls for indexed and predicated conditional,
|
||||||
// although both eventually map to kConditional HLO. IfOp maps to predicated
|
// although both eventually map to kConditional HLO. IfOp maps to predicated
|
||||||
// conditional use of kConditional HLO.
|
// conditional use of kConditional HLO.
|
||||||
def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> {
|
def HLO_IfOp: HLO_Op<"if", [
|
||||||
|
RecursiveSideEffects,
|
||||||
|
SingleBlockImplicitTerminator<"ReturnOp">]> {
|
||||||
string summary = "If operator";
|
string summary = "If operator";
|
||||||
|
|
||||||
string description = [{
|
string description = [{
|
||||||
|
@ -540,8 +542,8 @@ def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> {
|
||||||
HLO_TensorOrTuple:$false_arg
|
HLO_TensorOrTuple:$false_arg
|
||||||
);
|
);
|
||||||
|
|
||||||
let regions = (region AnyRegion:$true_branch,
|
let regions = (region SizedRegion<1>:$true_branch,
|
||||||
AnyRegion:$false_branch);
|
SizedRegion<1>:$false_branch);
|
||||||
|
|
||||||
let results = (outs HLO_TensorOrTuple);
|
let results = (outs HLO_TensorOrTuple);
|
||||||
|
|
||||||
|
@ -552,15 +554,17 @@ def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> {
|
||||||
// Xla Client API has two separate calls for indexed and predicated conditional,
|
// Xla Client API has two separate calls for indexed and predicated conditional,
|
||||||
// although both eventually map to kConditional HLO. CaseOp maps to indexed
|
// although both eventually map to kConditional HLO. CaseOp maps to indexed
|
||||||
// conditional use of kConditional HLO.
|
// conditional use of kConditional HLO.
|
||||||
def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>,
|
def HLO_CaseOp: HLO_Op<"case", [
|
||||||
BASE_HLO_CaseOp {
|
RecursiveSideEffects,
|
||||||
|
SingleBlockImplicitTerminator<"ReturnOp">
|
||||||
|
]>, BASE_HLO_CaseOp {
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I32Tensor:$index,
|
I32Tensor:$index,
|
||||||
Variadic<HLO_TensorOrTuple>:$branch_operands
|
Variadic<HLO_TensorOrTuple>:$branch_operands
|
||||||
);
|
);
|
||||||
|
|
||||||
let regions = (region VariadicRegion<AnyRegion>:$branches);
|
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
|
||||||
|
|
||||||
let results = (outs Variadic<HLO_TensorOrTuple>);
|
let results = (outs Variadic<HLO_TensorOrTuple>);
|
||||||
|
|
||||||
|
@ -568,12 +572,14 @@ def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects,
|
def HLO_WhileOp: HLO_Op<"while", [
|
||||||
SameOperandsAndResultType]>,
|
RecursiveSideEffects,
|
||||||
BASE_HLO_WhileOp {
|
SameOperandsAndResultType,
|
||||||
|
SingleBlockImplicitTerminator<"ReturnOp">
|
||||||
|
]>, BASE_HLO_WhileOp {
|
||||||
let arguments = (ins HLO_TensorOrTuple:$val);
|
let arguments = (ins HLO_TensorOrTuple:$val);
|
||||||
|
|
||||||
let regions = (region AnyRegion:$cond, AnyRegion:$body);
|
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||||
|
|
||||||
let results = (outs HLO_TensorOrTuple);
|
let results = (outs HLO_TensorOrTuple);
|
||||||
|
|
||||||
|
|
|
@ -2099,8 +2099,6 @@ static LogicalResult Verify(CaseOp op) {
|
||||||
OperandRange branch_operands = op.branch_operands();
|
OperandRange branch_operands = op.branch_operands();
|
||||||
for (unsigned i = 0; i < num_branches; ++i) {
|
for (unsigned i = 0; i < num_branches; ++i) {
|
||||||
mlir::Region& branch_region = branches[i];
|
mlir::Region& branch_region = branches[i];
|
||||||
if (branch_region.empty())
|
|
||||||
return op.emitOpError() << "cannot have empty regions";
|
|
||||||
mlir::Block& entry_block = branch_region.front();
|
mlir::Block& entry_block = branch_region.front();
|
||||||
if (entry_block.getNumArguments() != 1)
|
if (entry_block.getNumArguments() != 1)
|
||||||
return op.emitOpError()
|
return op.emitOpError()
|
||||||
|
|
|
@ -58,89 +58,3 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
return %1 : tensor<f32>
|
return %1 : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
|
|
||||||
func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
|
|
||||||
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
|
|
||||||
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
|
|
||||||
// CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
|
||||||
// CHECK: %2 = tensor.extract %1[] : tensor<i1>
|
|
||||||
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
|
|
||||||
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
|
|
||||||
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
|
|
||||||
// CHECK: ^[[BODY_SUCC]](%4: tensor<i64>):
|
|
||||||
// CHECK: %5 = mhlo.add %4, %4 : tensor<i64>
|
|
||||||
// CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>)
|
|
||||||
// CHECK: ^[[EXIT]](%6: tensor<i64>):
|
|
||||||
// CHECK: return %6 : tensor<i64>
|
|
||||||
// CHECK: }
|
|
||||||
%0 = "mhlo.while"(%arg0) ( {
|
|
||||||
^cond_entry(%arg1: tensor<i64>):
|
|
||||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
|
||||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
|
||||||
}, {
|
|
||||||
^body_entry(%arg1: tensor<i64>):
|
|
||||||
br ^body_succ(%arg1: tensor<i64>)
|
|
||||||
^body_succ(%0: tensor<i64>):
|
|
||||||
%1 = mhlo.add %0, %0 : tensor<i64>
|
|
||||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
|
||||||
}) : (tensor<i64>) -> tensor<i64>
|
|
||||||
|
|
||||||
return %0 : tensor<i64>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
|
|
||||||
func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
|
|
||||||
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
|
|
||||||
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
|
|
||||||
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
|
|
||||||
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
|
|
||||||
// CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
|
||||||
// CHECK: %3 = tensor.extract %2[] : tensor<i1>
|
|
||||||
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
|
|
||||||
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
|
|
||||||
// CHECK: br ^[[COND_ENTRY]](%4 : tensor<i64>)
|
|
||||||
// CHECK: ^[[EXIT]](%5: tensor<i64>):
|
|
||||||
// CHECK: return %5 : tensor<i64>
|
|
||||||
// CHECK: }
|
|
||||||
%0 = "mhlo.while"(%arg0) ( {
|
|
||||||
^cond_entry(%arg1: tensor<i64>):
|
|
||||||
br ^cond_succ(%arg1: tensor<i64>)
|
|
||||||
^cond_succ(%0: tensor<i64>):
|
|
||||||
%1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
|
||||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
|
||||||
}, {
|
|
||||||
^body_entry(%arg1: tensor<i64>):
|
|
||||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
|
||||||
}) : (tensor<i64>) -> tensor<i64>
|
|
||||||
|
|
||||||
return %0 : tensor<i64>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
|
|
||||||
func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
|
|
||||||
// CHECK: %0 = tensor.extract %arg2[] : tensor<i1>
|
|
||||||
// CHECK: cond_br %0, ^[[THEN_ENTRY:.+]](%arg0 : tensor<f32>), ^[[ELSE_ENTRY:.+]](%arg1 : tensor<f32>)
|
|
||||||
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
|
|
||||||
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
|
|
||||||
// CHECK: ^[[THEN_SUCC]](%2: tensor<f32>):
|
|
||||||
// CHECK: %3 = "mhlo.log"(%2) : (tensor<f32>) -> tensor<f32>
|
|
||||||
// CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>)
|
|
||||||
// CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>):
|
|
||||||
// CHECK: %5 = "mhlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
|
|
||||||
// CHECK: br ^[[EXIT]](%5 : tensor<f32>)
|
|
||||||
// CHECK: ^[[EXIT]](%6: tensor<f32>):
|
|
||||||
// CHECK: return %6 : tensor<f32>
|
|
||||||
// CHECK: }
|
|
||||||
%1 = "mhlo.if"(%pred, %arg0, %arg1) ( {
|
|
||||||
^then_entry(%arg2: tensor<f32>):
|
|
||||||
br ^then_succ(%arg2: tensor<f32>)
|
|
||||||
^then_succ(%0: tensor<f32>):
|
|
||||||
%2 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
|
||||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
|
||||||
}, {
|
|
||||||
^else_entry(%arg2: tensor<f32>):
|
|
||||||
%2 = "mhlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
|
|
||||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
|
||||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
|
||||||
return %1 : tensor<f32>
|
|
||||||
}
|
|
||||||
|
|
|
@ -288,14 +288,6 @@ func @case_mismatch_return_type(%index: tensor<i32>, %operand_1: tensor<f32>, %o
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @case_empty_region(%index: tensor<i32>, %operand_1: tensor<f32>) -> () {
|
|
||||||
// expected-error@+1 {{cannot have empty regions}}
|
|
||||||
"mhlo.case"(%index, %operand_1) ( {} ) : (tensor<i32>, tensor<f32>) -> tensor<f32>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @comp_eq
|
// CHECK-LABEL: func @comp_eq
|
||||||
func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> {
|
func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> {
|
||||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||||
|
|
Loading…
Reference in New Issue