Verify mhlo.if region return types match op
This matches the behavior of mhlo.case. Additionally, fix the verification of CaseOp in the case of nested ops with mhlo.return-containing regions. PiperOrigin-RevId: 365936672
This commit is contained in:
parent
3be9874d82
commit
2fb2a92c6e
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
#include "llvm/ADT/iterator_range.h"
|
#include "llvm/ADT/iterator_range.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
@ -2084,6 +2085,50 @@ LogicalResult ReplicaIdOp::inferReturnTypes(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// If Op
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult VerifyConditionalBranch(Operation* op, Region& region,
|
||||||
|
Value operand,
|
||||||
|
llvm::Twine branchName,
|
||||||
|
llvm::Twine operandName) {
|
||||||
|
mlir::Block& entryBlock = region.front();
|
||||||
|
if (entryBlock.getNumArguments() != 1)
|
||||||
|
return op->emitOpError()
|
||||||
|
<< branchName << " block should have single argument, but found "
|
||||||
|
<< entryBlock.getNumArguments();
|
||||||
|
|
||||||
|
Type operandType = operand.getType();
|
||||||
|
Type branchArgType = entryBlock.getArgument(0).getType();
|
||||||
|
if (branchArgType != operandType)
|
||||||
|
return op->emitOpError()
|
||||||
|
<< operandName << " type (" << operandType << ") does not match "
|
||||||
|
<< branchName << " block arg type (" << branchArgType << ")";
|
||||||
|
TypeRange branchReturnTypes = entryBlock.getTerminator()->getOperandTypes();
|
||||||
|
if (branchReturnTypes != op->getResultTypes())
|
||||||
|
return op->emitOpError()
|
||||||
|
<< branchName << " returned types (" << branchReturnTypes
|
||||||
|
<< ") do not match op result types (" << op->getResultTypes() << ")";
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult Verify(IfOp op) {
|
||||||
|
if (failed(VerifyConditionalBranch(op, op.true_branch(), op.true_arg(),
|
||||||
|
/*branchName=*/"true_branch",
|
||||||
|
/*operandName=*/"true_arg"))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (failed(VerifyConditionalBranch(op, op.false_branch(), op.false_arg(),
|
||||||
|
/*branchName=*/"false_branch",
|
||||||
|
/*operandName=*/"false_arg"))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Case Op
|
// Case Op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2091,35 +2136,17 @@ LogicalResult ReplicaIdOp::inferReturnTypes(
|
||||||
static LogicalResult Verify(CaseOp op) {
|
static LogicalResult Verify(CaseOp op) {
|
||||||
auto num_branches = op.branches().size();
|
auto num_branches = op.branches().size();
|
||||||
if (op.branch_operands().size() != num_branches)
|
if (op.branch_operands().size() != num_branches)
|
||||||
return op.emitOpError() << "expects number of branches " << num_branches
|
return op.emitOpError() << " number of branches (" << num_branches
|
||||||
<< " to be same as number of branch operands "
|
<< ") does not match number of branch operands ("
|
||||||
<< op.branch_operands().size();
|
<< op.branch_operands().size() << ")";
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < num_branches; ++i)
|
||||||
|
if (failed(VerifyConditionalBranch(
|
||||||
|
op, op.branches()[i], op.branch_operands()[i],
|
||||||
|
/*branchName=*/"branch " + Twine(i),
|
||||||
|
/*operandName=*/"branch_operand " + Twine(i))))
|
||||||
|
return failure();
|
||||||
|
|
||||||
MutableArrayRef<Region> branches = op.branches();
|
|
||||||
OperandRange branch_operands = op.branch_operands();
|
|
||||||
for (unsigned i = 0; i < num_branches; ++i) {
|
|
||||||
mlir::Region& branch_region = branches[i];
|
|
||||||
mlir::Block& entry_block = branch_region.front();
|
|
||||||
if (entry_block.getNumArguments() != 1)
|
|
||||||
return op.emitOpError()
|
|
||||||
<< "expects branch regions to have single argument, but found "
|
|
||||||
<< entry_block.getNumArguments() << " for branch " << i;
|
|
||||||
auto operand = branch_operands[i];
|
|
||||||
if (entry_block.getArgument(0).getType() != operand.getType())
|
|
||||||
return op.emitOpError()
|
|
||||||
<< "expects operand " << i + 1 << " to be of type "
|
|
||||||
<< entry_block.getArgument(0).getType() << ", but found "
|
|
||||||
<< operand.getType();
|
|
||||||
WalkResult walker = branch_region.walk([&](ReturnOp return_op) {
|
|
||||||
if (return_op.getOperands().getTypes() != op.getResultTypes())
|
|
||||||
return WalkResult::interrupt();
|
|
||||||
return WalkResult::advance();
|
|
||||||
});
|
|
||||||
if (walker.wasInterrupted())
|
|
||||||
return op.emitOpError()
|
|
||||||
<< "branch " << i
|
|
||||||
<< " returned values do not match op result types";
|
|
||||||
}
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -204,8 +204,89 @@ func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @if_nested_different_return_types(
|
||||||
|
func @if_nested_different_return_types(%pred : tensor<i1>, %branch_operand : tensor<f32>) {
|
||||||
|
%0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({
|
||||||
|
^bb0(%arg0 : tensor<f32>):
|
||||||
|
"mhlo.return"(%arg0) : (tensor<f32>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb1(%arg1 : tensor<f32>):
|
||||||
|
%2 = "mhlo.if"(%pred, %arg1, %arg1) ({
|
||||||
|
^bb0 (%arg2 : tensor<f32>):
|
||||||
|
"mhlo.return"(%pred) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb1 (%arg3 : tensor<f32>):
|
||||||
|
"mhlo.return"(%pred) : (tensor<i1>) -> ()
|
||||||
|
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
|
||||||
|
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @if_mismatch_arg_type(%pred : tensor<i1>, %branch_operand : tensor<f32>, %wrong_type : tensor<3xf32>) {
|
||||||
|
// @expected-error@+1 {{true_arg type ('tensor<3xf32>') does not match true_branch block arg type ('tensor<f32>')}}
|
||||||
|
%0 = "mhlo.if"(%pred, %wrong_type, %branch_operand) ({
|
||||||
|
^bb0(%arg0 : tensor<f32>):
|
||||||
|
"mhlo.return"(%arg0) : (tensor<f32>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg1 : tensor<f32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
|
||||||
|
}) : (tensor<i1>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @if_mismatch_return_type(%pred : tensor<i1>, %branch_operand : tensor<f32>, %wrong_type : tensor<3xf32>) {
|
||||||
|
// @expected-error@+1 {{true_branch returned types ('tensor<3xf32>') do not match op result types ('tensor<f32>')}}
|
||||||
|
%0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({
|
||||||
|
^bb0(%arg0 : tensor<f32>):
|
||||||
|
"mhlo.return"(%wrong_type) : (tensor<3xf32>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg1 : tensor<f32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
|
||||||
|
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @if_mismatch_num_return_types(%pred : tensor<i1>, %branch_operand : tensor<f32>) {
|
||||||
|
// @expected-error@+1 {{true_branch returned types ('tensor<f32>', 'tensor<f32>') do not match op result types ('tensor<f32>')}}
|
||||||
|
%0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({
|
||||||
|
^bb0(%arg0 : tensor<f32>):
|
||||||
|
"mhlo.return"(%branch_operand, %branch_operand) : (tensor<f32>, tensor<f32>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg1 : tensor<f32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
|
||||||
|
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @case_nested_different_return_types(
|
||||||
|
func @case_nested_different_return_types(%index : tensor<i32>, %branch_operand : tensor<f32>) {
|
||||||
|
%0 = "mhlo.case"(%index, %branch_operand, %branch_operand) ({
|
||||||
|
^bb0(%arg0 : tensor<f32>):
|
||||||
|
"mhlo.return"(%arg0) : (tensor<f32>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb1(%arg1 : tensor<f32>):
|
||||||
|
%2 = "mhlo.case"(%index, %arg1) ({
|
||||||
|
^bb0 (%arg2 : tensor<f32>):
|
||||||
|
"mhlo.return"(%index) : (tensor<i32>) -> ()
|
||||||
|
}) : (tensor<i32>, tensor<f32>) -> tensor<i32>
|
||||||
|
"mhlo.return"(%arg1) : (tensor<f32>) -> ()
|
||||||
|
}) : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||||
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
|
// expected-error@+1 {{branch 1 block should have single argument, but found 2}}
|
||||||
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||||
^bb0(%arg0: tensor<f32>):
|
^bb0(%arg0: tensor<f32>):
|
||||||
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
@ -226,7 +307,7 @@ func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %oper
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||||
// expected-error@+1 {{branch 1 returned values do not match op result types}}
|
// expected-error@+1 {{branch 1 returned types ('tensor<f32>', 'tensor<f32>') do not match op result types ('tensor<f32>')}}
|
||||||
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||||
^bb0(%arg0: tensor<f32>):
|
^bb0(%arg0: tensor<f32>):
|
||||||
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
@ -247,7 +328,7 @@ func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %o
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||||
// expected-error@+1 {{expects operand 2 to be of type 'tensor<i32>', but found 'tensor<f32>'}}
|
// expected-error@+1 {{branch_operand 1 type ('tensor<f32>') does not match branch 1 block arg type ('tensor<i32>')}}
|
||||||
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||||
^bb0(%arg0: tensor<f32>):
|
^bb0(%arg0: tensor<f32>):
|
||||||
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
@ -268,7 +349,7 @@ func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %oper
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @case_mismatch_return_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
func @case_mismatch_return_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||||
// expected-error@+1 {{branch 1 returned values do not match op result types}}
|
// expected-error@+1 {{branch 1 returned types ('tensor<i32>') do not match op result types ('tensor<f32>')}}
|
||||||
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
%0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||||
^bb0(%arg0: tensor<f32>):
|
^bb0(%arg0: tensor<f32>):
|
||||||
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||||
|
|
Loading…
Reference in New Issue