Lower ReluGrad via chlo::BroadcastSelect.
This allows to get rid of the constraint that it needs to have a static shape. PiperOrigin-RevId: 371862452
This commit is contained in:
		
							parent
							
								
									4bf1904c86
								
							
						
					
					
						commit
						384b87fad0
					
				|  | @ -81,8 +81,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> { | |||
| 
 | ||||
|   LogicalResult matchAndRewrite(OpTy op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     // Don't apply conversion unless all operands are unranked.
 | ||||
|     if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) { | ||||
|     // Only apply conversion if at least one operand is unranked.
 | ||||
|     if (!llvm::any_of(op.getOperation()->getOperands(), [&](Value operand) { | ||||
|           return operand.getType().isa<UnrankedTensorType>(); | ||||
|         })) { | ||||
|       return failure(); | ||||
|  | @ -227,15 +227,21 @@ struct ConvertUnrankedDynamicBroadcastNaryOp | |||
|     typename ChloOpTy::Adaptor transformed(operands); | ||||
|     ValueRange transformed_operands = transformed.getOperands(); | ||||
|     auto num_operands = transformed_operands.size(); | ||||
|     llvm::SmallVector<UnrankedTensorType, 3> operand_types; | ||||
|     operand_types.reserve(num_operands); | ||||
|     llvm::SmallVector<Type, 3> operand_element_types; | ||||
|     operand_element_types.reserve(num_operands); | ||||
|     bool has_unranked_tensor_type = false; | ||||
|     for (int i = 0; i < num_operands; ++i) { | ||||
|       auto type = | ||||
|           transformed_operands[i].getType().dyn_cast<UnrankedTensorType>(); | ||||
|       // Only support unranked operands.
 | ||||
|       if (!type) return failure(); | ||||
|       operand_types.push_back(type); | ||||
|       if (auto type = | ||||
|               transformed_operands[i].getType().dyn_cast<TensorType>()) { | ||||
|         if (type.isa<UnrankedTensorType>()) { | ||||
|           has_unranked_tensor_type = true; | ||||
|         } | ||||
|         operand_element_types.push_back(type.getElementType()); | ||||
|       } else { | ||||
|         return failure(); | ||||
|       } | ||||
|     } | ||||
|     if (!has_unranked_tensor_type) return failure(); | ||||
|     auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); | ||||
| 
 | ||||
|     llvm::SmallVector<Value> shapes; | ||||
|  | @ -278,7 +284,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp | |||
|           if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>( | ||||
|               loc, | ||||
|               RankedTensorType::get({RankedTensorType::kDynamicSize}, | ||||
|                                     operand_types[i].getElementType()), | ||||
|                                     operand_element_types[i]), | ||||
|               transformed_operands[i], size_tensor); | ||||
|       reshaped_operands.push_back(reshaped); | ||||
|     } | ||||
|  |  | |||
|  | @ -64,6 +64,28 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // Transformed if there is a mix of unranked/static shapes. | ||||
| // CHECK-LABEL: @select_mixed | ||||
| // CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>) | ||||
| func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>)  -> tensor<*xf32> { | ||||
|   // CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]] | ||||
|   // CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]] | ||||
|   // CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]] | ||||
|   // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]] | ||||
|   // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] | ||||
|   // CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> | ||||
|   // CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1> | ||||
|   // CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
|   // CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
|   // CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> | ||||
|   // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
|   // CHECK: return %[[RESULT]] : tensor<*xf32> | ||||
|   %b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> | ||||
|   return %b : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @add_unranked | ||||
| // CHECK-SAME:  (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32> | ||||
| func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue