Verify mhlo.if region return types match op

This matches the behavior of mhlo.case. Additionally, fix the verification of CaseOp in the case of nested ops with mhlo.return-containing regions.

PiperOrigin-RevId: 365936672
This commit is contained in:
Geoffrey Martin-Noble 2021-03-30 17:55:12 -07:00 committed by TensorFlow MLIR Team
parent 3be9874d82
commit 2fb2a92c6e
2 changed files with 140 additions and 32 deletions

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
@ -2084,6 +2085,50 @@ LogicalResult ReplicaIdOp::inferReturnTypes(
return success();
}
//===----------------------------------------------------------------------===//
// If Op
//===----------------------------------------------------------------------===//
static LogicalResult VerifyConditionalBranch(Operation* op, Region& region,
Value operand,
llvm::Twine branchName,
llvm::Twine operandName) {
mlir::Block& entryBlock = region.front();
if (entryBlock.getNumArguments() != 1)
return op->emitOpError()
<< branchName << " block should have single argument, but found "
<< entryBlock.getNumArguments();
Type operandType = operand.getType();
Type branchArgType = entryBlock.getArgument(0).getType();
if (branchArgType != operandType)
return op->emitOpError()
<< operandName << " type (" << operandType << ") does not match "
<< branchName << " block arg type (" << branchArgType << ")";
TypeRange branchReturnTypes = entryBlock.getTerminator()->getOperandTypes();
if (branchReturnTypes != op->getResultTypes())
return op->emitOpError()
<< branchName << " returned types (" << branchReturnTypes
<< ") do not match op result types (" << op->getResultTypes() << ")";
return success();
}
static LogicalResult Verify(IfOp op) {
if (failed(VerifyConditionalBranch(op, op.true_branch(), op.true_arg(),
/*branchName=*/"true_branch",
/*operandName=*/"true_arg"))) {
return failure();
}
if (failed(VerifyConditionalBranch(op, op.false_branch(), op.false_arg(),
/*branchName=*/"false_branch",
/*operandName=*/"false_arg"))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// Case Op
//===----------------------------------------------------------------------===//
@ -2091,35 +2136,17 @@ LogicalResult ReplicaIdOp::inferReturnTypes(
static LogicalResult Verify(CaseOp op) {
auto num_branches = op.branches().size();
if (op.branch_operands().size() != num_branches)
return op.emitOpError() << "expects number of branches " << num_branches
<< " to be same as number of branch operands "
<< op.branch_operands().size();
return op.emitOpError() << " number of branches (" << num_branches
<< ") does not match number of branch operands ("
<< op.branch_operands().size() << ")";
for (unsigned i = 0; i < num_branches; ++i)
if (failed(VerifyConditionalBranch(
op, op.branches()[i], op.branch_operands()[i],
/*branchName=*/"branch " + Twine(i),
/*operandName=*/"branch_operand " + Twine(i))))
return failure();
MutableArrayRef<Region> branches = op.branches();
OperandRange branch_operands = op.branch_operands();
for (unsigned i = 0; i < num_branches; ++i) {
mlir::Region& branch_region = branches[i];
mlir::Block& entry_block = branch_region.front();
if (entry_block.getNumArguments() != 1)
return op.emitOpError()
<< "expects branch regions to have single argument, but found "
<< entry_block.getNumArguments() << " for branch " << i;
auto operand = branch_operands[i];
if (entry_block.getArgument(0).getType() != operand.getType())
return op.emitOpError()
<< "expects operand " << i + 1 << " to be of type "
<< entry_block.getArgument(0).getType() << ", but found "
<< operand.getType();
WalkResult walker = branch_region.walk([&](ReturnOp return_op) {
if (return_op.getOperands().getTypes() != op.getResultTypes())
return WalkResult::interrupt();
return WalkResult::advance();
});
if (walker.wasInterrupted())
return op.emitOpError()
<< "branch " << i
<< " returned values do not match op result types";
}
return success();
}

View File

@ -204,8 +204,89 @@ func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32>
// -----
// CHECK-LABEL: @if_nested_different_return_types(
func @if_nested_different_return_types(%pred : tensor<i1>, %branch_operand : tensor<f32>) {
%0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({
^bb0(%arg0 : tensor<f32>):
"mhlo.return"(%arg0) : (tensor<f32>) -> ()
}, {
^bb1(%arg1 : tensor<f32>):
%2 = "mhlo.if"(%pred, %arg1, %arg1) ({
^bb0 (%arg2 : tensor<f32>):
"mhlo.return"(%pred) : (tensor<i1>) -> ()
}, {
^bb1 (%arg3 : tensor<f32>):
"mhlo.return"(%pred) : (tensor<i1>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return
}
// -----
func @if_mismatch_arg_type(%pred : tensor<i1>, %branch_operand : tensor<f32>, %wrong_type : tensor<3xf32>) {
// @expected-error@+1 {{true_arg type ('tensor<3xf32>') does not match true_branch block arg type ('tensor<f32>')}}
%0 = "mhlo.if"(%pred, %wrong_type, %branch_operand) ({
^bb0(%arg0 : tensor<f32>):
"mhlo.return"(%arg0) : (tensor<f32>) -> ()
}, {
^bb0(%arg1 : tensor<f32>):
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
return
}
// -----
func @if_mismatch_return_type(%pred : tensor<i1>, %branch_operand : tensor<f32>, %wrong_type : tensor<3xf32>) {
// @expected-error@+1 {{true_branch returned types ('tensor<3xf32>') do not match op result types ('tensor<f32>')}}
%0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({
^bb0(%arg0 : tensor<f32>):
"mhlo.return"(%wrong_type) : (tensor<3xf32>) -> ()
}, {
^bb0(%arg1 : tensor<f32>):
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return
}
// -----
func @if_mismatch_num_return_types(%pred : tensor<i1>, %branch_operand : tensor<f32>) {
// @expected-error@+1 {{true_branch returned types ('tensor<f32>', 'tensor<f32>') do not match op result types ('tensor<f32>')}}
%0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({
^bb0(%arg0 : tensor<f32>):
"mhlo.return"(%branch_operand, %branch_operand) : (tensor<f32>, tensor<f32>) -> ()
}, {
^bb0(%arg1 : tensor<f32>):
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return
}
// -----
// CHECK-LABEL: @case_nested_different_return_types(
func @case_nested_different_return_types(%index : tensor<i32>, %branch_operand : tensor<f32>) {
%0 = "mhlo.case"(%index, %branch_operand, %branch_operand) ({
^bb0(%arg0 : tensor<f32>):
"mhlo.return"(%arg0) : (tensor<f32>) -> ()
}, {
^bb1(%arg1 : tensor<f32>):
%2 = "mhlo.case"(%index, %arg1) ({
^bb0 (%arg2 : tensor<f32>):
"mhlo.return"(%index) : (tensor<i32>) -> ()
}) : (tensor<i32>, tensor<f32>) -> tensor<i32>
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
}) : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return
}
// -----
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
// expected-error@+1 {{branch 1 block should have single argument, but found 2}}
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
^bb0(%arg0: tensor<f32>):
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
@ -226,7 +307,7 @@ func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %oper
// -----
func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{branch 1 returned values do not match op result types}}
// expected-error@+1 {{branch 1 returned types ('tensor<f32>', 'tensor<f32>') do not match op result types ('tensor<f32>')}}
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
^bb0(%arg0: tensor<f32>):
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
@ -247,7 +328,7 @@ func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %o
// -----
func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{expects operand 2 to be of type 'tensor<i32>', but found 'tensor<f32>'}}
// expected-error@+1 {{branch_operand 1 type ('tensor<f32>') does not match branch 1 block arg type ('tensor<i32>')}}
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
^bb0(%arg0: tensor<f32>):
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
@ -268,7 +349,7 @@ func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %oper
// -----
func @case_mismatch_return_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{branch 1 returned values do not match op result types}}
// expected-error@+1 {{branch 1 returned types ('tensor<i32>') do not match op result types ('tensor<f32>')}}
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
^bb0(%arg0: tensor<f32>):
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>