[MLIR] Simplify and generalize `transform-unranked-hlo`
This refactoring allows to support a wider range of n-ary operations in future changes. PiperOrigin-RevId: 331953362
This commit is contained in:
parent
d1f85f32b1
commit
da43c8596b
|
@ -49,99 +49,70 @@ namespace {
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||||
return llvm::all_of((op.getOperation())->getOperandTypes(),
|
return llvm::all_of(op.getOperation()->getOperandTypes(),
|
||||||
[&](Type t) { return t.isa<RankedTensorType>(); });
|
[&](Type t) { return t.isa<RankedTensorType>(); });
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unary element-wise operations on unranked tensors can be applied to the
|
/// Element-wise operations on unranked tensors can be applied to the flattened
|
||||||
/// flattened tensor with the same effect.
|
/// tensor operands with the same effect. This pattern rewrites every such
|
||||||
/// This pattern rewrites every such operation to
|
/// operation to
|
||||||
/// (i) flatten the input tensor,
|
/// (i) flatten the input tensor,
|
||||||
/// (ii) apply the unary operation, and
|
/// (ii) apply the operation, and
|
||||||
/// (iii) restore the original shape.
|
/// (iii) restore the original shape.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
explicit UnaryElementwiseOpConversion(MLIRContext *context)
|
explicit ElementwiseOpConversion(MLIRContext *context)
|
||||||
: OpRewritePattern<OpTy>(context) {}
|
: OpRewritePattern<OpTy>(context) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(OpTy op,
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// Don't apply conversion to ops with statically shaped operands.
|
// Don't apply conversion unless all operands are unranked.
|
||||||
Value operand = op.getOperand();
|
if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
|
||||||
auto operandTy = operand.getType().dyn_cast<TensorType>();
|
return operand.getType().isa<UnrankedTensorType>();
|
||||||
if (operandTy.hasRank()) return failure();
|
})) {
|
||||||
|
|
||||||
// Generate IR to flatten the operand.
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
|
||||||
Value shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
|
||||||
Type indexTy = rewriter.getIndexType();
|
|
||||||
Value numElements =
|
|
||||||
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
|
||||||
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
|
||||||
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
operandTy.getElementType());
|
|
||||||
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, flatTensorTy, operand, flatShape);
|
|
||||||
|
|
||||||
// Generate IR for the actual operation.
|
|
||||||
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
|
|
||||||
|
|
||||||
// Generate IR to restore the original shape.
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy,
|
|
||||||
flatResult, shape);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Binary element-wise operation on unranked tensors can be applied to the
|
|
||||||
/// flattened operand tensors with the same effect.
|
|
||||||
/// This pattern rewrites every such operation to
|
|
||||||
/// (i) flatten the operand tensors,
|
|
||||||
/// (ii) apply the binary operation, and
|
|
||||||
// (iii) restore the original shape.
|
|
||||||
template <typename OpTy>
|
|
||||||
struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|
||||||
explicit BinaryElementwiseOpConversion(MLIRContext *context)
|
|
||||||
: OpRewritePattern<OpTy>(context) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(OpTy op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// Don't apply conversion unless both operands are unranked.
|
|
||||||
if (op.lhs().getType().template isa<RankedTensorType>() ||
|
|
||||||
op.rhs().getType().template isa<RankedTensorType>()) {
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flatten operands.
|
// Get operands' shape.
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
||||||
Value shapeLhs =
|
SmallVector<Value, 3> operandShapes;
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs());
|
for (Value operand : op.getOperation()->getOperands()) {
|
||||||
Value shapeRhs =
|
Value shape =
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs());
|
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
||||||
Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy,
|
operandShapes.push_back(shape);
|
||||||
ValueRange{shapeLhs, shapeRhs});
|
}
|
||||||
|
Value shape =
|
||||||
|
operandShapes.size() == 1
|
||||||
|
? operandShapes.front()
|
||||||
|
: rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
|
||||||
|
|
||||||
|
// Derive flat shape.
|
||||||
Type indexTy = rewriter.getIndexType();
|
Type indexTy = rewriter.getIndexType();
|
||||||
Value numElements =
|
Value numElements =
|
||||||
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
||||||
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
||||||
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
|
|
||||||
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
lhsTy.getElementType());
|
|
||||||
Value flatLhs =
|
|
||||||
rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape);
|
|
||||||
TensorType rhsTy = op.rhs().getType().template cast<TensorType>();
|
|
||||||
Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
rhsTy.getElementType());
|
|
||||||
Value flatRhs =
|
|
||||||
rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape);
|
|
||||||
|
|
||||||
// Apply actual operation to flattened operands.
|
// Flatten operands.
|
||||||
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
|
SmallVector<Value, 3> flatOperands;
|
||||||
|
for (Value operand : op.getOperation()->getOperands()) {
|
||||||
|
Type operandElementTy =
|
||||||
|
operand.getType().template cast<ShapedType>().getElementType();
|
||||||
|
Type flatTy =
|
||||||
|
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
|
||||||
|
Value flat =
|
||||||
|
rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape);
|
||||||
|
flatOperands.push_back(flat);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply operation to flattened operands.
|
||||||
|
Type resultElementTy =
|
||||||
|
op.getType().template cast<ShapedType>().getElementType();
|
||||||
|
Type flatResultTy =
|
||||||
|
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
|
||||||
|
Value flatResult =
|
||||||
|
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
|
||||||
|
|
||||||
// Restore original shape.
|
// Restore original shape.
|
||||||
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
|
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
|
||||||
|
@ -154,7 +125,7 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
struct TransformUnrankedHloPass
|
struct TransformUnrankedHloPass
|
||||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<shape::ShapeDialect>();
|
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
|
@ -183,19 +154,17 @@ struct TransformUnrankedHloPass
|
||||||
|
|
||||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
// TODO(frgossen): Populate all unary and binary operations.
|
#define MAP_UNARY(op) ElementwiseOpConversion<op>
|
||||||
// clang-format off
|
#define MAP_BINARY(op) ElementwiseOpConversion<op>
|
||||||
#define MAP_UNARY(op) UnaryElementwiseOpConversion<op>
|
|
||||||
#define MAP_BINARY(op) BinaryElementwiseOpConversion<op>
|
|
||||||
#define COMMA ,
|
#define COMMA ,
|
||||||
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
||||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)
|
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context);
|
||||||
>(context);
|
// clang-format on
|
||||||
#undef MAP_UNARY
|
#undef MAP_UNARY
|
||||||
#undef MAP_BINARY
|
#undef MAP_BINARY
|
||||||
#undef COMMA
|
#undef COMMA
|
||||||
// clang-format on
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
||||||
|
|
Loading…
Reference in New Issue