[MLIR] Simplify and generalize `transform-unranked-hlo`
This refactoring allows to support a wider range of n-ary operations in future changes. PiperOrigin-RevId: 331953362
This commit is contained in:
		
							parent
							
								
									d1f85f32b1
								
							
						
					
					
						commit
						da43c8596b
					
				|  | @ -49,99 +49,70 @@ namespace { | |||
| template <typename OpTy> | ||||
| inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { | ||||
|   target->addDynamicallyLegalOp<OpTy>([](OpTy op) { | ||||
|     return llvm::all_of((op.getOperation())->getOperandTypes(), | ||||
|     return llvm::all_of(op.getOperation()->getOperandTypes(), | ||||
|                         [&](Type t) { return t.isa<RankedTensorType>(); }); | ||||
|   }); | ||||
| } | ||||
| 
 | ||||
| /// Unary element-wise operations on unranked tensors can be applied to the
 | ||||
| /// flattened tensor with the same effect.
 | ||||
| /// This pattern rewrites every such operation to
 | ||||
| /// Element-wise operations on unranked tensors can be applied to the flattened
 | ||||
| /// tensor operands with the same effect.  This pattern rewrites every such
 | ||||
| /// operation to
 | ||||
| ///   (i)   flatten the input tensor,
 | ||||
| ///   (ii)  apply the unary operation, and
 | ||||
| ///   (ii)  apply the operation, and
 | ||||
| ///   (iii) restore the original shape.
 | ||||
| template <typename OpTy> | ||||
| struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> { | ||||
|   explicit UnaryElementwiseOpConversion(MLIRContext *context) | ||||
| struct ElementwiseOpConversion : public OpRewritePattern<OpTy> { | ||||
|   explicit ElementwiseOpConversion(MLIRContext *context) | ||||
|       : OpRewritePattern<OpTy>(context) {} | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite(OpTy op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     // Don't apply conversion to ops with statically shaped operands.
 | ||||
|     Value operand = op.getOperand(); | ||||
|     auto operandTy = operand.getType().dyn_cast<TensorType>(); | ||||
|     if (operandTy.hasRank()) return failure(); | ||||
| 
 | ||||
|     // Generate IR to flatten the operand.
 | ||||
|     auto loc = op.getLoc(); | ||||
|     Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); | ||||
|     Value shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand); | ||||
|     Type indexTy = rewriter.getIndexType(); | ||||
|     Value numElements = | ||||
|         rewriter.create<shape::NumElementsOp>(loc, indexTy, shape); | ||||
|     Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements); | ||||
|     auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                               operandTy.getElementType()); | ||||
|     Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, flatTensorTy, operand, flatShape); | ||||
| 
 | ||||
|     // Generate IR for the actual operation.
 | ||||
|     Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand); | ||||
| 
 | ||||
|     // Generate IR to restore the original shape.
 | ||||
|     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy, | ||||
|                                                         flatResult, shape); | ||||
| 
 | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| /// Binary element-wise operation on unranked tensors can be applied to the
 | ||||
| /// flattened operand tensors with the same effect.
 | ||||
| /// This pattern rewrites every such operation to
 | ||||
| ///   (i)   flatten the operand tensors,
 | ||||
| ///   (ii)  apply the binary operation, and
 | ||||
| //    (iii) restore the original shape.
 | ||||
| template <typename OpTy> | ||||
| struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> { | ||||
|   explicit BinaryElementwiseOpConversion(MLIRContext *context) | ||||
|       : OpRewritePattern<OpTy>(context) {} | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite(OpTy op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     // Don't apply conversion unless both operands are unranked.
 | ||||
|     if (op.lhs().getType().template isa<RankedTensorType>() || | ||||
|         op.rhs().getType().template isa<RankedTensorType>()) { | ||||
|     // Don't apply conversion unless all operands are unranked.
 | ||||
|     if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) { | ||||
|           return operand.getType().isa<UnrankedTensorType>(); | ||||
|         })) { | ||||
|       return failure(); | ||||
|     } | ||||
| 
 | ||||
|     // Flatten operands.
 | ||||
|     // Get operands' shape.
 | ||||
