Lower ReluGrad via chlo::BroadcastSelect.
This allows to get rid of the constraint that it needs to have a static shape. PiperOrigin-RevId: 371862452
This commit is contained in:
		
							parent
							
								
									4bf1904c86
								
							
						
					
					
						commit
						384b87fad0
					
				|  | @ -81,8 +81,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> { | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite(OpTy op, |   LogicalResult matchAndRewrite(OpTy op, | ||||||
|                                 PatternRewriter &rewriter) const override { |                                 PatternRewriter &rewriter) const override { | ||||||
|     // Don't apply conversion unless all operands are unranked.
 |     // Only apply conversion if at least one operand is unranked.
 | ||||||
|     if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) { |     if (!llvm::any_of(op.getOperation()->getOperands(), [&](Value operand) { | ||||||
|           return operand.getType().isa<UnrankedTensorType>(); |           return operand.getType().isa<UnrankedTensorType>(); | ||||||
|         })) { |         })) { | ||||||
|       return failure(); |       return failure(); | ||||||
|  | @ -227,15 +227,21 @@ struct ConvertUnrankedDynamicBroadcastNaryOp | ||||||
|     typename ChloOpTy::Adaptor transformed(operands); |     typename ChloOpTy::Adaptor transformed(operands); | ||||||
|     ValueRange transformed_operands = transformed.getOperands(); |     ValueRange transformed_operands = transformed.getOperands(); | ||||||
|     auto num_operands = transformed_operands.size(); |     auto num_operands = transformed_operands.size(); | ||||||
|     llvm::SmallVector<UnrankedTensorType, 3> operand_types; |     llvm::SmallVector<Type, 3> operand_element_types; | ||||||
|     operand_types.reserve(num_operands); |     operand_element_types.reserve(num_operands); | ||||||
|  |     bool has_unranked_tensor_type = false; | ||||||
|     for (int i = 0; i < num_operands; ++i) { |     for (int i = 0; i < num_operands; ++i) { | ||||||
|       auto type = |       if (auto type = | ||||||
|           transformed_operands[i].getType().dyn_cast<UnrankedTensorType>(); |               transformed_operands[i].getType().dyn_cast<TensorType>()) { | ||||||
|       // Only support unranked operands.
 |         if (type.isa<UnrankedTensorType>()) { | ||||||
|       if (!type) return failure(); |           has_unranked_tensor_type = true; | ||||||
|       operand_types.push_back(type); |         } | ||||||
|  |         operand_element_types.push_back(type.getElementType()); | ||||||
|  |       } else { | ||||||
|  |         return failure(); | ||||||
|  |       } | ||||||
|     } |     } | ||||||
|  |     if (!has_unranked_tensor_type) return failure(); | ||||||
|     auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); |     auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); | ||||||
| 
 | 
 | ||||||
|     llvm::SmallVector<Value> shapes; |     llvm::SmallVector<Value> shapes; | ||||||
|  | @ -278,7 +284,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp | ||||||
|           if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>( |           if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>( | ||||||
|               loc, |               loc, | ||||||
|               RankedTensorType::get({RankedTensorType::kDynamicSize}, |               RankedTensorType::get({RankedTensorType::kDynamicSize}, | ||||||
|                                     operand_types[i].getElementType()), |                                     operand_element_types[i]), | ||||||
|               transformed_operands[i], size_tensor); |               transformed_operands[i], size_tensor); | ||||||
|       reshaped_operands.push_back(reshaped); |       reshaped_operands.push_back(reshaped); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | @ -64,6 +64,28 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
|  | // Transformed if there is a mix of unranked/static shapes. | ||||||
|  | // CHECK-LABEL: @select_mixed | ||||||
|  | // CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>) | ||||||
|  | func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>)  -> tensor<*xf32> { | ||||||
|  |   // CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]] | ||||||
|  |   // CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]] | ||||||
|  |   // CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]] | ||||||
|  |   // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]] | ||||||
|  |   // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] | ||||||
|  |   // CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> | ||||||
|  |   // CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1> | ||||||
|  |   // CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||||
|  |   // CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||||
|  |   // CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> | ||||||
|  |   // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||||
|  |   // CHECK: return %[[RESULT]] : tensor<*xf32> | ||||||
|  |   %b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> | ||||||
|  |   return %b : tensor<*xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: @add_unranked | // CHECK-LABEL: @add_unranked | ||||||
| // CHECK-SAME:  (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32> | // CHECK-SAME:  (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32> | ||||||
| func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { | func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue