Move unranked chlo lowering to transform_unranked_hlo.
Additionally: - Forward listeners through new if/else op builders. This corrects an error that led to incomplete legalization of broadcasted op lowering. - Use OpConversionPattern to ensure up to date operand values are used. PiperOrigin-RevId: 339838833
This commit is contained in:
		
							parent
							
								
									e188ef10f2
								
							
						
					
					
						commit
						76b30fd426
					
				|  | @ -0,0 +1,97 @@ | |||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_ | ||||
| #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_ | ||||
| 
 | ||||
| #include <type_traits> | ||||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir/IR/PatternMatch.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| namespace chlo { | ||||
| 
 | ||||
| struct HloComplexAdaptor { | ||||
|   static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, | ||||
|                                   Value broadcasted_lhs, Value broadcasted_rhs, | ||||
|                                   OpBuilder &builder) { | ||||
|     return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type, | ||||
|                                            broadcasted_lhs, broadcasted_rhs); | ||||
|   } | ||||
| }; | ||||
| template <typename FromOpTy, typename ToOpTy> | ||||
| struct HloBinaryElementwiseAdaptor { | ||||
|   static ToOpTy CreateOp(FromOpTy from_op, Type result_type, | ||||
|                          Value broadcasted_lhs, Value broadcasted_rhs, | ||||
|                          OpBuilder &builder) { | ||||
|     return builder.create<ToOpTy>(from_op.getLoc(), result_type, | ||||
|                                   broadcasted_lhs, broadcasted_rhs); | ||||
|   } | ||||
| }; | ||||
| struct HloCompareAdaptor { | ||||
|   static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, | ||||
|                                   Value broadcasted_lhs, Value broadcasted_rhs, | ||||
|                                   OpBuilder &builder) { | ||||
|     return builder.create<mhlo::CompareOp>( | ||||
|         from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, | ||||
|         from_op.comparison_direction(), from_op.compare_typeAttr()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Populate a pattern for each Broadcasting CHlo op. This requires the pattern
 | ||||
| // to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
 | ||||
| template <template <typename, typename, typename> class Pattern, | ||||
|           typename... ConstructorArgs> | ||||
| void PopulateForBroadcastingBinaryOp(MLIRContext *context, | ||||
|                                      OwningRewritePatternList *patterns, | ||||
|                                      ConstructorArgs &&...args) { | ||||
| #define POPULATE_BCAST(ChloOp, HloOp)                                      \ | ||||
|   patterns->insert<                                                        \ | ||||
|       Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \ | ||||
|       context, args...); | ||||
| 
 | ||||
|   POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp); | ||||
|   POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp); | ||||
|   POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op); | ||||
|   POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp); | ||||
|   POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp); | ||||
|   POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp); | ||||
|   POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp); | ||||
|   POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp); | ||||
|   POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp); | ||||
|   POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp); | ||||
|   POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp); | ||||
|   POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp); | ||||
|   POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp); | ||||
|   POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp); | ||||
|   POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp); | ||||
| 
 | ||||
|   // Broadcasting ops requiring special construction.
 | ||||
|   patterns | ||||
|       ->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>( | ||||
|           context, args...); | ||||
|   patterns | ||||
|       ->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>( | ||||
|           context, args...); | ||||
| 
 | ||||
| #undef POPULATE_BCAST | ||||
| } | ||||
| 
 | ||||
| }  // namespace chlo
 | ||||
| }  // namespace mlir
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H_
 | ||||
|  | @ -17,6 +17,7 @@ limitations under the License. | |||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | ||||
| #include "mlir-hlo/utils/broadcast_utils.h" | ||||
| #include "mlir/Dialect/SCF/SCF.h" | ||||
|  | @ -69,13 +70,18 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> { | |||
| // Converts binary ops that statically are determined to not broadcast directly
 | ||||
| // to the corresponding mhlo non-broadcasting op.
 | ||||
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | ||||
| struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> { | ||||
|   using OpRewritePattern<ChloOpTy>::OpRewritePattern; | ||||
|   LogicalResult matchAndRewrite(ChloOpTy op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
| struct ConvertTrivialNonBroadcastBinaryOp | ||||
|     : public OpConversionPattern<ChloOpTy> { | ||||
|   using OpConversionPattern<ChloOpTy>::OpConversionPattern; | ||||
|   LogicalResult matchAndRewrite( | ||||
|       ChloOpTy op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     // Only rewrite for statically determinable non-broadcasting cases.
 | ||||
|     auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>(); | ||||
|     auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>(); | ||||
|     typename ChloOpTy::Adaptor transformed(operands); | ||||
|     auto lhs_type = | ||||
|         transformed.lhs().getType().template dyn_cast<RankedTensorType>(); | ||||
|     auto rhs_type = | ||||
|         transformed.rhs().getType().template dyn_cast<RankedTensorType>(); | ||||
|     if (!lhs_type || !rhs_type) return failure(); | ||||
| 
 | ||||
|     // Requires rank broadcast.
 | ||||
|  | @ -93,8 +99,9 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> { | |||
|       } | ||||
|     } | ||||
| 
 | ||||
|     rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(), | ||||
|                                               op.lhs(), op.rhs(), rewriter)}); | ||||
|     rewriter.replaceOp( | ||||
|         op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0], | ||||
|                                operands[1], rewriter)}); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  | @ -113,13 +120,15 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> { | |||
| // `shape.broadcast` op, which only supports prefix-padding.
 | ||||
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | ||||
| struct ConvertRankedDynamicBroadcastBinaryOp | ||||
|     : public OpRewritePattern<ChloOpTy> { | ||||
|   using OpRewritePattern<ChloOpTy>::OpRewritePattern; | ||||
|   LogicalResult matchAndRewrite(ChloOpTy op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     : public OpConversionPattern<ChloOpTy> { | ||||
|   using OpConversionPattern<ChloOpTy>::OpConversionPattern; | ||||
|   LogicalResult matchAndRewrite( | ||||
|       ChloOpTy op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     // Only support ranked operands.
 | ||||
|     Value lhs = op.lhs(); | ||||
|     Value rhs = op.rhs(); | ||||
|     typename ChloOpTy::Adaptor transformed(operands); | ||||
|     Value lhs = transformed.lhs(); | ||||
|     Value rhs = transformed.rhs(); | ||||
|     auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>(); | ||||
|     auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>(); | ||||
|     auto result_type = | ||||
|  | @ -193,324 +202,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Converts a broadcasting binary operation with a scalar operand and an
 | ||||
| // unranked operand to a ranked broadcasting operation by dynamically reshaping
 | ||||
| // the unranked operand to a 1D tensor. This will always be safe because
 | ||||
| // broadcasting from a scalar to another shape always works.
 | ||||
| template <typename ChloOpTy, typename HloOpTy> | ||||
| struct ConvertUnrankedScalarDynamicBroadcastBinaryOp | ||||
|     : public OpRewritePattern<ChloOpTy> { | ||||
|   using OpRewritePattern<ChloOpTy>::OpRewritePattern; | ||||
|   LogicalResult matchAndRewrite(ChloOpTy op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     auto loc = op.getLoc(); | ||||
|     Value lhs = op.lhs(); | ||||
|     Value rhs = op.rhs(); | ||||
| 
 | ||||
|     auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>(); | ||||
|     auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>(); | ||||
| 
 | ||||
|     auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>(); | ||||
|     auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>(); | ||||
| 
 | ||||
|     bool lhs_is_scalar = lhs_ranked_type && | ||||
|                          lhs_ranked_type.getShape().empty() && | ||||
|                          rhs_unranked_type; | ||||
|     bool rhs_is_scalar = rhs_ranked_type && | ||||
|                          rhs_ranked_type.getShape().empty() && | ||||
|                          lhs_unranked_type; | ||||
| 
 | ||||
|     // Only support the case where exactly one operand is scalar and the other
 | ||||
|     // is unranked. Other patterns in this file will create more efficient
 | ||||
|     // lowerings for cases where both ranks are known or will handle the more
 | ||||
|     // generic case of both inputs being unranked.
 | ||||
|     if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure(); | ||||
| 
 | ||||
|     auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); | ||||
| 
 | ||||
|     // Reshape the non-scalar value into a dynamically sized, rank-1 tensor
 | ||||
|     Value shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs); | ||||
|     Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape); | ||||
|     Value size_tensor = | ||||
|         rewriter.create<TensorFromElementsOp>(loc, num_elements); | ||||
|     Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, RankedTensorType::get({-1}, result_type.getElementType()), | ||||
|         lhs_is_scalar ? rhs : lhs, size_tensor); | ||||
| 
 | ||||
|     // Create a new ranked Chlo op that will be further lowered by other
 | ||||
|     // patterns into Mhlo.
 | ||||
|     SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped, | ||||
|                                    rhs_is_scalar ? rhs : reshaped}; | ||||
|     Value computed = rewriter.create<ChloOpTy>( | ||||
|         loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs()); | ||||
| 
 | ||||
|     // Reshape the result back into an unranked tensor.
 | ||||
|     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type, | ||||
|                                                         computed, shape); | ||||
| 
 | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Handles lowering of the following pattern to patterns that will be further
 | ||||
| // matched by other patterns until they result in LHLO:
 | ||||
| //   %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
 | ||||
| //
 | ||||
| // The sequence of specializations this handles is:
 | ||||
| //   - Either operand being scalar
 | ||||
| //   - Operands having equal shapes
 | ||||
| //   - The resulting value being any of ranks [2,6]
 | ||||
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | ||||
| struct ConvertUnrankedDynamicBroadcastBinaryOp | ||||
|     : public OpRewritePattern<ChloOpTy> { | ||||
|   using OpRewritePattern<ChloOpTy>::OpRewritePattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite(ChloOpTy op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     auto loc = op.getLoc(); | ||||
|     Value lhs = op.lhs(); | ||||
|     Value rhs = op.rhs(); | ||||
|     auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>(); | ||||
|     auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>(); | ||||
|     auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); | ||||
| 
 | ||||
|     // Only support unranked operands. If either operand is ranked, another
 | ||||
|     // pattern will handle the lowering.
 | ||||
|     if (!lhs_type || !rhs_type) return failure(); | ||||
| 
 | ||||
|     // If lhs is scalar
 | ||||
|     auto if_op = rewriter.create<scf::IfOp>( | ||||
|         loc, result_type, IsScalarTensor(rewriter, op, lhs), true); | ||||
|     OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(); | ||||
|     Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>( | ||||
|         loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); | ||||
|     Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs}, | ||||
|         op.getAttrs()); | ||||
|     if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result); | ||||
| 
 | ||||
|     // If lhs is NOT scalar
 | ||||
|     //
 | ||||
|     // See if rhs is scalar
 | ||||
|     OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder(); | ||||
|     auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>( | ||||
|         loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs), | ||||
|         true); | ||||
|     else_lhs_scalar_builder.create<scf::YieldOp>(loc, | ||||
|                                                  if_rhs_scalar_op.getResult(0)); | ||||
|     OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(); | ||||
|     Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>( | ||||
|         loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); | ||||
|     Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs}, | ||||
|         op.getAttrs()); | ||||
|     if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result); | ||||
| 
 | ||||
|     // If NEITHER shape is scalar
 | ||||
|     //
 | ||||
|     // See if shapes are equal.
 | ||||
|     OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder(); | ||||
|     Value shape_of_lhs = | ||||
|         else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs); | ||||
|     Value shape_of_rhs = | ||||
|         else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs); | ||||
|     Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>( | ||||
|         loc, shape_of_lhs, shape_of_rhs); | ||||
| 
 | ||||
|     auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>( | ||||
|         loc, result_type, equal_shapes, true); | ||||
|     else_no_scalars_builder.create<scf::YieldOp>(loc, | ||||
|                                                  if_eq_shapes_op.getResult(0)); | ||||
| 
 | ||||
|     OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder(); | ||||
|     Value non_broadcast_op = | ||||
|         Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder); | ||||
|     if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op); | ||||
| 
 | ||||
|     // If shapes are not scalar, nor equal
 | ||||
|     //
 | ||||
|     // See if values are of a rank that we support.
 | ||||
|     OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder(); | ||||
|     if_neq_shapes_builder.create<scf::YieldOp>( | ||||
|         loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs)); | ||||
| 
 | ||||
|     rewriter.replaceOp(op, {if_op.getResult(0)}); | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   // Returns the dyanamic result of checking the given value is a scalar
 | ||||
|   // tensor.
 | ||||
|   Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor); | ||||
|     Value rank_tensor = rewriter.create<shape::RankOp>( | ||||
|         loc, rewriter.getIndexType(), shape_of_tensor); | ||||
|     return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq, | ||||
|                                    rank_tensor, | ||||
|                                    rewriter.create<ConstantIndexOp>(loc, 0)); | ||||
|   } | ||||
| 
 | ||||
|   // Create the if statement and code for a broadcasting op with a result of a
 | ||||
|   // given rank.
 | ||||
|   scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op, | ||||
|                                                 Value lhs, Value rhs, | ||||
|                                                 Value actual_rank, | ||||
|                                                 int targeted_rank) const { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Create the if block to place the current specialized logic in.
 | ||||
|     Value greater_rank_is_n = builder.create<CmpIOp>( | ||||
|         loc, CmpIPredicate::eq, actual_rank, | ||||
|         builder.create<ConstantIndexOp>(loc, targeted_rank)); | ||||
|     auto if_op = | ||||
|         builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true); | ||||
|     OpBuilder if_builder = if_op.getThenBodyBuilder(); | ||||
| 
 | ||||
|     // Handle shape broadcasting and inferrence.
 | ||||
|     Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs); | ||||
|     Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs); | ||||
|     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); | ||||
|     auto unknown_rank_extent_tensor_type = RankedTensorType::get( | ||||
|         {RankedTensorType::kDynamicSize}, builder.getIndexType()); | ||||
|     auto known_rank_extent_tensor_type = | ||||
|         RankedTensorType::get({targeted_rank}, builder.getIndexType()); | ||||
|     auto reshaped_type = RankedTensorType::get( | ||||
|         llvm::SmallVector<int64_t, 6>(targeted_rank, | ||||
|                                       RankedTensorType::kDynamicSize), | ||||
|         lhs.getType().template dyn_cast<TensorType>().getElementType()); | ||||
|     Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>( | ||||
|         loc, known_rank_extent_tensor_type, | ||||
|         mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, | ||||
|                                         ranked_shape)); | ||||
|     Value extended_lhs = if_builder.create<shape::BroadcastOp>( | ||||
|         loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val, | ||||
|         nullptr); | ||||
|     Value extended_lhs_casted = if_builder.create<TensorCastOp>( | ||||
|         loc, known_rank_extent_tensor_type, extended_lhs); | ||||
|     Value extended_rhs = if_builder.create<shape::BroadcastOp>( | ||||
|         loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val, | ||||
|         nullptr); | ||||
|     Value extended_rhs_casted = if_builder.create<TensorCastOp>( | ||||
|         loc, known_rank_extent_tensor_type, extended_rhs); | ||||
| 
 | ||||
|     // 1. Reshape operands to the given rank (with the same number of elements)
 | ||||
|     // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
 | ||||
|     //    can be broadcasted and do the actual broadcasting)
 | ||||
|     // 3. Type erase the output back to unranked
 | ||||
|     Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, reshaped_type, lhs, extended_lhs_casted); | ||||
|     Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, reshaped_type, rhs, extended_rhs_casted); | ||||
|     Value result = if_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{reshaped_type}, | ||||
|         ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs()); | ||||
|     Value reshaped_result = if_builder.create<TensorCastOp>( | ||||
|         loc, UnrankedTensorType::get(reshaped_type.getElementType()), result); | ||||
|     if_builder.create<scf::YieldOp>(loc, reshaped_result); | ||||
| 
 | ||||
|     // Return the if_op, so the result can be used and the else block can be
 | ||||
|     // used for the next rank specialized step.
 | ||||
|     return if_op; | ||||
|   } | ||||
| 
 | ||||
|   // Iterates over the desired ranks to be specialized and generates the code
 | ||||
|   // snippet for each case.
 | ||||
|   Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs, | ||||
|                              Value rhs) const { | ||||
|     constexpr int max_rank_specialization = 7; | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Find the larger rank of the 2 operands.
 | ||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                                     rewriter.getIndexType()); | ||||
|     Value lhs_shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs); | ||||
|     Value rhs_shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs); | ||||
|     Value lhs_rank = | ||||
|         rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape); | ||||
|     Value rhs_rank = | ||||
|         rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape); | ||||
|     Value greater_rank_lhs = | ||||
|         rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); | ||||
|     Value greater_rank = | ||||
|         rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank); | ||||
| 
 | ||||
|     // Generate a list of nested if/else statements to handle rank
 | ||||
|     // specializations from 2-6.
 | ||||
|     scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs, | ||||
|                                                           rhs, greater_rank, 2); | ||||
| 
 | ||||
|     // Put each subsequent rank specialization inside the else statement of the
 | ||||
|     // previous one.
 | ||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(); | ||||
|     for (int i = 3; i < max_rank_specialization; i++) { | ||||
|       auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs, | ||||
|                                                           rhs, greater_rank, i); | ||||
| 
 | ||||
|       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); | ||||
|       else_builder = inner_if.getElseBodyBuilder(); | ||||
|     } | ||||
| 
 | ||||
|     // Fire an assertion if none of the rank specializations applied (one of the
 | ||||
|     // ranks was greater than 6).
 | ||||
|     else_builder.create<AssertOp>( | ||||
|         loc, else_builder.create<ConstantIntOp>(loc, 0, 1), | ||||
|         "Input for dynamic binary op lowering was of a rank greater than 6"); | ||||
|     else_builder.create<scf::YieldOp>(loc, lhs); | ||||
| 
 | ||||
|     // Return the result of the outermost if statement.
 | ||||
|     return if_op.getResult(0); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | ||||
| void PopulateForBinaryOp(MLIRContext *context, | ||||
|                          OwningRewritePatternList *patterns) { | ||||
|   patterns | ||||
|       ->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>( | ||||
|           context, 10); | ||||
|   patterns->insert< | ||||
|       ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>( | ||||
|       context, 5); | ||||
|   patterns->insert< | ||||
|       ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>, | ||||
|       ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>( | ||||
|       context); | ||||
| } | ||||
| 
 | ||||
| template <typename FromOpTy, typename ToOpTy> | ||||
| struct HloBinaryElementwiseAdaptor { | ||||
|   static ToOpTy CreateOp(FromOpTy from_op, Type result_type, | ||||
|                          Value broadcasted_lhs, Value broadcasted_rhs, | ||||
|                          OpBuilder &builder) { | ||||
|     return builder.create<ToOpTy>(from_op.getLoc(), result_type, | ||||
|                                   broadcasted_lhs, broadcasted_rhs); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct HloComplexAdaptor { | ||||
|   static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, | ||||
|                                   Value broadcasted_lhs, Value broadcasted_rhs, | ||||
|                                   OpBuilder &builder) { | ||||
|     return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type, | ||||
|                                            broadcasted_lhs, broadcasted_rhs); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct HloCompareAdaptor { | ||||
|   static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, | ||||
|                                   Value broadcasted_lhs, Value broadcasted_rhs, | ||||
|                                   OpBuilder &builder) { | ||||
|     return builder.create<mhlo::CompareOp>( | ||||
|         from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, | ||||
|         from_op.comparison_direction(), from_op.compare_typeAttr()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| #include "generated_chlo_legalize_to_hlo.inc" | ||||
| }  // namespace
 | ||||
| 
 | ||||
|  | @ -521,32 +212,10 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, | |||
|   // Instantiate conversion templates for conforming binary elementwise ops
 | ||||
|   // that do not have different dtypes between operands and results and do
 | ||||
|   // not have special attributes that need to be preserved.
 | ||||
| #define POPULATE_BCAST(ChloOp, HloOp)                                      \ | ||||
|   PopulateForBinaryOp<ChloOp, HloOp,                                       \ | ||||
|                       HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \ | ||||
|                                                                   patterns); | ||||
| 
 | ||||
|   POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp); | ||||
|   POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp); | ||||
|   POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op); | ||||
|   POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp); | ||||
|   POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp); | ||||
|   POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp); | ||||
|   POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp); | ||||
|   POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp); | ||||
|   POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp); | ||||
|   POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp); | ||||
|   POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp); | ||||
|   POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp); | ||||
|   POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp); | ||||
|   POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp); | ||||
|   POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp); | ||||
| 
 | ||||
|   // Broadcasting ops requiring special construction.
 | ||||
|   PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>( | ||||
|       context, patterns); | ||||
|   PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>( | ||||
|       context, patterns); | ||||
|   PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>( | ||||
|       context, patterns, 10); | ||||
|   PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>( | ||||
|       context, patterns, 5); | ||||
| 
 | ||||
|   // Other patterns.
 | ||||
|   patterns->insert<ConvertConstantLikeOp>(context); | ||||
|  |  | |||
|  | @ -16,7 +16,9 @@ limitations under the License. | |||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | ||||
| #include "mlir/Dialect/SCF/SCF.h" | ||||
| #include "mlir/Dialect/Shape/IR/Shape.h" | ||||
| #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||
| #include "mlir/IR/Function.h" | ||||
|  | @ -126,6 +128,291 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> { | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Converts a broadcasting binary operation with a scalar operand and an
 | ||||
| // unranked operand to a ranked broadcasting operation by dynamically reshaping
 | ||||
| // the unranked operand to a 1D tensor. This will always be safe because
 | ||||
| // broadcasting from a scalar to another shape always works.
 | ||||
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | ||||
| struct ConvertUnrankedScalarDynamicBroadcastBinaryOp | ||||
|     : public OpConversionPattern<ChloOpTy> { | ||||
|   using OpConversionPattern<ChloOpTy>::OpConversionPattern; | ||||
|   LogicalResult matchAndRewrite( | ||||
|       ChloOpTy op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     auto loc = op.getLoc(); | ||||
|     typename ChloOpTy::Adaptor transformed(operands); | ||||
|     Value lhs = transformed.lhs(); | ||||
|     Value rhs = transformed.rhs(); | ||||
| 
 | ||||
|     auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>(); | ||||
|     auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>(); | ||||
| 
 | ||||
|     auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>(); | ||||
|     auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>(); | ||||
| 
 | ||||
|     bool lhs_is_scalar = lhs_ranked_type && | ||||
|                          lhs_ranked_type.getShape().empty() && | ||||
|                          rhs_unranked_type; | ||||
|     bool rhs_is_scalar = rhs_ranked_type && | ||||
|                          rhs_ranked_type.getShape().empty() && | ||||
|                          lhs_unranked_type; | ||||
| 
 | ||||
|     // Only support the case where exactly one operand is scalar and the other
 | ||||
|     // is unranked. Other patterns in chlo-to-hlo legalization will create more
 | ||||
|     // efficient lowerings for cases where both ranks are known or will handle
 | ||||
|     // the more generic case of both inputs being unranked.
 | ||||
|     if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure(); | ||||
| 
 | ||||
|     auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); | ||||
| 
 | ||||
|     // Reshape the non-scalar value into a dynamically sized, rank-1 tensor
 | ||||
|     Value shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs); | ||||
|     Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape); | ||||
|     Value size_tensor = | ||||
|         rewriter.create<TensorFromElementsOp>(loc, num_elements); | ||||
|     Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, RankedTensorType::get({-1}, result_type.getElementType()), | ||||
|         lhs_is_scalar ? rhs : lhs, size_tensor); | ||||
| 
 | ||||
|     // Create a new ranked Chlo op that will be further lowered by other
 | ||||
|     // patterns into Mhlo.
 | ||||
|     SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped, | ||||
|                                        rhs_is_scalar ? rhs : reshaped}; | ||||
|     Value computed = | ||||
|         rewriter.create<ChloOpTy>(loc, SmallVector<Type, 1>{reshaped.getType()}, | ||||
|                                   new_operands, op.getAttrs()); | ||||
| 
 | ||||
|     // Reshape the result back into an unranked tensor.
 | ||||
|     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type, | ||||
|                                                         computed, shape); | ||||
| 
 | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Handles lowering of the following pattern to patterns that will be further
 | ||||
| // matched by other patterns until they result in LHLO:
 | ||||
| //   %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
 | ||||
| //
 | ||||
| // The sequence of specializations this handles is:
 | ||||
| //   - Either operand being scalar
 | ||||
| //   - Operands having equal shapes
 | ||||
| //   - The resulting value being any of ranks [2,6]
 | ||||
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | ||||
| struct ConvertUnrankedDynamicBroadcastBinaryOp | ||||
|     : public OpConversionPattern<ChloOpTy> { | ||||
|   using OpConversionPattern<ChloOpTy>::OpConversionPattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       ChloOpTy op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     auto loc = op.getLoc(); | ||||
|     typename ChloOpTy::Adaptor transformed(operands); | ||||
|     Value lhs = transformed.lhs(); | ||||
|     Value rhs = transformed.rhs(); | ||||
|     auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>(); | ||||
|     auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>(); | ||||
|     auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); | ||||
| 
 | ||||
|     // Only support unranked operands. If either operand is ranked, another
 | ||||
|     // pattern will handle the lowering.
 | ||||
|     if (!lhs_type || !rhs_type) return failure(); | ||||
| 
 | ||||
|     // If lhs is scalar
 | ||||
|     auto if_op = rewriter.create<scf::IfOp>( | ||||
|         loc, result_type, IsScalarTensor(rewriter, op, lhs), true); | ||||
|     OpBuilder if_lhs_scalar_builder = | ||||
|         if_op.getThenBodyBuilder(rewriter.getListener()); | ||||
|     Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>( | ||||
|         loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); | ||||
|     Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs}, | ||||
|         op.getAttrs()); | ||||
|     if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result); | ||||
| 
 | ||||
|     // If lhs is NOT scalar
 | ||||
|     //
 | ||||
|     // See if rhs is scalar
 | ||||
|     OpBuilder else_lhs_scalar_builder = | ||||
|         if_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>( | ||||
|         loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs), | ||||
|         true); | ||||
|     else_lhs_scalar_builder.create<scf::YieldOp>(loc, | ||||
|                                                  if_rhs_scalar_op.getResult(0)); | ||||
|     OpBuilder if_rhs_scalar_builder = | ||||
|         if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener()); | ||||
|     Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>( | ||||
|         loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); | ||||
|     Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs}, | ||||
|         op.getAttrs()); | ||||
|     if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result); | ||||
| 
 | ||||
|     // If NEITHER shape is scalar
 | ||||
|     //
 | ||||
|     // See if shapes are equal.
 | ||||
|     OpBuilder else_no_scalars_builder = | ||||
|         if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     Value shape_of_lhs = | ||||
|         else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs); | ||||
|     Value shape_of_rhs = | ||||
|         else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs); | ||||
|     Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>( | ||||
|         loc, shape_of_lhs, shape_of_rhs); | ||||
| 
 | ||||
|     auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>( | ||||
|         loc, result_type, equal_shapes, true); | ||||
|     else_no_scalars_builder.create<scf::YieldOp>(loc, | ||||
|                                                  if_eq_shapes_op.getResult(0)); | ||||
| 
 | ||||
|     OpBuilder if_eq_shapes_builder = | ||||
|         if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener()); | ||||
|     Value non_broadcast_op = | ||||
|         Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder); | ||||
|     if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op); | ||||
| 
 | ||||
|     // If shapes are not scalar, nor equal
 | ||||
|     //
 | ||||
|     // See if values are of a rank that we support.
 | ||||
|     OpBuilder if_neq_shapes_builder = | ||||
|         if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     if_neq_shapes_builder.create<scf::YieldOp>( | ||||
|         loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs)); | ||||
| 
 | ||||
|     rewriter.replaceOp(op, {if_op.getResult(0)}); | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   // Returns the dyanamic result of checking the given value is a scalar
 | ||||
|   // tensor.
 | ||||
|   Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor); | ||||
|     Value rank_tensor = rewriter.create<shape::RankOp>( | ||||
|         loc, rewriter.getIndexType(), shape_of_tensor); | ||||
|     return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq, | ||||
|                                    rank_tensor, | ||||
|                                    rewriter.create<ConstantIndexOp>(loc, 0)); | ||||
|   } | ||||
| 
 | ||||
|   // Create the if statement and code for a broadcasting op with a result of a
 | ||||
|   // given rank.
 | ||||
|   scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op, | ||||
|                                                 Value lhs, Value rhs, | ||||
|                                                 Value actual_rank, | ||||
|                                                 int targeted_rank) const { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Create the if block to place the current specialized logic in.
 | ||||
|     Value greater_rank_is_n = builder.create<CmpIOp>( | ||||
|         loc, CmpIPredicate::eq, actual_rank, | ||||
|         builder.create<ConstantIndexOp>(loc, targeted_rank)); | ||||
|     auto if_op = | ||||
|         builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true); | ||||
|     OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener()); | ||||
| 
 | ||||
|     // Handle shape broadcasting and inferrence.
 | ||||
|     Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs); | ||||
|     Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs); | ||||
|     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); | ||||
|     auto unknown_rank_extent_tensor_type = RankedTensorType::get( | ||||
|         {RankedTensorType::kDynamicSize}, builder.getIndexType()); | ||||
|     auto known_rank_extent_tensor_type = | ||||
|         RankedTensorType::get({targeted_rank}, builder.getIndexType()); | ||||
|     auto reshaped_type = RankedTensorType::get( | ||||
|         llvm::SmallVector<int64_t, 6>(targeted_rank, | ||||
|                                       RankedTensorType::kDynamicSize), | ||||
|         lhs.getType().template dyn_cast<TensorType>().getElementType()); | ||||
|     Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>( | ||||
|         loc, known_rank_extent_tensor_type, | ||||
|         mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, | ||||
|                                         ranked_shape)); | ||||
|     Value extended_lhs = if_builder.create<shape::BroadcastOp>( | ||||
|         loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val, | ||||
|         nullptr); | ||||
|     Value extended_lhs_casted = if_builder.create<TensorCastOp>( | ||||
|         loc, known_rank_extent_tensor_type, extended_lhs); | ||||
|     Value extended_rhs = if_builder.create<shape::BroadcastOp>( | ||||
|         loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val, | ||||
|         nullptr); | ||||
|     Value extended_rhs_casted = if_builder.create<TensorCastOp>( | ||||
|         loc, known_rank_extent_tensor_type, extended_rhs); | ||||
| 
 | ||||
|     // 1. Reshape operands to the given rank (with the same number of elements)
 | ||||
|     // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
 | ||||
|     //    can be broadcasted and do the actual broadcasting)
 | ||||
|     // 3. Type erase the output back to unranked
 | ||||
|     Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, reshaped_type, lhs, extended_lhs_casted); | ||||
|     Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, reshaped_type, rhs, extended_rhs_casted); | ||||
|     Value result = if_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{reshaped_type}, | ||||
|         ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs()); | ||||
|     Value reshaped_result = if_builder.create<TensorCastOp>( | ||||
|         loc, UnrankedTensorType::get(reshaped_type.getElementType()), result); | ||||
|     if_builder.create<scf::YieldOp>(loc, reshaped_result); | ||||
| 
 | ||||
|     // Return the if_op, so the result can be used and the else block can be
 | ||||
|     // used for the next rank specialized step.
 | ||||
|     return if_op; | ||||
|   } | ||||
| 
 | ||||
|   // Iterates over the desired ranks to be specialized and generates the code
 | ||||
|   // snippet for each case.
 | ||||
|   Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs, | ||||
|                              Value rhs) const { | ||||
|     constexpr int max_rank_specialization = 7; | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Find the larger rank of the 2 operands.
 | ||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                                     rewriter.getIndexType()); | ||||
|     Value lhs_shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs); | ||||
|     Value rhs_shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs); | ||||
|     Value lhs_rank = | ||||
|         rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape); | ||||
|     Value rhs_rank = | ||||
|         rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape); | ||||
|     Value greater_rank_lhs = | ||||
|         rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); | ||||
|     Value greater_rank = | ||||
|         rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank); | ||||
| 
 | ||||
|     // Generate a list of nested if/else statements to handle rank
 | ||||
|     // specializations from 2-6.
 | ||||
|     scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs, | ||||
|                                                           rhs, greater_rank, 2); | ||||
| 
 | ||||
|     // Put each subsequent rank specialization inside the else statement of the
 | ||||
|     // previous one.
 | ||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     for (int i = 3; i < max_rank_specialization; i++) { | ||||
|       auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs, | ||||
|                                                           rhs, greater_rank, i); | ||||
| 
 | ||||
|       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); | ||||
|       else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); | ||||
|     } | ||||
| 
 | ||||
|     // Fire an assertion if none of the rank specializations applied (one of the
 | ||||
|     // ranks was greater than 6).
 | ||||
|     else_builder.create<AssertOp>( | ||||
|         loc, else_builder.create<ConstantIntOp>(loc, 0, 1), | ||||
|         "Input for dynamic binary op lowering was of a rank greater than 6"); | ||||
|     else_builder.create<scf::YieldOp>(loc, lhs); | ||||
| 
 | ||||
|     // Return the result of the outermost if statement.
 | ||||
|     return if_op.getResult(0); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct TransformUnrankedHloPass | ||||
|     : public PassWrapper<TransformUnrankedHloPass, FunctionPass> { | ||||
|   void getDependentDialects(DialectRegistry ®istry) const override { | ||||
|  | @ -137,7 +424,7 @@ struct TransformUnrankedHloPass | |||
|     MLIRContext &ctx = getContext(); | ||||
|     ConversionTarget target(ctx); | ||||
|     target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect, | ||||
|                            shape::ShapeDialect>(); | ||||
|                            shape::ShapeDialect, scf::SCFDialect>(); | ||||
|     target.addLegalOp<FuncOp>(); | ||||
| #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target) | ||||
| #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target) | ||||
|  | @ -148,6 +435,12 @@ struct TransformUnrankedHloPass | |||
| #undef ADD_LEGAL_CHLO | ||||
|     AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target); | ||||
|     AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target); | ||||
|     target.addDynamicallyLegalDialect<chlo::HloClientDialect>( | ||||
|         [](Operation *op) { | ||||
|           return !llvm::any_of(op->getOperandTypes(), [](Type type) { | ||||
|             return type.isa<UnrankedTensorType>(); | ||||
|           }); | ||||
|         }); | ||||
| 
 | ||||
|     // Populate rewrite patterns.
 | ||||
|     OwningRewritePatternList patterns; | ||||
|  | @ -180,6 +473,10 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, | |||
| #undef MAP_BINARY | ||||
| #undef MAP_CHLO_UNARY | ||||
| #undef COMMA | ||||
|   chlo::PopulateForBroadcastingBinaryOp< | ||||
|       ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns); | ||||
|   chlo::PopulateForBroadcastingBinaryOp< | ||||
|       ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns); | ||||
| } | ||||
| 
 | ||||
| std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() { | ||||
|  |  | |||
|  | @ -237,209 +237,3 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x | |||
|   %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> | ||||
|   return %0 : tensor<4xi1> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { | ||||
|   %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>) | ||||
|                                          -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL:   func @addScalarUnranked( | ||||
| // CHECK-SAME:                            %[[ARG_0:.*]]: tensor<f32>, | ||||
| // CHECK-SAME:                            %[[ARG_1:.*]]: tensor<*xf32> | ||||
| // CHECK-SAME:                            ) -> tensor<*xf32> { | ||||
| //                  First handle the dynamic reshaping of the unranked operand | ||||
| //                  to a 1D tensor. | ||||
| // CHECK:           %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> | ||||
| // CHECK:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index | ||||
| // CHECK:           %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> | ||||
| // CHECK:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| //                  The assuming region is part of the second stage of lowering | ||||
| //                  with ranked broadcasting logic. | ||||
| // CHECK:           %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32> | ||||
| // CHECK:           %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32> | ||||
| // CHECK:           %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]] | ||||
| // CHECK:           %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) { | ||||
| // CHECK:             %[[SCALAR_SHAPE:.*]] = shape.const_shape [] | ||||
| // CHECK:             %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] | ||||
| // CHECK:             %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex> | ||||
| // CHECK:             %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK:             %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK:             %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32> | ||||
| // CHECK:             shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32> | ||||
| // CHECK:           } | ||||
| //                  As part of the unranked logic, the result is reshaped back | ||||
| //                  to an unranked tensor. | ||||
| // CHECK:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK:           return %[[RESHAPED_RESULT]] : tensor<*xf32> | ||||
| // CHECK:         } | ||||
| 
 | ||||
| // ----- | ||||
| func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> { | ||||
|   %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>) | ||||
|                                          -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
| } | ||||
| // CHECK-LABEL:   func @addUnrankedScalar( | ||||
| // CHECK-SAME:                            %[[ARG_0:.*]]: tensor<*xf32>, | ||||
| // CHECK-SAME:                            %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> { | ||||
| //                  First handle the dynamic reshaping of the unranked operand | ||||
| //                  to a 1D tensor. | ||||
| // CHECK:           %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> | ||||
| // CHECK:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index | ||||
| // CHECK:           %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> | ||||
| // CHECK:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| //                  The assuming region is part of the second stage of lowering | ||||
| //                  with ranked broadcasting logic. | ||||
| // CHECK:           %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32> | ||||
| // CHECK:           %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32> | ||||
| // CHECK:           %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] | ||||
| // CHECK:           %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) { | ||||
| // CHECK:             %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]] | ||||
| // CHECK:             %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK:             %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK:             %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32> | ||||
| // CHECK:             shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32> | ||||
| // CHECK:           } | ||||
| //                  As part of the unranked logic, the result is reshaped back | ||||
| //                  to an unranked tensor. | ||||
| // CHECK:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK:           return %[[RESHAPED_RESULT]] : tensor<*xf32> | ||||
| // CHECK:         } | ||||
| 
 | ||||
| // ----- | ||||
| func @addUnrankedUnranked( | ||||
|       %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { | ||||
|   %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) | ||||
|                                          -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL:   func @addUnrankedUnranked( | ||||
| // CHECK-SAME:          %[[LHS:.*]]: tensor<*xf32>, | ||||
| // CHECK-SAME:          %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { | ||||
| // CHECK:           %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex> | ||||
| // CHECK:           %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index | ||||
| // CHECK:           %[[C0:.*]] = constant 0 : index | ||||
| // CHECK:           %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index | ||||
| //                  Handle scalar LHS case | ||||
| // CHECK:           %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { | ||||
| // CHECK:             %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32> | ||||
| // CHECK:             %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32> | ||||
| // CHECK:             scf.yield %[[VAL_10]] : tensor<*xf32> | ||||
| // CHECK:           } else { | ||||
| // CHECK:             %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> | ||||
| // CHECK:             %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index | ||||
| // CHECK:             %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index | ||||
|   //                  Handle scalar RHS case | ||||
| // CHECK:             %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { | ||||
| // CHECK:               %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32> | ||||
| // CHECK:               %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32> | ||||
| // CHECK:               scf.yield %[[VAL_16]] : tensor<*xf32> | ||||
| // CHECK:             } else { | ||||
| // CHECK:               %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> | ||||
|   //                    Handle scalar RHS case | ||||
| // CHECK:               %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { | ||||
| // CHECK:                 %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32> | ||||
| // CHECK:                 scf.yield %[[VAL_19]] : tensor<*xf32> | ||||
| // CHECK:               } else { | ||||
| // CHECK:                 %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex> | ||||
| // CHECK:                 %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex> | ||||
| // CHECK:                 %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index | ||||
| // CHECK:                 %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index | ||||
| // CHECK:                 %[[C2:.*]] = constant 2 : index | ||||
| // CHECK:                 %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index | ||||
| //                        Handle rank 2 specialization | ||||
| // CHECK:                 %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { | ||||
| // CHECK:                   %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] | ||||
| // CHECK:                   %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | ||||
| // CHECK:                   %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex> | ||||
| // CHECK:                   %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | ||||
| // CHECK:                   %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex> | ||||
| // CHECK:                   %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
| // CHECK:                   %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
| // CHECK:                   %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> | ||||
| // CHECK:                   %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32> | ||||
| // CHECK:                   scf.yield %[[RESULT_2]] : tensor<*xf32> | ||||
| // CHECK:                 } else { | ||||
| // CHECK:                   %[[C3:.*]] = constant 3 : index | ||||
| // CHECK:                   %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index | ||||
| //                          Handle rank 3 specialization | ||||
| // CHECK:                   %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { | ||||
| // CHECK:                     %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] | ||||
| // CHECK:                     %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | ||||
| // CHECK:                     %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex> | ||||
| // CHECK:                     %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | ||||
| // CHECK:                     %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex> | ||||
| // CHECK:                     %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||
| // CHECK:                     %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||
| // CHECK:                     %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||
| // CHECK:                     %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32> | ||||
| // CHECK:                     scf.yield %[[RESULT_3]] : tensor<*xf32> | ||||
| // CHECK:                   } else { | ||||
| // CHECK:                     %[[C4:.*]] = constant 4 : index | ||||
| // CHECK:                     %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index | ||||
| //                            Handle rank 4 specialization | ||||
| // CHECK:                     %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { | ||||
| // CHECK:                       %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] | ||||
| // CHECK:                       %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | ||||
| // CHECK:                       %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex> | ||||
| // CHECK:                       %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | ||||
| // CHECK:                       %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex> | ||||
| // CHECK:                       %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK:                       %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK:                       %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK:                       %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32> | ||||
| // CHECK:                       scf.yield %[[RESULT_4]] : tensor<*xf32> | ||||
| // CHECK:                     } else { | ||||
| // CHECK:                       %[[C5:.*]] = constant 5 : index | ||||
| // CHECK:                       %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index | ||||
| //                              Handle rank 5 specialization | ||||
| // CHECK:                       %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { | ||||
| // CHECK:                         %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] | ||||
| // CHECK:                         %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | ||||
| // CHECK:                         %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex> | ||||
| // CHECK:                         %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | ||||
| // CHECK:                         %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex> | ||||
| // CHECK:                         %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK:                         %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK:                         %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK:                         %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32> | ||||
| // CHECK:                         scf.yield %[[RESULT_5]] : tensor<*xf32> | ||||
| // CHECK:                       } else { | ||||
| // CHECK:                         %[[C6:.*]] = constant 6 : index | ||||
| // CHECK:                         %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index | ||||
| //                                Handle rank 6 specialization | ||||
| // CHECK:                         %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) { | ||||
| // CHECK:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] | ||||
| // CHECK:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> | ||||
| // CHECK:                           %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex> | ||||
| // CHECK:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> | ||||
| // CHECK:                           %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex> | ||||
| // CHECK:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK:                           %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32> | ||||
| // CHECK:                           scf.yield %[[RESULT_6]] : tensor<*xf32> | ||||
| // CHECK:                         } else { | ||||
| // CHECK:                           %false = constant false | ||||
| // CHECK:                           assert %false | ||||
| // CHECK:                           scf.yield %[[LHS]] : tensor<*xf32> | ||||
| // CHECK:                         } | ||||
| // CHECK:                         scf.yield %[[VAL_64:.*]] : tensor<*xf32> | ||||
| // CHECK:                       } | ||||
| // CHECK:                       scf.yield %[[VAL_65:.*]] : tensor<*xf32> | ||||
| // CHECK:                     } | ||||
| // CHECK:                     scf.yield %[[VAL_66:.*]] : tensor<*xf32> | ||||
| // CHECK:                   } | ||||
| // CHECK:                   scf.yield %[[VAL_67:.*]] : tensor<*xf32> | ||||
| // CHECK:                 } | ||||
| // CHECK:                 scf.yield %[[VAL_68:.*]] : tensor<*xf32> | ||||
| // CHECK:               } | ||||
| // CHECK:               scf.yield %[[VAL_69:.*]] : tensor<*xf32> | ||||
| // CHECK:             } | ||||
| // CHECK:             scf.yield %[[VAL_70:.*]] : tensor<*xf32> | ||||
| // CHECK:           } | ||||
| // CHECK:           return %[[VAL_71:.*]] : tensor<*xf32> | ||||
| // CHECK:         } | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| // RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s | ||||
| // RUN: mlir-hlo-opt --transform-unranked-hlo --cse --split-input-file %s | FileCheck %s | ||||
| 
 | ||||
| // Check the validity of expected IR. | ||||
| // CHECK-LABEL: @sqr_transform_result | ||||
|  | @ -96,3 +96,203 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> { | |||
|   %result = chlo.tan %a : tensor<*xf32> | ||||
|   return %result : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { | ||||
|   %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>) | ||||
|                                          -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL:   func @addScalarUnranked( | ||||
| // CHECK-SAME:                            %[[ARG_0:.*]]: tensor<f32>, | ||||
| // CHECK-SAME:                            %[[ARG_1:.*]]: tensor<*xf32> | ||||
| // CHECK-SAME:                            ) -> tensor<*xf32> { | ||||
| //                  First handle the dynamic reshaping of the unranked operand | ||||
| //                  to a 1D tensor. | ||||
| // CHECK-NEXT:           %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> | ||||
| // CHECK-NEXT:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index | ||||
| // CHECK-NEXT:           %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> | ||||
| // CHECK-NEXT:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:           %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32> | ||||
| //                  As part of the unranked logic, the result is reshaped back | ||||
| //                  to an unranked tensor. | ||||
| // CHECK-NEXT:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK-NEXT:           return %[[RESHAPED_RESULT]] : tensor<*xf32> | ||||
| // CHECK-NEXT:         } | ||||
| 
 | ||||
| // ----- | ||||
| func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> { | ||||
|   %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>) | ||||
|                                          -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
| } | ||||
| // CHECK-LABEL:   func @addUnrankedScalar( | ||||
| // CHECK-SAME:                            %[[ARG_0:.*]]: tensor<*xf32>, | ||||
| // CHECK-SAME:                            %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> { | ||||
| //                  First handle the dynamic reshaping of the unranked operand | ||||
| //                  to a 1D tensor. | ||||
| // CHECK-NEXT:           %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> | ||||
| // CHECK-NEXT:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index | ||||
| // CHECK-NEXT:           %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> | ||||
| // CHECK-NEXT:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| //                  The assuming region is part of the second stage of lowering | ||||
| //                  with ranked broadcasting logic. | ||||
| // CHECK-NEXT:           %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>)  -> tensor<?xf32> | ||||
| //                  As part of the unranked logic, the result is reshaped back | ||||
| //                  to an unranked tensor. | ||||
| // CHECK-NEXT:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK-NEXT:           return %[[RESHAPED_RESULT]] : tensor<*xf32> | ||||
| // CHECK-NEXT:         } | ||||
| 
 | ||||
| // ----- | ||||
| func @addUnrankedUnranked( | ||||
|       %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { | ||||
|   %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) | ||||
|                                          -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL:   func @addUnrankedUnranked( | ||||
| // CHECK-SAME:          %[[LHS:.*]]: tensor<*xf32>, | ||||
| // CHECK-SAME:          %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { | ||||
| // CHECK-NEXT:           %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex> | ||||
| // CHECK-NEXT:           %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index | ||||
| // CHECK-NEXT:           %[[C0:.*]] = constant 0 : index | ||||
| // CHECK-NEXT:           %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index | ||||
| //                       Handle scalar LHS case | ||||
| // CHECK-NEXT:           %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { | ||||
| // CHECK-NEXT:             %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32> | ||||
| // CHECK-NEXT:             %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> | ||||
| // CHECK-NEXT:             %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index | ||||
| // CHECK-NEXT:             %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex> | ||||
| // CHECK-NEXT:             %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:             %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:             %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK-NEXT:             scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32> | ||||
| // CHECK-NEXT:           } else { | ||||
| // CHECK-NEXT:             %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> | ||||
| // CHECK-NEXT:             %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index | ||||
| // CHECK-NEXT:             %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index | ||||
| //                         Handle scalar RHS case | ||||
| // CHECK-NEXT:             %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { | ||||
| // CHECK-NEXT:               %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32> | ||||
| // CHECK-NEXT:               %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index | ||||
| // CHECK-NEXT:               %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex> | ||||
| // CHECK-NEXT:               %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:               %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:               %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK-NEXT:               scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32> | ||||
| // CHECK-NEXT:             } else { | ||||
| // CHECK-NEXT:               %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> | ||||
| //                           Handle equal shapes case | ||||
| // CHECK-NEXT:               %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { | ||||
| // CHECK-NEXT:                 %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                 %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index | ||||
| // CHECK-NEXT:                 %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex> | ||||
| // CHECK-NEXT:                 %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:                 %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:                 %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32> | ||||
| // CHECK-NEXT:                 %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK-NEXT:                 scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> | ||||
| // CHECK-NEXT:               } else { | ||||
| // CHECK-NEXT:                 %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex> | ||||
| // CHECK-NEXT:                 %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex> | ||||
| // CHECK-NEXT:                 %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index | ||||
| // CHECK-NEXT:                 %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index | ||||
| // CHECK-NEXT:                 %[[C2:.*]] = constant 2 : index | ||||
| // CHECK-NEXT:                 %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index | ||||
| //                             Handle rank 2 specialization | ||||
| // CHECK-NEXT:                 %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { | ||||
| // CHECK-NEXT:                   %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] | ||||
| // CHECK-NEXT:                   %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                   %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex> | ||||
| // CHECK-NEXT:                   %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                   %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex> | ||||
| // CHECK-NEXT:                   %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
| // CHECK-NEXT:                   %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
| // CHECK-NEXT:                   %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> | ||||
| // CHECK-NEXT:                   %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32> | ||||
| // CHECK-NEXT:                   scf.yield %[[RESULT_2]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                 } else { | ||||
| // CHECK-NEXT:                   %[[C3:.*]] = constant 3 : index | ||||
| // CHECK-NEXT:                   %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index | ||||
| //                               Handle rank 3 specialization | ||||
| // CHECK-NEXT:                   %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { | ||||
| // CHECK-NEXT:                     %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] | ||||
| // CHECK-NEXT:                     %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                     %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex> | ||||
| // CHECK-NEXT:                     %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                     %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex> | ||||
| // CHECK-NEXT:                     %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||
| // CHECK-NEXT:                     %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||
| // CHECK-NEXT:                     %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||
| // CHECK-NEXT:                     %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32> | ||||
| // CHECK-NEXT:                     scf.yield %[[RESULT_3]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                   } else { | ||||
| // CHECK-NEXT:                     %[[C4:.*]] = constant 4 : index | ||||
| // CHECK-NEXT:                     %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index | ||||
| //                                 Handle rank 4 specialization | ||||
| // CHECK-NEXT:                     %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { | ||||
| // CHECK-NEXT:                       %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] | ||||
| // CHECK-NEXT:                       %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                       %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex> | ||||
| // CHECK-NEXT:                       %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                       %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex> | ||||
| // CHECK-NEXT:                       %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK-NEXT:                       %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK-NEXT:                       %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK-NEXT:                       %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32> | ||||
| // CHECK-NEXT:                       scf.yield %[[RESULT_4]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                     } else { | ||||
| // CHECK-NEXT:                       %[[C5:.*]] = constant 5 : index | ||||
| // CHECK-NEXT:                       %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index | ||||
| //                                   Handle rank 5 specialization | ||||
| // CHECK-NEXT:                       %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { | ||||
| // CHECK-NEXT:                         %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] | ||||
| // CHECK-NEXT:                         %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                         %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex> | ||||
| // CHECK-NEXT:                         %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                         %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex> | ||||
| // CHECK-NEXT:                         %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                         %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                         %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                         %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32> | ||||
| // CHECK-NEXT:                         scf.yield %[[RESULT_5]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                       } else { | ||||
| // CHECK-NEXT:                         %[[C6:.*]] = constant 6 : index | ||||
| // CHECK-NEXT:                         %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index | ||||
| //                                     Handle rank 6 specialization | ||||
| // CHECK-NEXT:                         %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) { | ||||
| // CHECK-NEXT:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] | ||||
| // CHECK-NEXT:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                           %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex> | ||||
| // CHECK-NEXT:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                           %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex> | ||||
| // CHECK-NEXT:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32> | ||||
| // CHECK-NEXT:                           scf.yield %[[RESULT_6]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                         } else { | ||||
| // CHECK-NEXT:                           %false = constant false | ||||
| // CHECK-NEXT:                           assert %false | ||||
| // CHECK-NEXT:                           scf.yield %[[LHS]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                         } | ||||
| // CHECK-NEXT:                         scf.yield %[[VAL_64:.*]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                       } | ||||
| // CHECK-NEXT:                       scf.yield %[[VAL_65:.*]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                     } | ||||
| // CHECK-NEXT:                     scf.yield %[[VAL_66:.*]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                   } | ||||
| // CHECK-NEXT:                   scf.yield %[[VAL_67:.*]] : tensor<*xf32> | ||||
| // CHECK-NEXT:                 } | ||||
| // CHECK-NEXT:                 scf.yield %[[VAL_68:.*]] : tensor<*xf32> | ||||
| // CHECK-NEXT:               } | ||||
| // CHECK-NEXT:               scf.yield %[[VAL_69:.*]] : tensor<*xf32> | ||||
| // CHECK-NEXT:             } | ||||
| // CHECK-NEXT:             scf.yield %[[VAL_70:.*]] : tensor<*xf32> | ||||
| // CHECK-NEXT:           } | ||||
| // CHECK-NEXT:           return %[[VAL_71:.*]] : tensor<*xf32> | ||||
| // CHECK-NEXT:         } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue