[MLIR] Simplify and generalize `transform-unranked-hlo`
This refactoring allows to support a wider range of n-ary operations in future changes. PiperOrigin-RevId: 331953362
This commit is contained in:
		
							parent
							
								
									d1f85f32b1
								
							
						
					
					
						commit
						da43c8596b
					
				|  | @ -49,99 +49,70 @@ namespace { | ||||||
| template <typename OpTy> | template <typename OpTy> | ||||||
| inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { | inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { | ||||||
|   target->addDynamicallyLegalOp<OpTy>([](OpTy op) { |   target->addDynamicallyLegalOp<OpTy>([](OpTy op) { | ||||||
|     return llvm::all_of((op.getOperation())->getOperandTypes(), |     return llvm::all_of(op.getOperation()->getOperandTypes(), | ||||||
|                         [&](Type t) { return t.isa<RankedTensorType>(); }); |                         [&](Type t) { return t.isa<RankedTensorType>(); }); | ||||||
|   }); |   }); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Unary element-wise operations on unranked tensors can be applied to the
 | /// Element-wise operations on unranked tensors can be applied to the flattened
 | ||||||
| /// flattened tensor with the same effect.
 | /// tensor operands with the same effect.  This pattern rewrites every such
 | ||||||
| /// This pattern rewrites every such operation to
 | /// operation to
 | ||||||
| ///   (i)   flatten the input tensor,
 | ///   (i)   flatten the input tensor,
 | ||||||
| ///   (ii)  apply the unary operation, and
 | ///   (ii)  apply the operation, and
 | ||||||
| ///   (iii) restore the original shape.
 | ///   (iii) restore the original shape.
 | ||||||
| template <typename OpTy> | template <typename OpTy> | ||||||
| struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> { | struct ElementwiseOpConversion : public OpRewritePattern<OpTy> { | ||||||
|   explicit UnaryElementwiseOpConversion(MLIRContext *context) |   explicit ElementwiseOpConversion(MLIRContext *context) | ||||||
|       : OpRewritePattern<OpTy>(context) {} |       : OpRewritePattern<OpTy>(context) {} | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite(OpTy op, |   LogicalResult matchAndRewrite(OpTy op, | ||||||
|                                 PatternRewriter &rewriter) const override { |                                 PatternRewriter &rewriter) const override { | ||||||
|     // Don't apply conversion to ops with statically shaped operands.
 |     // Don't apply conversion unless all operands are unranked.
 | ||||||
|     Value operand = op.getOperand(); |     if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) { | ||||||
|     auto operandTy = operand.getType().dyn_cast<TensorType>(); |           return operand.getType().isa<UnrankedTensorType>(); | ||||||
|     if (operandTy.hasRank()) return failure(); |         })) { | ||||||
| 
 |  | ||||||
|     // Generate IR to flatten the operand.
 |  | ||||||
|     auto loc = op.getLoc(); |  | ||||||
|     Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); |  | ||||||
|     Value shape = |  | ||||||
|         rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand); |  | ||||||
|     Type indexTy = rewriter.getIndexType(); |  | ||||||
|     Value numElements = |  | ||||||
|         rewriter.create<shape::NumElementsOp>(loc, indexTy, shape); |  | ||||||
|     Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements); |  | ||||||
|     auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, |  | ||||||
|                                               operandTy.getElementType()); |  | ||||||
|     Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>( |  | ||||||
|         loc, flatTensorTy, operand, flatShape); |  | ||||||
| 
 |  | ||||||
|     // Generate IR for the actual operation.
 |  | ||||||
|     Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand); |  | ||||||
| 
 |  | ||||||
|     // Generate IR to restore the original shape.
 |  | ||||||
|     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy, |  | ||||||
|                                                         flatResult, shape); |  | ||||||
| 
 |  | ||||||
|     return success(); |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| /// Binary element-wise operation on unranked tensors can be applied to the
 |  | ||||||
| /// flattened operand tensors with the same effect.
 |  | ||||||
| /// This pattern rewrites every such operation to
 |  | ||||||
| ///   (i)   flatten the operand tensors,
 |  | ||||||
| ///   (ii)  apply the binary operation, and
 |  | ||||||
| //    (iii) restore the original shape.
 |  | ||||||
| template <typename OpTy> |  | ||||||
| struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> { |  | ||||||
|   explicit BinaryElementwiseOpConversion(MLIRContext *context) |  | ||||||
|       : OpRewritePattern<OpTy>(context) {} |  | ||||||
| 
 |  | ||||||
|   LogicalResult matchAndRewrite(OpTy op, |  | ||||||
|                                 PatternRewriter &rewriter) const override { |  | ||||||
|     // Don't apply conversion unless both operands are unranked.
 |  | ||||||
|     if (op.lhs().getType().template isa<RankedTensorType>() || |  | ||||||
|         op.rhs().getType().template isa<RankedTensorType>()) { |  | ||||||
|       return failure(); |       return failure(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // Flatten operands.
 |     // Get operands' shape.
 | ||||||
|     auto loc = op.getLoc(); |     auto loc = op.getLoc(); | ||||||
|     Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); |     Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); | ||||||
|     Value shapeLhs = |     SmallVector<Value, 3> operandShapes; | ||||||
|         rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs()); |     for (Value operand : op.getOperation()->getOperands()) { | ||||||
|     Value shapeRhs = |       Value shape = | ||||||
|         rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs()); |           rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand); | ||||||
|     Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy, |       operandShapes.push_back(shape); | ||||||
|                                                 ValueRange{shapeLhs, shapeRhs}); |     } | ||||||
|  |     Value shape = | ||||||
|  |         operandShapes.size() == 1 | ||||||
|  |             ? operandShapes.front() | ||||||
|  |             : rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes); | ||||||
|  | 
 | ||||||
|  |     // Derive flat shape.
 | ||||||
|     Type indexTy = rewriter.getIndexType(); |     Type indexTy = rewriter.getIndexType(); | ||||||
|     Value numElements = |     Value numElements = | ||||||
|         rewriter.create<shape::NumElementsOp>(loc, indexTy, shape); |         rewriter.create<shape::NumElementsOp>(loc, indexTy, shape); | ||||||
|     Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements); |     Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements); | ||||||
|     TensorType lhsTy = op.lhs().getType().template cast<TensorType>(); |  | ||||||
|     Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, |  | ||||||
|                                            lhsTy.getElementType()); |  | ||||||
|     Value flatLhs = |  | ||||||
|         rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape); |  | ||||||
|     TensorType rhsTy = op.rhs().getType().template cast<TensorType>(); |  | ||||||
|     Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, |  | ||||||
|                                            rhsTy.getElementType()); |  | ||||||
|     Value flatRhs = |  | ||||||
|         rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape); |  | ||||||
| 
 | 
 | ||||||
|     // Apply actual operation to flattened operands.
 |     // Flatten operands.
 | ||||||
|     Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs); |     SmallVector<Value, 3> flatOperands; | ||||||
|  |     for (Value operand : op.getOperation()->getOperands()) { | ||||||
|  |       Type operandElementTy = | ||||||
|  |           operand.getType().template cast<ShapedType>().getElementType(); | ||||||
|  |       Type flatTy = | ||||||
|  |           RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); | ||||||
|  |       Value flat = | ||||||
|  |           rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape); | ||||||
|  |       flatOperands.push_back(flat); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Apply operation to flattened operands.
 | ||||||
|  |     Type resultElementTy = | ||||||
|  |         op.getType().template cast<ShapedType>().getElementType(); | ||||||
|  |     Type flatResultTy = | ||||||
|  |         RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy); | ||||||
|  |     Value flatResult = | ||||||
|  |         rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs()); | ||||||
| 
 | 
 | ||||||
|     // Restore original shape.
 |     // Restore original shape.
 | ||||||
|     rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult, |     rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult, | ||||||
|  | @ -154,7 +125,7 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> { | ||||||
| struct TransformUnrankedHloPass | struct TransformUnrankedHloPass | ||||||
|     : public PassWrapper<TransformUnrankedHloPass, FunctionPass> { |     : public PassWrapper<TransformUnrankedHloPass, FunctionPass> { | ||||||
|   void getDependentDialects(DialectRegistry ®istry) const override { |   void getDependentDialects(DialectRegistry ®istry) const override { | ||||||
|     registry.insert<shape::ShapeDialect>(); |     registry.insert<shape::ShapeDialect, mhlo::MhloDialect>(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   void runOnFunction() override { |   void runOnFunction() override { | ||||||
|  | @ -183,19 +154,17 @@ struct TransformUnrankedHloPass | ||||||
| 
 | 
 | ||||||
| void PopulateTransformUnrankedHloPatterns(MLIRContext *context, | void PopulateTransformUnrankedHloPatterns(MLIRContext *context, | ||||||
|                                           OwningRewritePatternList *patterns) { |                                           OwningRewritePatternList *patterns) { | ||||||
|   // TODO(frgossen): Populate all unary and binary operations.
 | #define MAP_UNARY(op) ElementwiseOpConversion<op> | ||||||
|   // clang-format off
 | #define MAP_BINARY(op) ElementwiseOpConversion<op> | ||||||
| #define MAP_UNARY(op) UnaryElementwiseOpConversion<op> |  | ||||||
| #define MAP_BINARY(op) BinaryElementwiseOpConversion<op> |  | ||||||
| #define COMMA , | #define COMMA , | ||||||
|  |   // clang-format off
 | ||||||
|   patterns->insert< |   patterns->insert< | ||||||
|       MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), |       MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), | ||||||
|       MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA) |       MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context); | ||||||
|       >(context); |   // clang-format on
 | ||||||
| #undef MAP_UNARY | #undef MAP_UNARY | ||||||
| #undef MAP_BINARY | #undef MAP_BINARY | ||||||
| #undef COMMA | #undef COMMA | ||||||
|   // clang-format on
 |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() { | std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue