From 2fb2a92c6ef1c0a33c010e0b8fa9248374a6cee4 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Tue, 30 Mar 2021 17:55:12 -0700 Subject: [PATCH] Verify mhlo.if region return types match op This matches the behavior of mhlo.case. Additionally, fix the verification of CaseOp in the case of nested ops with mhlo.return-containing regions. PiperOrigin-RevId: 365936672 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 83 ++++++++++++++++++++----------- tests/ops.mlir | 89 ++++++++++++++++++++++++++++++++-- 2 files changed, 140 insertions(+), 32 deletions(-) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index cdd2629..bdc296d 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -2084,6 +2085,50 @@ LogicalResult ReplicaIdOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// If Op +//===----------------------------------------------------------------------===// + +static LogicalResult VerifyConditionalBranch(Operation* op, Region& region, + Value operand, + llvm::Twine branchName, + llvm::Twine operandName) { + mlir::Block& entryBlock = region.front(); + if (entryBlock.getNumArguments() != 1) + return op->emitOpError() + << branchName << " block should have single argument, but found " + << entryBlock.getNumArguments(); + + Type operandType = operand.getType(); + Type branchArgType = entryBlock.getArgument(0).getType(); + if (branchArgType != operandType) + return op->emitOpError() + << operandName << " type (" << operandType << ") does not match " + << branchName << " block arg type (" << branchArgType << ")"; + TypeRange branchReturnTypes = entryBlock.getTerminator()->getOperandTypes(); + if (branchReturnTypes != op->getResultTypes()) + return op->emitOpError() + << branchName << " returned types (" << branchReturnTypes + << ") do not match op result types (" << op->getResultTypes() << ")"; + + return success(); +} + +static LogicalResult Verify(IfOp op) { + if (failed(VerifyConditionalBranch(op, op.true_branch(), op.true_arg(), + /*branchName=*/"true_branch", + /*operandName=*/"true_arg"))) { + return failure(); + } + + if (failed(VerifyConditionalBranch(op, op.false_branch(), op.false_arg(), + /*branchName=*/"false_branch", + /*operandName=*/"false_arg"))) { + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // Case Op //===----------------------------------------------------------------------===// @@ -2091,35 +2136,17 @@ LogicalResult ReplicaIdOp::inferReturnTypes( static LogicalResult Verify(CaseOp op) { auto num_branches = op.branches().size(); if (op.branch_operands().size() != num_branches) - return op.emitOpError() << "expects number of branches " << num_branches - << " to be same as number of branch operands " - << op.branch_operands().size(); + return op.emitOpError() << " number of branches (" << num_branches + << ") does not match number of branch operands (" + << op.branch_operands().size() << ")"; + + for (unsigned i = 0; i < num_branches; ++i) + if (failed(VerifyConditionalBranch( + op, op.branches()[i], op.branch_operands()[i], + /*branchName=*/"branch " + Twine(i), + /*operandName=*/"branch_operand " + Twine(i)))) + return failure(); - MutableArrayRef branches = op.branches(); - OperandRange branch_operands = op.branch_operands(); - for (unsigned i = 0; i < num_branches; ++i) { - mlir::Region& branch_region = branches[i]; - mlir::Block& entry_block = branch_region.front(); - if (entry_block.getNumArguments() != 1) - return op.emitOpError() - << "expects branch regions to have single argument, but found " - << entry_block.getNumArguments() << " for branch " << i; - auto operand = branch_operands[i]; - if (entry_block.getArgument(0).getType() != operand.getType()) - return op.emitOpError() - << "expects operand " << i + 1 << " to be of type " - << entry_block.getArgument(0).getType() << ", but found " - << operand.getType(); - WalkResult walker = branch_region.walk([&](ReturnOp return_op) { - if (return_op.getOperands().getTypes() != op.getResultTypes()) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - if (walker.wasInterrupted()) - return op.emitOpError() - << "branch " << i - << " returned values do not match op result types"; - } return success(); } diff --git a/tests/ops.mlir b/tests/ops.mlir index 5adad7d..5c83761 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -204,8 +204,89 @@ func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> // ----- +// CHECK-LABEL: @if_nested_different_return_types( +func @if_nested_different_return_types(%pred : tensor, %branch_operand : tensor) { + %0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({ + ^bb0(%arg0 : tensor): + "mhlo.return"(%arg0) : (tensor) -> () + }, { + ^bb1(%arg1 : tensor): + %2 = "mhlo.if"(%pred, %arg1, %arg1) ({ + ^bb0 (%arg2 : tensor): + "mhlo.return"(%pred) : (tensor) -> () + }, { + ^bb1 (%arg3 : tensor): + "mhlo.return"(%pred) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return +} + +// ----- + +func @if_mismatch_arg_type(%pred : tensor, %branch_operand : tensor, %wrong_type : tensor<3xf32>) { + // @expected-error@+1 {{true_arg type ('tensor<3xf32>') does not match true_branch block arg type ('tensor')}} + %0 = "mhlo.if"(%pred, %wrong_type, %branch_operand) ({ + ^bb0(%arg0 : tensor): + "mhlo.return"(%arg0) : (tensor) -> () + }, { + ^bb0(%arg1 : tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor, tensor<3xf32>, tensor) -> tensor + return +} + +// ----- + +func @if_mismatch_return_type(%pred : tensor, %branch_operand : tensor, %wrong_type : tensor<3xf32>) { + // @expected-error@+1 {{true_branch returned types ('tensor<3xf32>') do not match op result types ('tensor')}} + %0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({ + ^bb0(%arg0 : tensor): + "mhlo.return"(%wrong_type) : (tensor<3xf32>) -> () + }, { + ^bb0(%arg1 : tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return +} + +// ----- + +func @if_mismatch_num_return_types(%pred : tensor, %branch_operand : tensor) { + // @expected-error@+1 {{true_branch returned types ('tensor', 'tensor') do not match op result types ('tensor')}} + %0 = "mhlo.if"(%pred, %branch_operand, %branch_operand) ({ + ^bb0(%arg0 : tensor): + "mhlo.return"(%branch_operand, %branch_operand) : (tensor, tensor) -> () + }, { + ^bb0(%arg1 : tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @case_nested_different_return_types( +func @case_nested_different_return_types(%index : tensor, %branch_operand : tensor) { + %0 = "mhlo.case"(%index, %branch_operand, %branch_operand) ({ + ^bb0(%arg0 : tensor): + "mhlo.return"(%arg0) : (tensor) -> () + }, { + ^bb1(%arg1 : tensor): + %2 = "mhlo.case"(%index, %arg1) ({ + ^bb0 (%arg2 : tensor): + "mhlo.return"(%index) : (tensor) -> () + }) : (tensor, tensor) -> tensor + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return +} + +// ----- + func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { - // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} + // expected-error@+1 {{branch 1 block should have single argument, but found 2}} %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor @@ -226,7 +307,7 @@ func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %oper // ----- func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { - // expected-error@+1 {{branch 1 returned values do not match op result types}} + // expected-error@+1 {{branch 1 returned types ('tensor', 'tensor') do not match op result types ('tensor')}} %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor @@ -247,7 +328,7 @@ func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %o // ----- func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { - // expected-error@+1 {{expects operand 2 to be of type 'tensor', but found 'tensor'}} + // expected-error@+1 {{branch_operand 1 type ('tensor') does not match branch 1 block arg type ('tensor')}} %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor @@ -268,7 +349,7 @@ func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %oper // ----- func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { - // expected-error@+1 {{branch 1 returned values do not match op result types}} + // expected-error@+1 {{branch 1 returned types ('tensor') do not match op result types ('tensor')}} %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { ^bb0(%arg0: tensor): %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor