From 3de2024a9bd8b0587cf60238162ab78d40f8694c Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 15 Mar 2021 17:57:17 -0700 Subject: [PATCH] Avoid creating tuple type only for verification Make the error message a bit more verbose & it is cheaper to verify the elements rather than creating a (potentially) new type. PiperOrigin-RevId: 363073909 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 18 ++++++++++++------ tests/ops.mlir | 4 ++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index fd43805..94079b8 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -597,12 +597,18 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// static LogicalResult Verify(TupleOp op) { - SmallVector operandTypes = {op.operand_type_begin(), - op.operand_type_end()}; - auto expectedType = TupleType::get(op.getContext(), operandTypes); - if (op.getType() != expectedType) { - return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}", - op.getType(), expectedType)); + auto opType = op.getType().dyn_cast(); + if (!opType) return op.emitOpError("tuple op with non-tuple result"); + if (op.getNumOperands() != opType.size()) + return op.emitOpError( + "number of operands to tuple expected to match number of types in " + "resultant tuple type"); + for (auto it : llvm::enumerate( + llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) { + if (std::get<0>(it.value()) != std::get<1>(it.value())) + return op.emitOpError("has return type mismatch at ") + << it.index() << "th value (" << std::get<0>(it.value()) + << " != " << std::get<1>(it.value()) << ")"; } return success(); } diff --git a/tests/ops.mlir b/tests/ops.mlir index 358b760..32f20ea 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -953,7 +953,7 @@ func @tuple_token(%arg0: tensor, %arg1: !mhlo.token) -> tuple, // ----- func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor, tensor> { - // expected-error@+1 {{has return type tuple, tensor, tensor>, but expected tuple, tensor>}} + // expected-error@+1 {{number of operands to tuple expected to match number of types}} %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> return %0 : tuple, tensor, tensor> } @@ -961,7 +961,7 @@ func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, %arg1: tensor) -> tuple, tensor> { - // expected-error@+1 {{has return type tuple, tensor>, but expected tuple, tensor>}} + // expected-error@+1 {{op has return type mismatch at 1th value}} %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor> return %0 : tuple, tensor> }