[MLIR] Simplify and generalize `transform-unranked-hlo`
This refactoring allows to support a wider range of n-ary operations in future changes. PiperOrigin-RevId: 331953362
This commit is contained in:
		
							parent
							
								
									d1f85f32b1
								
							
						
					
					
						commit
						da43c8596b
					
				| 
						 | 
					@ -49,99 +49,70 @@ namespace {
 | 
				
			||||||
template <typename OpTy>
 | 
					template <typename OpTy>
 | 
				
			||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
 | 
					inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
 | 
				
			||||||
  target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
 | 
					  target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
 | 
				
			||||||
    return llvm::all_of((op.getOperation())->getOperandTypes(),
 | 
					    return llvm::all_of(op.getOperation()->getOperandTypes(),
 | 
				
			||||||
                        [&](Type t) { return t.isa<RankedTensorType>(); });
 | 
					                        [&](Type t) { return t.isa<RankedTensorType>(); });
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// Unary element-wise operations on unranked tensors can be applied to the
 | 
					/// Element-wise operations on unranked tensors can be applied to the flattened
 | 
				
			||||||
/// flattened tensor with the same effect.
 | 
					/// tensor operands with the same effect.  This pattern rewrites every such
 | 
				
			||||||
/// This pattern rewrites every such operation to
 | 
					/// operation to
 | 
				
			||||||
///   (i)   flatten the input tensor,
 | 
					///   (i)   flatten the input tensor,
 | 
				
			||||||
///   (ii)  apply the unary operation, and
 | 
					///   (ii)  apply the operation, and
 | 
				
			||||||
///   (iii) restore the original shape.
 | 
					///   (iii) restore the original shape.
 | 
				
			||||||
template <typename OpTy>
 | 
					template <typename OpTy>
 | 
				
			||||||
struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
 | 
					struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
 | 
				
			||||||
  explicit UnaryElementwiseOpConversion(MLIRContext *context)
 | 
					  explicit ElementwiseOpConversion(MLIRContext *context)
 | 
				
			||||||
      : OpRewritePattern<OpTy>(context) {}
 | 
					      : OpRewritePattern<OpTy>(context) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  LogicalResult matchAndRewrite(OpTy op,
 | 
					  LogicalResult matchAndRewrite(OpTy op,
 | 
				
			||||||
                                PatternRewriter &rewriter) const override {
 | 
					                                PatternRewriter &rewriter) const override {
 | 
				
			||||||
    // Don't apply conversion to ops with statically shaped operands.
 | 
					    // Don't apply conversion unless all operands are unranked.
 | 
				
			||||||
    Value operand = op.getOperand();
 | 
					    if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
 | 
				
			||||||
    auto operandTy = operand.getType().dyn_cast<TensorType>();
 | 
					          return operand.getType().isa<UnrankedTensorType>();
 | 
				
			||||||
    if (operandTy.hasRank()) return failure();
 | 
					        })) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Generate IR to flatten the operand.
 | 
					 | 
				
			||||||
    auto loc = op.getLoc();
 | 
					 | 
				
			||||||
    Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
 | 
					 | 
				
			||||||
    Value shape =
 | 
					 | 
				
			||||||
        rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
 | 
					 | 
				
			||||||
    Type indexTy = rewriter.getIndexType();
 | 
					 | 
				
			||||||
    Value numElements =
 | 
					 | 
				
			||||||
        rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
 | 
					 | 
				
			||||||
    Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
 | 
					 | 
				
			||||||
    auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
					 | 
				
			||||||
                                              operandTy.getElementType());
 | 
					 | 
				
			||||||
    Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
 | 
					 | 
				
			||||||
        loc, flatTensorTy, operand, flatShape);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Generate IR for the actual operation.
 | 
					 | 
				
			||||||
    Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Generate IR to restore the original shape.
 | 
					 | 
				
			||||||
    rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy,
 | 
					 | 
				
			||||||
                                                        flatResult, shape);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return success();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/// Binary element-wise operation on unranked tensors can be applied to the
 | 
					 | 
				
			||||||
/// flattened operand tensors with the same effect.
 | 
					 | 
				
			||||||
/// This pattern rewrites every such operation to
 | 
					 | 
				
			||||||
///   (i)   flatten the operand tensors,
 | 
					 | 
				
			||||||
///   (ii)  apply the binary operation, and
 | 
					 | 
				
			||||||
//    (iii) restore the original shape.
 | 
					 | 
				
			||||||
template <typename OpTy>
 | 
					 | 
				
			||||||
struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
 | 
					 | 
				
			||||||
  explicit BinaryElementwiseOpConversion(MLIRContext *context)
 | 
					 | 
				
			||||||
      : OpRewritePattern<OpTy>(context) {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  LogicalResult matchAndRewrite(OpTy op,
 | 
					 | 
				
			||||||
                                PatternRewriter &rewriter) const override {
 | 
					 | 
				
			||||||
    // Don't apply conversion unless both operands are unranked.
 | 
					 | 
				
			||||||
    if (op.lhs().getType().template isa<RankedTensorType>() ||
 | 
					 | 
				
			||||||
        op.rhs().getType().template isa<RankedTensorType>()) {
 | 
					 | 
				
			||||||
      return failure();
 | 
					      return failure();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Flatten operands.
 | 
					    // Get operands' shape.
 | 
				
			||||||
    auto loc = op.getLoc();
 | 
					    auto loc = op.getLoc();
 | 
				
			||||||
    Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
 | 
					    Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
 | 
				
			||||||
    Value shapeLhs =
 | 
					    SmallVector<Value, 3> operandShapes;
 | 
				
			||||||
        rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs());
 | 
					    for (Value operand : op.getOperation()->getOperands()) {
 | 
				
			||||||
    Value shapeRhs =
 | 
					      Value shape =
 | 
				
			||||||
        rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs());
 | 
					          rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
 | 
				
			||||||
    Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy,
 | 
					      operandShapes.push_back(shape);
 | 
				
			||||||
                                                ValueRange{shapeLhs, shapeRhs});
 | 
					    }
 | 
				
			||||||
 | 
					    Value shape =
 | 
				
			||||||
 | 
					        operandShapes.size() == 1
 | 
				
			||||||
 | 
					            ? operandShapes.front()
 | 
				
			||||||
 | 
					            : rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Derive flat shape.
 | 
				
			||||||
    Type indexTy = rewriter.getIndexType();
 | 
					    Type indexTy = rewriter.getIndexType();
 | 
				
			||||||
    Value numElements =
 | 
					    Value numElements =
 | 
				
			||||||
        rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
 | 
					        rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
 | 
				
			||||||
    Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
 | 
					    Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
 | 
				
			||||||
    TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
 | 
					 | 
				
			||||||
    Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
					 | 
				
			||||||
                                           lhsTy.getElementType());
 | 
					 | 
				
			||||||
    Value flatLhs =
 | 
					 | 
				
			||||||
        rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape);
 | 
					 | 
				
			||||||
    TensorType rhsTy = op.rhs().getType().template cast<TensorType>();
 | 
					 | 
				
			||||||
    Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
					 | 
				
			||||||
                                           rhsTy.getElementType());
 | 
					 | 
				
			||||||
    Value flatRhs =
 | 
					 | 
				
			||||||
        rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Apply actual operation to flattened operands.
 | 
					    // Flatten operands.
 | 
				
			||||||
    Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
 | 
					    SmallVector<Value, 3> flatOperands;
 | 
				
			||||||
 | 
					    for (Value operand : op.getOperation()->getOperands()) {
 | 
				
			||||||
 | 
					      Type operandElementTy =
 | 
				
			||||||
 | 
					          operand.getType().template cast<ShapedType>().getElementType();
 | 
				
			||||||
 | 
					      Type flatTy =
 | 
				
			||||||
 | 
					          RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
 | 
				
			||||||
 | 
					      Value flat =
 | 
				
			||||||
 | 
					          rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape);
 | 
				
			||||||
 | 
					      flatOperands.push_back(flat);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Apply operation to flattened operands.
 | 
				
			||||||
 | 
					    Type resultElementTy =
 | 
				
			||||||
 | 
					        op.getType().template cast<ShapedType>().getElementType();
 | 
				
			||||||
 | 
					    Type flatResultTy =
 | 
				
			||||||
 | 
					        RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
 | 
				
			||||||
 | 
					    Value flatResult =
 | 
				
			||||||
 | 
					        rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Restore original shape.
 | 
					    // Restore original shape.
 | 
				
			||||||
    rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
 | 
					    rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
 | 
				
			||||||
| 
						 | 
					@ -154,7 +125,7 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
 | 
				
			||||||
struct TransformUnrankedHloPass
 | 
					struct TransformUnrankedHloPass
 | 
				
			||||||
    : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
 | 
					    : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
 | 
				
			||||||
  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
					  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
				
			||||||
    registry.insert<shape::ShapeDialect>();
 | 
					    registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
| 
						 | 
					@ -183,19 +154,17 @@ struct TransformUnrankedHloPass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
 | 
					void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
 | 
				
			||||||
                                          OwningRewritePatternList *patterns) {
 | 
					                                          OwningRewritePatternList *patterns) {
 | 
				
			||||||
  // TODO(frgossen): Populate all unary and binary operations.
 | 
					#define MAP_UNARY(op) ElementwiseOpConversion<op>
 | 
				
			||||||
  // clang-format off
 | 
					#define MAP_BINARY(op) ElementwiseOpConversion<op>
 | 
				
			||||||
#define MAP_UNARY(op) UnaryElementwiseOpConversion<op>
 | 
					 | 
				
			||||||
#define MAP_BINARY(op) BinaryElementwiseOpConversion<op>
 | 
					 | 
				
			||||||
#define COMMA ,
 | 
					#define COMMA ,
 | 
				
			||||||
 | 
					  // clang-format off
 | 
				
			||||||
  patterns->insert<
 | 
					  patterns->insert<
 | 
				
			||||||
      MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
 | 
					      MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
 | 
				
			||||||
      MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)
 | 
					      MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context);
 | 
				
			||||||
      >(context);
 | 
					  // clang-format on
 | 
				
			||||||
#undef MAP_UNARY
 | 
					#undef MAP_UNARY
 | 
				
			||||||
#undef MAP_BINARY
 | 
					#undef MAP_BINARY
 | 
				
			||||||
#undef COMMA
 | 
					#undef COMMA
 | 
				
			||||||
  // clang-format on
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
 | 
					std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue