Verify mhlo.if region return types match op
This matches the behavior of mhlo.case. Additionally, fix the verification of CaseOp in the case of nested ops with mhlo.return-containing regions. PiperOrigin-RevId: 365936672
This commit is contained in:
		
							parent
							
								
									3be9874d82
								
							
						
					
					
						commit
						2fb2a92c6e
					
				|  | @ -31,6 +31,7 @@ limitations under the License. | |||
| #include "llvm/ADT/STLExtras.h" | ||||
| #include "llvm/ADT/SmallVector.h" | ||||
| #include "llvm/ADT/StringRef.h" | ||||
| #include "llvm/ADT/Twine.h" | ||||
| #include "llvm/ADT/iterator_range.h" | ||||
| #include "llvm/Support/Casting.h" | ||||
| #include "llvm/Support/FormatVariadic.h" | ||||
|  | @ -2084,6 +2085,50 @@ LogicalResult ReplicaIdOp::inferReturnTypes( | |||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // If Op
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| static LogicalResult VerifyConditionalBranch(Operation* op, Region& region, | ||||
|                                              Value operand, | ||||
|                                              llvm::Twine branchName, | ||||
|                                              llvm::Twine operandName) { | ||||
|   mlir::Block& entryBlock = region.front(); | ||||
|   if (entryBlock.getNumArguments() != 1) | ||||
|     return op->emitOpError() | ||||
|            << branchName << " block should have single argument, but found " | ||||
|            << entryBlock.getNumArguments(); | ||||
| 
 | ||||
|   Type operandType = operand.getType(); | ||||
|   Type branchArgType = entryBlock.getArgument(0).getType(); | ||||
|   if (branchArgType != operandType) | ||||
|     return op->emitOpError() | ||||
|            << operandName << " type (" << operandType << ") does not match " | ||||
|            << branchName << " block arg type (" << branchArgType << ")"; | ||||
|   TypeRange branchReturnTypes = entryBlock.getTerminator()->getOperandTypes(); | ||||
|   if (branchReturnTypes != op->getResultTypes()) | ||||
|     return op->emitOpError() | ||||
|            << branchName << " returned types (" << branchReturnTypes | ||||
|            << ") do not match op result types (" << op->getResultTypes() << ")"; | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| static LogicalResult Verify(IfOp op) { | ||||
|   if (failed(VerifyConditionalBranch(op, op.true_branch(), op.true_arg(), | ||||
|                                      /*branchName=*/"true_branch", | ||||
|                                      /*operandName=*/"true_arg"))) { | ||||
|     return failure(); | ||||
|   } | ||||
| 
 | ||||
|   if (failed(VerifyConditionalBranch(op, op.false_branch(), op.false_arg(), | ||||
|                                      /*branchName=*/"false_branch", | ||||
|                                      /*operandName=*/"false_arg"))) { | ||||
|     return failure(); | ||||
|   } | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Case Op
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -2091,35 +2136,17 @@ LogicalResult ReplicaIdOp::inferReturnTypes( | |||
| static LogicalResult Verify(CaseOp op) { | ||||
|   auto num_branches = op.branches().size(); | ||||
|   if (op.branch_operands().size() != num_branches) | ||||
|     return op.emitOpError() << "expects number of branches " << num_branches | ||||
|                             << " to be same as number of branch operands " | ||||
|                             << op.branch_operands().size(); | ||||
|     return op.emitOpError() << " number of branches (" << num_branches | ||||
|                             << ") does not match number of branch operands (" | ||||
|                             << op.branch_operands().size() << ")"; | ||||
| 
 | ||||
|   for (unsigned i = 0; i < num_branches; ++i) | ||||
|     if (failed(VerifyConditionalBranch( | ||||
|             op, op.branches()[i], op.branch_operands()[i], | ||||
|             /*branchName=*/"branch " + Twine(i), | ||||
|             /*operandName=*/"branch_operand " + Twine(i)))) | ||||
|       return failure(); | ||||
| 
 | ||||
|   MutableArrayRef<Region> branches = op.branches(); | ||||
|   OperandRange branch_operands = op.branch_operands(); | ||||
|   for (unsigned i = 0; i < num_branches; ++i) { | ||||
|     mlir::Region& branch_region = branches[i]; | ||||
|     mlir::Block& entry_block = branch_region.front(); | ||||
|     if (entry_block.getNumArguments() != 1) | ||||
|       return op.emitOpError() | ||||
|              << "expects branch regions to have single argument, but found " | ||||
|              << entry_block.getNumArguments() << " for branch " << i; | ||||
|     auto operand = branch_operands[i]; | ||||
|     if (entry_block.getArgument(0).getType() != operand.getType()) | ||||
|       return op.emitOpError() | ||||
|              << "expects operand " << i + 1 << " to be of type " | ||||
|              << entry_block.getArgument(0).getType() << ", but found " | ||||
|              << operand.getType(); | ||||
|     WalkResult walker = branch_region.walk([&](ReturnOp return_op) { | ||||
|       if (return_op.getOperands().getTypes() != op.getResultTypes()) | ||||
|         return WalkResult::interrupt(); | ||||
|       return WalkResult::advance(); | ||||
|     }); | ||||
|     if (walker.wasInterrupted()) | ||||
|       return op.emitOpError() | ||||
|              << "branch " << i | ||||
|              << " returned values do not match op result types"; | ||||
|   } | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -204,8 +204,89 @@ func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @if_nested_different_return_types( | ||||
| func @if_nested_different_return_types(%pred : tensor<i1>, %branch_operand : tensor<f32>) { | ||||
|   %0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({ | ||||
|     ^bb0(%arg0 : tensor<f32>): | ||||
|       "mhlo.return"(%arg0) : (tensor<f32>) -> () | ||||
|   }, { | ||||
|     ^bb1(%arg1 : tensor<f32>): | ||||
|       %2 = "mhlo.if"(%pred, %arg1, %arg1) ({ | ||||
|         ^bb0 (%arg2 : tensor<f32>): | ||||
|           "mhlo.return"(%pred) : (tensor<i1>) -> () | ||||
|       }, { | ||||
|         ^bb1 (%arg3 : tensor<f32>): | ||||
|           "mhlo.return"(%pred) : (tensor<i1>) -> () | ||||
|       }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<i1> | ||||
|       "mhlo.return"(%arg1) : (tensor<f32>) -> () | ||||
|   }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @if_mismatch_arg_type(%pred : tensor<i1>, %branch_operand : tensor<f32>, %wrong_type : tensor<3xf32>) { | ||||
|   // @expected-error@+1 {{true_arg type ('tensor<3xf32>') does not match true_branch block arg type ('tensor<f32>')}} | ||||
|   %0 = "mhlo.if"(%pred, %wrong_type, %branch_operand) ({ | ||||
|     ^bb0(%arg0 : tensor<f32>): | ||||
|       "mhlo.return"(%arg0) : (tensor<f32>) -> () | ||||
|     }, { | ||||
|     ^bb0(%arg1 : tensor<f32>): | ||||
|       "mhlo.return"(%arg1) : (tensor<f32>) -> () | ||||
|     }) : (tensor<i1>, tensor<3xf32>, tensor<f32>) -> tensor<f32> | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @if_mismatch_return_type(%pred : tensor<i1>, %branch_operand : tensor<f32>, %wrong_type : tensor<3xf32>) { | ||||
|   // @expected-error@+1 {{true_branch returned types ('tensor<3xf32>') do not match op result types ('tensor<f32>')}} | ||||
|   %0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({ | ||||
|     ^bb0(%arg0 : tensor<f32>): | ||||
|       "mhlo.return"(%wrong_type) : (tensor<3xf32>) -> () | ||||
|     }, { | ||||
|     ^bb0(%arg1 : tensor<f32>): | ||||
|       "mhlo.return"(%arg1) : (tensor<f32>) -> () | ||||
|     }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @if_mismatch_num_return_types(%pred : tensor<i1>, %branch_operand : tensor<f32>) { | ||||
|   // @expected-error@+1 {{true_branch returned types ('tensor<f32>', 'tensor<f32>') do not match op result types ('tensor<f32>')}} | ||||
|   %0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({ | ||||
|     ^bb0(%arg0 : tensor<f32>): | ||||
|       "mhlo.return"(%branch_operand, %branch_operand) : (tensor<f32>, tensor<f32>) -> () | ||||
|     }, { | ||||
|     ^bb0(%arg1 : tensor<f32>): | ||||
|       "mhlo.return"(%arg1) : (tensor<f32>) -> () | ||||
|     }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @case_nested_different_return_types( | ||||
| func @case_nested_different_return_types(%index : tensor<i32>, %branch_operand : tensor<f32>) { | ||||
|   %0 = "mhlo.case"(%index, %branch_operand, %branch_operand) ({ | ||||
|     ^bb0(%arg0 : tensor<f32>): | ||||
|       "mhlo.return"(%arg0) : (tensor<f32>) -> () | ||||
|   }, { | ||||
|     ^bb1(%arg1 : tensor<f32>): | ||||
|       %2 = "mhlo.case"(%index, %arg1) ({ | ||||
|         ^bb0 (%arg2 : tensor<f32>): | ||||
|           "mhlo.return"(%index) : (tensor<i32>) -> () | ||||
|       }) : (tensor<i32>, tensor<f32>) -> tensor<i32> | ||||
|       "mhlo.return"(%arg1) : (tensor<f32>) -> () | ||||
|   }) : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> { | ||||
|   // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} | ||||
|   // expected-error@+1 {{branch 1 block should have single argument, but found 2}} | ||||
|   %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { | ||||
|     ^bb0(%arg0: tensor<f32>): | ||||
|       %1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> | ||||
|  | @ -226,7 +307,7 @@ func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %oper | |||
| // ----- | ||||
| 
 | ||||
| func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> { | ||||
|   // expected-error@+1 {{branch 1 returned values do not match op result types}} | ||||
|   // expected-error@+1 {{branch 1 returned types ('tensor<f32>', 'tensor<f32>') do not match op result types ('tensor<f32>')}} | ||||
|   %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { | ||||
|     ^bb0(%arg0: tensor<f32>): | ||||
|       %1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> | ||||
|  | @ -247,7 +328,7 @@ func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %o | |||
| // ----- | ||||
| 
 | ||||
| func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> { | ||||
|   // expected-error@+1 {{expects operand 2 to be of type 'tensor<i32>', but found 'tensor<f32>'}} | ||||
|   // expected-error@+1 {{branch_operand 1 type ('tensor<f32>') does not match branch 1 block arg type ('tensor<i32>')}} | ||||
|   %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { | ||||
|     ^bb0(%arg0: tensor<f32>): | ||||
|       %1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> | ||||
|  | @ -268,7 +349,7 @@ func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %oper | |||
| // ----- | ||||
| 
 | ||||
| func @case_mismatch_return_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> { | ||||
|   // expected-error@+1 {{branch 1 returned values do not match op result types}} | ||||
|   // expected-error@+1 {{branch 1 returned types ('tensor<i32>') do not match op result types ('tensor<f32>')}} | ||||
|   %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { | ||||
|     ^bb0(%arg0: tensor<f32>): | ||||
|       %1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue