[mhlo] Make sure reifyResultTypes returns a tensor of index
Dynamic broadcast/reshape/iota take i32/i64 shape inputs, but users of reification expect index shapes. Insert an appropriate cast if necessary. PiperOrigin-RevId: 380613128
This commit is contained in:
parent
a6b8882739
commit
03d2cb606d
|
@ -642,11 +642,33 @@ void DynamicIotaOp::getCanonicalizationPatterns(
|
||||||
results.insert<DynamicIotaBroadcast>(context);
|
results.insert<DynamicIotaBroadcast>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value castToIndexTensor(OpBuilder& builder, Location loc,
|
||||||
|
Value shape_op) {
|
||||||
|
ShapedType result_ty = shape::getExtentTensorType(
|
||||||
|
builder.getContext(),
|
||||||
|
shape_op.getType().cast<ShapedType>().getDimSize(0));
|
||||||
|
if (shape_op.getType() == result_ty) return shape_op; // Nothing to do.
|
||||||
|
// index_cast is not defined on tensors, so emit a tensor.generate instead.
|
||||||
|
return builder.create<tensor::GenerateOp>(
|
||||||
|
loc, result_ty,
|
||||||
|
result_ty.hasStaticShape()
|
||||||
|
? ValueRange{}
|
||||||
|
: ValueRange{builder.create<memref::DimOp>(loc, shape_op, 0)},
|
||||||
|
[&](OpBuilder& b, Location loc, ValueRange args) {
|
||||||
|
Value dim = args.front();
|
||||||
|
Value extent = b.create<tensor::ExtractOp>(loc, shape_op, dim);
|
||||||
|
Value casted =
|
||||||
|
b.create<IndexCastOp>(loc, extent, result_ty.getElementType());
|
||||||
|
b.create<tensor::YieldOp>(loc, casted);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
|
LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
|
||||||
OpBuilder&, ValueRange operands,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
DynamicIotaOp::Adaptor adaptor(operands);
|
DynamicIotaOp::Adaptor adaptor(operands);
|
||||||
reifiedReturnShapes.push_back(adaptor.output_shape());
|
reifiedReturnShapes.push_back(
|
||||||
|
castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1192,10 +1214,11 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
|
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
|
||||||
OpBuilder&, ValueRange operands,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
DynamicBroadcastInDimOp::Adaptor adaptor(operands);
|
DynamicBroadcastInDimOp::Adaptor adaptor(operands);
|
||||||
reifiedReturnShapes.push_back(adaptor.output_dimensions());
|
reifiedReturnShapes.push_back(
|
||||||
|
castToIndexTensor(builder, getLoc(), adaptor.output_dimensions()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1627,10 +1650,11 @@ static LogicalResult Verify(DynamicReshapeOp op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
|
LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
|
||||||
OpBuilder&, ValueRange operands,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
DynamicReshapeOp::Adaptor adaptor(operands);
|
DynamicReshapeOp::Adaptor adaptor(operands);
|
||||||
reifiedReturnShapes.push_back(adaptor.output_shape());
|
reifiedReturnShapes.push_back(
|
||||||
|
castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
// RUN: mlir-hlo-opt -resolve-shaped-type-result-dims -canonicalize \
|
||||||
|
// RUN: -split-input-file %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @dynamic_broadcast_i32_shape
|
||||||
|
func @dynamic_broadcast_i32_shape(%arg0 : tensor<?xi32>, %arg1 : tensor<*xf32>)
|
||||||
|
-> index {
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<?xi32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %arg0)
|
||||||
|
{ broadcast_dimensions = dense<0> : tensor<1xi64> }
|
||||||
|
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
|
%1 = memref.dim %0, %c0 : tensor<*xf32>
|
||||||
|
return %1 : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @dynamic_iota_i32_shape
|
||||||
|
func @dynamic_iota_i32_shape(%arg0 : tensor<?xi32>) -> index {
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<?xi32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%0 = "mhlo.dynamic_iota"(%arg0)
|
||||||
|
{iota_dimension = 0 : i64}
|
||||||
|
: (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
%1 = memref.dim %0, %c0 : tensor<?xi32>
|
||||||
|
return %1 : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @dynamic_reshape_i32_shape
|
||||||
|
func @dynamic_reshape_i32_shape(%arg0 : tensor<?xi32>, %arg1 : tensor<*xf32>)
|
||||||
|
-> index {
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<?xi32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%0 = "mhlo.dynamic_reshape"(%arg1, %arg0)
|
||||||
|
{ broadcast_dimensions = dense<0> : tensor<1xi64> }
|
||||||
|
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
|
%1 = memref.dim %0, %c0 : tensor<*xf32>
|
||||||
|
return %1 : index
|
||||||
|
}
|
Loading…
Reference in New Issue