Avoid creating tuple type only for verification
Make the error message a bit more verbose & it is cheaper to verify the elements rather than creating a (potentially) new type. PiperOrigin-RevId: 363073909
This commit is contained in:
parent
01d729d35d
commit
3de2024a9b
|
@ -597,12 +597,18 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TupleOp op) {
|
||||
SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
|
||||
op.operand_type_end()};
|
||||
auto expectedType = TupleType::get(op.getContext(), operandTypes);
|
||||
if (op.getType() != expectedType) {
|
||||
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
|
||||
op.getType(), expectedType));
|
||||
auto opType = op.getType().dyn_cast<TupleType>();
|
||||
if (!opType) return op.emitOpError("tuple op with non-tuple result");
|
||||
if (op.getNumOperands() != opType.size())
|
||||
return op.emitOpError(
|
||||
"number of operands to tuple expected to match number of types in "
|
||||
"resultant tuple type");
|
||||
for (auto it : llvm::enumerate(
|
||||
llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) {
|
||||
if (std::get<0>(it.value()) != std::get<1>(it.value()))
|
||||
return op.emitOpError("has return type mismatch at ")
|
||||
<< it.index() << "th value (" << std::get<0>(it.value())
|
||||
<< " != " << std::get<1>(it.value()) << ")";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -953,7 +953,7 @@ func @tuple_token(%arg0: tensor<f32>, %arg1: !mhlo.token) -> tuple<tensor<f32>,
|
|||
// -----
|
||||
|
||||
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
|
||||
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
|
||||
// expected-error@+1 {{number of operands to tuple expected to match number of types}}
|
||||
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||
return %0 : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||
}
|
||||
|
@ -961,7 +961,7 @@ func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<t
|
|||
// -----
|
||||
|
||||
func @tuple_type_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<i32>> {
|
||||
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<i32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
|
||||
// expected-error@+1 {{op has return type mismatch at 1th value}}
|
||||
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||
return %0 : tuple<tensor<f32>, tensor<i32>>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue