Avoid creating tuple type only for verification

Make the error message a bit more verbose & it is cheaper to verify the elements rather than creating a (potentially) new type.

PiperOrigin-RevId: 363073909
This commit is contained in:
Jacques Pienaar 2021-03-15 17:57:17 -07:00 committed by TensorFlow MLIR Team
parent 01d729d35d
commit 3de2024a9b
2 changed files with 14 additions and 8 deletions

View File

@ -597,12 +597,18 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TupleOp op) {
SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
op.operand_type_end()};
auto expectedType = TupleType::get(op.getContext(), operandTypes);
if (op.getType() != expectedType) {
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
op.getType(), expectedType));
auto opType = op.getType().dyn_cast<TupleType>();
if (!opType) return op.emitOpError("tuple op with non-tuple result");
if (op.getNumOperands() != opType.size())
return op.emitOpError(
"number of operands to tuple expected to match number of types in "
"resultant tuple type");
for (auto it : llvm::enumerate(
llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) {
if (std::get<0>(it.value()) != std::get<1>(it.value()))
return op.emitOpError("has return type mismatch at ")
<< it.index() << "th value (" << std::get<0>(it.value())
<< " != " << std::get<1>(it.value()) << ")";
}
return success();
}

View File

@ -953,7 +953,7 @@ func @tuple_token(%arg0: tensor<f32>, %arg1: !mhlo.token) -> tuple<tensor<f32>,
// -----
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
// expected-error@+1 {{number of operands to tuple expected to match number of types}}
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
return %0 : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
}
@ -961,7 +961,7 @@ func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<t
// -----
func @tuple_type_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<i32>> {
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<i32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
// expected-error@+1 {{op has return type mismatch at 1th value}}
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<i32>>
return %0 : tuple<tensor<f32>, tensor<i32>>
}