diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index a1f0480..6f453d1 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -340,6 +340,33 @@ void DynamicIotaOp::getCanonicalizationPatterns( results.insert(context); } +//===----------------------------------------------------------------------===// +// DynamicUpdateSliceOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicUpdateSliceOp op) { + OperandRange indices = op.start_indices(); + if (indices.size() <= 1) return success(); + + // Note: start_indices is constrained to Variadic, so it + // is OK to cast indices to ShapedType here. + auto idx_tensor = indices.take_front().front().getType().cast(); + Type first_elem_ty = idx_tensor.getElementType(); + Type elem_ty; + + for (auto idx : llvm::drop_begin(indices, 1)) { + idx_tensor = idx.getType().cast(); + elem_ty = idx_tensor.getElementType(); + + if (first_elem_ty != elem_ty) { + return op.emitOpError() << "start indices must have same element type " + "(encountered mismatch: " + << first_elem_ty << " vs " << elem_ty << ")"; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // AbsOp //===----------------------------------------------------------------------===// diff --git a/tests/ops.mlir b/tests/ops.mlir index 3443f21..25c7d6a 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -754,6 +754,14 @@ func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tenso // ----- +func @dynamic_update_slice_mismatched_start(%input: tensor<11x3x4xi32>, %update: tensor<1x3x4xi32>, %start1: tensor, %start2: tensor, %start3: tensor) -> tensor<11x3x4xi32> { + // expected-error@+1 {{start indices must have same element type (encountered mismatch: 'i32' vs 'i64')}} + %0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2, %start3) : (tensor<11x3x4xi32>, tensor<1x3x4xi32>, tensor, tensor, tensor) -> tensor<11x3x4xi32> + return %0 : tensor<11x3x4xi32> +} + +// ----- + // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>