|     auto loc = op.getLoc(); | ||||
|     Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); | ||||
|     Value shapeLhs = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs()); | ||||
|     Value shapeRhs = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs()); | ||||
|     Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy, | ||||
|                                                 ValueRange{shapeLhs, shapeRhs}); | ||||
|     SmallVector<Value, 3> operandShapes; | ||||
|     for (Value operand : op.getOperation()->getOperands()) { | ||||
|       Value shape = | ||||
|           rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand); | ||||
|       operandShapes.push_back(shape); | ||||
|     } | ||||
|     Value shape = | ||||
|         operandShapes.size() == 1 | ||||
|             ? operandShapes.front() | ||||
|             : rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes); | ||||
| 
 | ||||
|     // Derive flat shape.
 | ||||
|     Type indexTy = rewriter.getIndexType(); | ||||
|     Value numElements = | ||||
|         rewriter.create<shape::NumElementsOp>(loc, indexTy, shape); | ||||
|     Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements); | ||||
|     TensorType lhsTy = op.lhs().getType().template cast<TensorType>(); | ||||
|     Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                            lhsTy.getElementType()); | ||||
|     Value flatLhs = | ||||
|         rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape); | ||||
|     TensorType rhsTy = op.rhs().getType().template cast<TensorType>(); | ||||
|     Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                            rhsTy.getElementType()); | ||||
|     Value flatRhs = | ||||
|         rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape); | ||||
| 
 | ||||
|     // Apply actual operation to flattened operands.
 | ||||
|     Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs); | ||||
|     // Flatten operands.
 | ||||
|     SmallVector<Value, 3> flatOperands; | ||||
|     for (Value operand : op.getOperation()->getOperands()) { | ||||
|       Type operandElementTy = | ||||
|           operand.getType().template cast<ShapedType>().getElementType(); | ||||
|       Type flatTy = | ||||
|           RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); | ||||
|       Value flat = | ||||
|           rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape); | ||||
|       flatOperands.push_back(flat); | ||||
|     } | ||||
| 
 | ||||
|     // Apply operation to flattened operands.
 | ||||
|     Type resultElementTy = | ||||
|         op.getType().template cast<ShapedType>().getElementType(); | ||||
|     Type flatResultTy = | ||||
|         RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy); | ||||
|     Value flatResult = | ||||
|         rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs()); | ||||
| 
 | ||||
|     // Restore original shape.
 | ||||
|     rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult, | ||||
|  | @ -154,7 +125,7 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> { | |||
| struct TransformUnrankedHloPass | ||||
|     : public PassWrapper<TransformUnrankedHloPass, FunctionPass> { | ||||
|   void getDependentDialects(DialectRegistry ®istry) const override { | ||||
|     registry.insert<shape::ShapeDialect>(); | ||||
|     registry.insert<shape::ShapeDialect, mhlo::MhloDialect>(); | ||||
|   } | ||||
| 
 | ||||
|   void runOnFunction() override { | ||||
|  | @ -183,19 +154,17 @@ struct TransformUnrankedHloPass | |||
| 
 | ||||
| void PopulateTransformUnrankedHloPatterns(MLIRContext *context, | ||||
|                                           OwningRewritePatternList *patterns) { | ||||
|   // TODO(frgossen): Populate all unary and binary operations.
 | ||||
|   // clang-format off
 | ||||
| #define MAP_UNARY(op) UnaryElementwiseOpConversion<op> | ||||
| #define MAP_BINARY(op) BinaryElementwiseOpConversion<op> | ||||
| #define MAP_UNARY(op) ElementwiseOpConversion<op> | ||||
| #define MAP_BINARY(op) ElementwiseOpConversion<op> | ||||
| #define COMMA , | ||||
|   // clang-format off
 | ||||
|   patterns->insert< | ||||
|       MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), | ||||
|       MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA) | ||||
|       >(context); | ||||
|       MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context); | ||||
|   // clang-format on
 | ||||
| #undef MAP_UNARY | ||||
| #undef MAP_BINARY | ||||
| #undef COMMA | ||||
|   // clang-format on
 | ||||
| } | ||||
| 
 | ||||
| std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue