[mhlo] Make sure reifyResultTypes returns a tensor of index
Dynamic broadcast/reshape/iota take i32/i64 shape inputs, but users of reification expect index shapes. Insert an appropriate cast if necessary. PiperOrigin-RevId: 380613128
This commit is contained in:
		
							parent
							
								
									a6b8882739
								
							
						
					
					
						commit
						03d2cb606d
					
				|  | @ -642,11 +642,33 @@ void DynamicIotaOp::getCanonicalizationPatterns( | |||
|   results.insert<DynamicIotaBroadcast>(context); | ||||
| } | ||||
| 
 | ||||
| static Value castToIndexTensor(OpBuilder& builder, Location loc, | ||||
|                                Value shape_op) { | ||||
|   ShapedType result_ty = shape::getExtentTensorType( | ||||
|       builder.getContext(), | ||||
|       shape_op.getType().cast<ShapedType>().getDimSize(0)); | ||||
|   if (shape_op.getType() == result_ty) return shape_op;  // Nothing to do.
 | ||||
|   // index_cast is not defined on tensors, so emit a tensor.generate instead.
 | ||||
|   return builder.create<tensor::GenerateOp>( | ||||
|       loc, result_ty, | ||||
|       result_ty.hasStaticShape() | ||||
|           ? ValueRange{} | ||||
|           : ValueRange{builder.create<memref::DimOp>(loc, shape_op, 0)}, | ||||
|       [&](OpBuilder& b, Location loc, ValueRange args) { | ||||
|         Value dim = args.front(); | ||||
|         Value extent = b.create<tensor::ExtractOp>(loc, shape_op, dim); | ||||
|         Value casted = | ||||
|             b.create<IndexCastOp>(loc, extent, result_ty.getElementType()); | ||||
|         b.create<tensor::YieldOp>(loc, casted); | ||||
|       }); | ||||
| } | ||||
| 
 | ||||
| LogicalResult DynamicIotaOp::reifyReturnTypeShapes( | ||||
|     OpBuilder&, ValueRange operands, | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   DynamicIotaOp::Adaptor adaptor(operands); | ||||
|   reifiedReturnShapes.push_back(adaptor.output_shape()); | ||||
|   reifiedReturnShapes.push_back( | ||||
|       castToIndexTensor(builder, getLoc(), adaptor.output_shape())); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
|  | @ -1192,10 +1214,11 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns( | |||
| } | ||||
| 
 | ||||
| LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( | ||||
|     OpBuilder&, ValueRange operands, | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   DynamicBroadcastInDimOp::Adaptor adaptor(operands); | ||||
|   reifiedReturnShapes.push_back(adaptor.output_dimensions()); | ||||
|   reifiedReturnShapes.push_back( | ||||
|       castToIndexTensor(builder, getLoc(), adaptor.output_dimensions())); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
|  | @ -1627,10 +1650,11 @@ static LogicalResult Verify(DynamicReshapeOp op) { | |||
| } | ||||
| 
 | ||||
| LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( | ||||
|     OpBuilder&, ValueRange operands, | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   DynamicReshapeOp::Adaptor adaptor(operands); | ||||
|   reifiedReturnShapes.push_back(adaptor.output_shape()); | ||||
|   reifiedReturnShapes.push_back( | ||||
|       castToIndexTensor(builder, getLoc(), adaptor.output_shape())); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,50 @@ | |||
| // RUN: mlir-hlo-opt -resolve-shaped-type-result-dims -canonicalize \ | ||||
| // RUN: -split-input-file %s -o - | FileCheck %s | ||||
| 
 | ||||
| // CHECK-LABEL: @dynamic_broadcast_i32_shape | ||||
| func @dynamic_broadcast_i32_shape(%arg0 : tensor<?xi32>, %arg1 : tensor<*xf32>) | ||||
|      -> index { | ||||
|   // CHECK: %[[C0:.*]] = constant 0 : index | ||||
|   // CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<?xi32> | ||||
|   // CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index | ||||
|   // CHECK: return %[[RESULT]] | ||||
|   %c0 = constant 0 : index | ||||
|   %0 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %arg0) | ||||
|        { broadcast_dimensions = dense<0> : tensor<1xi64> } | ||||
|      : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32> | ||||
|   %1 = memref.dim %0, %c0 : tensor<*xf32> | ||||
|   return %1 : index | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @dynamic_iota_i32_shape | ||||
| func @dynamic_iota_i32_shape(%arg0 : tensor<?xi32>) -> index { | ||||
|   // CHECK: %[[C0:.*]] = constant 0 : index | ||||
|   // CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<?xi32> | ||||
|   // CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index | ||||
|   // CHECK: return %[[RESULT]] | ||||
|   %c0 = constant 0 : index | ||||
|   %0 = "mhlo.dynamic_iota"(%arg0) | ||||
|        {iota_dimension = 0 : i64} | ||||
|      : (tensor<?xi32>) -> tensor<?xi32> | ||||
|   %1 = memref.dim %0, %c0 : tensor<?xi32> | ||||
|   return %1 : index | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @dynamic_reshape_i32_shape | ||||
| func @dynamic_reshape_i32_shape(%arg0 : tensor<?xi32>, %arg1 : tensor<*xf32>) | ||||
|      -> index { | ||||
|   // CHECK: %[[C0:.*]] = constant 0 : index | ||||
|   // CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor<?xi32> | ||||
|   // CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index | ||||
|   // CHECK: return %[[RESULT]] | ||||
|   %c0 = constant 0 : index | ||||
|   %0 = "mhlo.dynamic_reshape"(%arg1, %arg0) | ||||
|        { broadcast_dimensions = dense<0> : tensor<1xi64> } | ||||
|      : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32> | ||||
|   %1 = memref.dim %0, %c0 : tensor<*xf32> | ||||
|   return %1 : index | ||||
| } | ||||
		Loading…
	
		Reference in New Issue