[MLIR] Simplify and generalize `transform-unranked-hlo`
This refactoring allows to support a wider range of n-ary operations in future changes. PiperOrigin-RevId: 331953362
This commit is contained in:
parent
d1f85f32b1
commit
da43c8596b
|
@ -49,99 +49,70 @@ namespace {
|
|||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||
return llvm::all_of((op.getOperation())->getOperandTypes(),
|
||||
return llvm::all_of(op.getOperation()->getOperandTypes(),
|
||||
[&](Type t) { return t.isa<RankedTensorType>(); });
|
||||
});
|
||||
}
|
||||
|
||||
/// Unary element-wise operations on unranked tensors can be applied to the
|
||||
/// flattened tensor with the same effect.
|
||||
/// This pattern rewrites every such operation to
|
||||
/// Element-wise operations on unranked tensors can be applied to the flattened
|
||||
/// tensor operands with the same effect. This pattern rewrites every such
|
||||
/// operation to
|
||||
/// (i) flatten the input tensor,
|
||||
/// (ii) apply the unary operation, and
|
||||
/// (ii) apply the operation, and
|
||||
/// (iii) restore the original shape.
|
||||
template <typename OpTy>
|
||||
struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
explicit UnaryElementwiseOpConversion(MLIRContext *context)
|
||||
struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
explicit ElementwiseOpConversion(MLIRContext *context)
|
||||
: OpRewritePattern<OpTy>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Don't apply conversion to ops with statically shaped operands.
|
||||
Value operand = op.getOperand();
|
||||
auto operandTy = operand.getType().dyn_cast<TensorType>();
|
||||
if (operandTy.hasRank()) return failure();
|
||||
|
||||
// Generate IR to flatten the operand.
|
||||
auto loc = op.getLoc();
|
||||
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Value numElements =
|
||||
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
||||
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
||||
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
operandTy.getElementType());
|
||||
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, flatTensorTy, operand, flatShape);
|
||||
|
||||
// Generate IR for the actual operation.
|
||||
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
|
||||
|
||||
// Generate IR to restore the original shape.
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy,
|
||||
flatResult, shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Binary element-wise operation on unranked tensors can be applied to the
|
||||
/// flattened operand tensors with the same effect.
|
||||
/// This pattern rewrites every such operation to
|
||||
/// (i) flatten the operand tensors,
|
||||
/// (ii) apply the binary operation, and
|
||||
// (iii) restore the original shape.
|
||||
template <typename OpTy>
|
||||
struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
explicit BinaryElementwiseOpConversion(MLIRContext *context)
|
||||
: OpRewritePattern<OpTy>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Don't apply conversion unless both operands are unranked.
|
||||
if (op.lhs().getType().template isa<RankedTensorType>() ||
|
||||
op.rhs().getType().template isa<RankedTensorType>()) {
|
||||
// Don't apply conversion unless all operands are unranked.
|
||||
if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
|
||||
return operand.getType().isa<UnrankedTensorType>();
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Flatten operands.
|
||||
// Get operands' shape.
|
||||
auto loc = op.getLoc();
|
||||
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
||||
Value shapeLhs =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs());
|
||||
Value shapeRhs =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs());
|
||||
Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy,
|
||||
ValueRange{shapeLhs, shapeRhs});
|
||||
SmallVector<Value, 3> operandShapes;
|
||||
for (Value operand : op.getOperation()->getOperands()) {
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
||||
operandShapes.push_back(shape);
|
||||
}
|
||||
Value shape =
|
||||
operandShapes.size() == 1
|
||||
? operandShapes.front()
|
||||
: rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
|
||||
|
||||
// Derive flat shape.
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Value numElements =
|
||||
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
||||
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
||||
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
|
||||
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
lhsTy.getElementType());
|
||||
Value flatLhs =
|
||||
rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape);
|
||||
TensorType rhsTy = op.rhs().getType().template cast<TensorType>();
|
||||
Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rhsTy.getElementType());
|
||||
Value flatRhs =
|
||||
rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape);
|
||||
|
||||
// Apply actual operation to flattened operands.
|
||||
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
|
||||
// Flatten operands.
|
||||
SmallVector<Value, 3> flatOperands;
|
||||
for (Value operand : op.getOperation()->getOperands()) {
|
||||
Type operandElementTy =
|
||||
operand.getType().template cast<ShapedType>().getElementType();
|
||||
Type flatTy =
|
||||
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
|
||||
Value flat =
|
||||
rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape);
|
||||
flatOperands.push_back(flat);
|
||||
}
|
||||
|
||||
// Apply operation to flattened operands.
|
||||
Type resultElementTy =
|
||||
op.getType().template cast<ShapedType>().getElementType();
|
||||
Type flatResultTy =
|
||||
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
|
||||
Value flatResult =
|
||||
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
|
||||
|
||||
// Restore original shape.
|
||||
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
|
||||
|
@ -154,7 +125,7 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
struct TransformUnrankedHloPass
|
||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect>();
|
||||
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
|
@ -183,19 +154,17 @@ struct TransformUnrankedHloPass
|
|||
|
||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
// TODO(frgossen): Populate all unary and binary operations.
|
||||
// clang-format off
|
||||
#define MAP_UNARY(op) UnaryElementwiseOpConversion<op>
|
||||
#define MAP_BINARY(op) BinaryElementwiseOpConversion<op>
|
||||
#define MAP_UNARY(op) ElementwiseOpConversion<op>
|
||||
#define MAP_BINARY(op) ElementwiseOpConversion<op>
|
||||
#define COMMA ,
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)
|
||||
>(context);
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context);
|
||||
// clang-format on
|
||||
#undef MAP_UNARY
|
||||
#undef MAP_BINARY
|
||||
#undef COMMA
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
||||
|
|
Loading…
Reference in New Issue