Verify that MHLO DynamicUpdateSlice start indices have matching element types.
HLO requires that the element types match for all start index parameters. Right now we don't catch this invalid case until export, so adding a check in the verifier so that we catch this sooner. This also requires a small tweak to the TF InplaceUpdate op lowering. PiperOrigin-RevId: 325463796
This commit is contained in:
parent
5f8da992f2
commit
d742477c02
|
@ -340,6 +340,33 @@ void DynamicIotaOp::getCanonicalizationPatterns(
|
|||
results.insert<DynamicIotaBroadcast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicUpdateSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(DynamicUpdateSliceOp op) {
|
||||
OperandRange indices = op.start_indices();
|
||||
if (indices.size() <= 1) return success();
|
||||
|
||||
// Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
|
||||
// is OK to cast indices to ShapedType here.
|
||||
auto idx_tensor = indices.take_front().front().getType().cast<ShapedType>();
|
||||
Type first_elem_ty = idx_tensor.getElementType();
|
||||
Type elem_ty;
|
||||
|
||||
for (auto idx : llvm::drop_begin(indices, 1)) {
|
||||
idx_tensor = idx.getType().cast<ShapedType>();
|
||||
elem_ty = idx_tensor.getElementType();
|
||||
|
||||
if (first_elem_ty != elem_ty) {
|
||||
return op.emitOpError() << "start indices must have same element type "
|
||||
"(encountered mismatch: "
|
||||
<< first_elem_ty << " vs " << elem_ty << ")";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AbsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -754,6 +754,14 @@ func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tenso
|
|||
|
||||
// -----
|
||||
|
||||
func @dynamic_update_slice_mismatched_start(%input: tensor<11x3x4xi32>, %update: tensor<1x3x4xi32>, %start1: tensor<i32>, %start2: tensor<i64>, %start3: tensor<i64>) -> tensor<11x3x4xi32> {
|
||||
// expected-error@+1 {{start indices must have same element type (encountered mismatch: 'i32' vs 'i64')}}
|
||||
%0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2, %start3) : (tensor<11x3x4xi32>, tensor<1x3x4xi32>, tensor<i32>, tensor<i64>, tensor<i64>) -> tensor<11x3x4xi32>
|
||||
return %0 : tensor<11x3x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose
|
||||
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
|
|
Loading…
Reference in New Issue