Verify that MHLO DynamicUpdateSlice start indices have matching element types.

HLO requires that the element types match for all start index parameters. Right now we don't catch this invalid case until export, so adding a check in the verifier so that we catch this sooner.

This also requires a small tweak to the TF InplaceUpdate op lowering.

PiperOrigin-RevId: 325463796
This commit is contained in:
Lucy Fox 2020-08-07 10:46:06 -07:00 committed by Geoffrey Martin-Noble
parent 5f8da992f2
commit d742477c02
2 changed files with 35 additions and 0 deletions

View File

@ -340,6 +340,33 @@ void DynamicIotaOp::getCanonicalizationPatterns(
results.insert<DynamicIotaBroadcast>(context);
}
//===----------------------------------------------------------------------===//
// DynamicUpdateSliceOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(DynamicUpdateSliceOp op) {
OperandRange indices = op.start_indices();
if (indices.size() <= 1) return success();
// Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
// is OK to cast indices to ShapedType here.
auto idx_tensor = indices.take_front().front().getType().cast<ShapedType>();
Type first_elem_ty = idx_tensor.getElementType();
Type elem_ty;
for (auto idx : llvm::drop_begin(indices, 1)) {
idx_tensor = idx.getType().cast<ShapedType>();
elem_ty = idx_tensor.getElementType();
if (first_elem_ty != elem_ty) {
return op.emitOpError() << "start indices must have same element type "
"(encountered mismatch: "
<< first_elem_ty << " vs " << elem_ty << ")";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// AbsOp
//===----------------------------------------------------------------------===//

View File

@ -754,6 +754,14 @@ func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tenso
// -----
func @dynamic_update_slice_mismatched_start(%input: tensor<11x3x4xi32>, %update: tensor<1x3x4xi32>, %start1: tensor<i32>, %start2: tensor<i64>, %start3: tensor<i64>) -> tensor<11x3x4xi32> {
// expected-error@+1 {{start indices must have same element type (encountered mismatch: 'i32' vs 'i64')}}
%0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2, %start3) : (tensor<11x3x4xi32>, tensor<1x3x4xi32>, tensor<i32>, tensor<i64>, tensor<i64>) -> tensor<11x3x4xi32>
return %0 : tensor<11x3x4xi32>
}
// -----
// CHECK-LABEL: func @transpose
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>