Properly handle if DynamicBroadcastInDimOp shape is not of index type.
* The op defines this to be index, any integer, or pred (i1). * Many TensorFlow legalizations produce integers for the shape. PiperOrigin-RevId: 374566113
This commit is contained in:
parent
0fe07e3814
commit
71394fb301
|
@ -536,6 +536,9 @@ class HloDynamicBroadcastInDimConverter
|
||||||
Value shape = adaptor.output_dimensions();
|
Value shape = adaptor.output_dimensions();
|
||||||
auto shape_type = shape.getType().cast<RankedTensorType>();
|
auto shape_type = shape.getType().cast<RankedTensorType>();
|
||||||
int64_t result_rank = shape_type.getDimSize(0);
|
int64_t result_rank = shape_type.getDimSize(0);
|
||||||
|
// HLO dimension types can be any integer, as well as index.
|
||||||
|
bool convert_to_index =
|
||||||
|
shape_type.getElementType() != rewriter.getIndexType();
|
||||||
|
|
||||||
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!result_type) return failure();
|
if (!result_type) return failure();
|
||||||
|
@ -545,7 +548,11 @@ class HloDynamicBroadcastInDimConverter
|
||||||
for (int i = 0; i < result_rank; ++i) {
|
for (int i = 0; i < result_rank; ++i) {
|
||||||
if (!result_type.isDynamicDim(i)) continue;
|
if (!result_type.isDynamicDim(i)) continue;
|
||||||
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
||||||
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
|
Value dim = rewriter.create<tensor::ExtractOp>(loc, shape, index);
|
||||||
|
if (convert_to_index) {
|
||||||
|
dim = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), dim);
|
||||||
|
}
|
||||||
|
dyn_dims.push_back(dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t nloops = result_type.getRank();
|
int64_t nloops = result_type.getRank();
|
||||||
|
|
|
@ -1054,6 +1054,19 @@ func @dynamic_broadcast_in_dim(%vector: tensor<42xf32>, %shape: tensor<3xindex>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
||||||
|
// Note: this test requires no checks. The linalg init_tensor verifier will
|
||||||
|
// fail if the %shape i32 -> index cast is not performed properly.
|
||||||
|
func @dynamic_broadcast_in_dim(%scalar: tensor<f32>, %shape: tensor<2xi32>)
|
||||||
|
-> tensor<?x32xf32> {
|
||||||
|
%result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) {
|
||||||
|
broadcast_dimensions = dense<> : tensor<0xi64>
|
||||||
|
} : (tensor<f32>, tensor<2xi32>) -> tensor<?x32xf32>
|
||||||
|
return %result : tensor<?x32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @dot_matmul(%arg0: tensor<2x3xf32>,
|
func @dot_matmul(%arg0: tensor<2x3xf32>,
|
||||||
%arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
|
%arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
|
||||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,
|
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,
|
||||||
|
|
Loading…
Reference in New Issue