Avoid creating tuple type only for verification
Make the error message a bit more verbose & it is cheaper to verify the elements rather than creating a (potentially) new type. PiperOrigin-RevId: 363073909
This commit is contained in:
parent
01d729d35d
commit
3de2024a9b
|
@ -597,12 +597,18 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(TupleOp op) {
|
static LogicalResult Verify(TupleOp op) {
|
||||||
SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
|
auto opType = op.getType().dyn_cast<TupleType>();
|
||||||
op.operand_type_end()};
|
if (!opType) return op.emitOpError("tuple op with non-tuple result");
|
||||||
auto expectedType = TupleType::get(op.getContext(), operandTypes);
|
if (op.getNumOperands() != opType.size())
|
||||||
if (op.getType() != expectedType) {
|
return op.emitOpError(
|
||||||
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
|
"number of operands to tuple expected to match number of types in "
|
||||||
op.getType(), expectedType));
|
"resultant tuple type");
|
||||||
|
for (auto it : llvm::enumerate(
|
||||||
|
llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) {
|
||||||
|
if (std::get<0>(it.value()) != std::get<1>(it.value()))
|
||||||
|
return op.emitOpError("has return type mismatch at ")
|
||||||
|
<< it.index() << "th value (" << std::get<0>(it.value())
|
||||||
|
<< " != " << std::get<1>(it.value()) << ")";
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -953,7 +953,7 @@ func @tuple_token(%arg0: tensor<f32>, %arg1: !mhlo.token) -> tuple<tensor<f32>,
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
|
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
|
||||||
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
|
// expected-error@+1 {{number of operands to tuple expected to match number of types}}
|
||||||
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||||
return %0 : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
return %0 : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||||
}
|
}
|
||||||
|
@ -961,7 +961,7 @@ func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<t
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @tuple_type_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<i32>> {
|
func @tuple_type_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<i32>> {
|
||||||
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<i32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
|
// expected-error@+1 {{op has return type mismatch at 1th value}}
|
||||||
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<i32>>
|
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||||
return %0 : tuple<tensor<f32>, tensor<i32>>
|
return %0 : tuple<tensor<f32>, tensor<i32>>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue