[MLIR] Simplify and generalize `transform-unranked-hlo`

This refactoring allows to support a wider range of n-ary operations in future
changes.

PiperOrigin-RevId: 331953362
This commit is contained in:
A. Unique TensorFlower 2020-09-16 01:12:09 -07:00 committed by TensorFlow MLIR Team
parent d1f85f32b1
commit da43c8596b
1 changed files with 49 additions and 80 deletions

View File

@ -49,99 +49,70 @@ namespace {
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
return llvm::all_of((op.getOperation())->getOperandTypes(),
return llvm::all_of(op.getOperation()->getOperandTypes(),
[&](Type t) { return t.isa<RankedTensorType>(); });
});
}
/// Unary element-wise operations on unranked tensors can be applied to the
/// flattened tensor with the same effect.
/// This pattern rewrites every such operation to
/// Element-wise operations on unranked tensors can be applied to the flattened
/// tensor operands with the same effect. This pattern rewrites every such
/// operation to
/// (i) flatten the input tensor,
/// (ii) apply the unary operation, and
/// (ii) apply the operation, and
/// (iii) restore the original shape.
template <typename OpTy>
struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
explicit UnaryElementwiseOpConversion(MLIRContext *context)
struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
explicit ElementwiseOpConversion(MLIRContext *context)
: OpRewritePattern<OpTy>(context) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Don't apply conversion to ops with statically shaped operands.
Value operand = op.getOperand();
auto operandTy = operand.getType().dyn_cast<TensorType>();
if (operandTy.hasRank()) return failure();
// Generate IR to flatten the operand.
auto loc = op.getLoc();
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
Type indexTy = rewriter.getIndexType();
Value numElements =
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
operandTy.getElementType());
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, flatTensorTy, operand, flatShape);
// Generate IR for the actual operation.
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
// Generate IR to restore the original shape.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy,
flatResult, shape);
return success();
}
};
/// Binary element-wise operation on unranked tensors can be applied to the
/// flattened operand tensors with the same effect.
/// This pattern rewrites every such operation to
/// (i) flatten the operand tensors,
/// (ii) apply the binary operation, and
// (iii) restore the original shape.
template <typename OpTy>
struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
explicit BinaryElementwiseOpConversion(MLIRContext *context)
: OpRewritePattern<OpTy>(context) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Don't apply conversion unless both operands are unranked.
if (op.lhs().getType().template isa<RankedTensorType>() ||
op.rhs().getType().template isa<RankedTensorType>()) {
// Don't apply conversion unless all operands are unranked.
if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
return operand.getType().isa<UnrankedTensorType>();
})) {
return failure();
}
// Flatten operands.
// Get operands' shape.
auto loc = op.getLoc();
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
Value shapeLhs =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs());
Value shapeRhs =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs());
Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy,
ValueRange{shapeLhs, shapeRhs});
SmallVector<Value, 3> operandShapes;
for (Value operand : op.getOperation()->getOperands()) {
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
operandShapes.push_back(shape);
}
Value shape =
operandShapes.size() == 1
? operandShapes.front()
: rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
// Derive flat shape.
Type indexTy = rewriter.getIndexType();
Value numElements =
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
lhsTy.getElementType());
Value flatLhs =
rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape);
TensorType rhsTy = op.rhs().getType().template cast<TensorType>();
Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
rhsTy.getElementType());
Value flatRhs =
rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape);
// Apply actual operation to flattened operands.
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
// Flatten operands.
SmallVector<Value, 3> flatOperands;
for (Value operand : op.getOperation()->getOperands()) {
Type operandElementTy =
operand.getType().template cast<ShapedType>().getElementType();
Type flatTy =
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
Value flat =
rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape);
flatOperands.push_back(flat);
}
// Apply operation to flattened operands.
Type resultElementTy =
op.getType().template cast<ShapedType>().getElementType();
Type flatResultTy =
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
Value flatResult =
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
// Restore original shape.
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
@ -154,7 +125,7 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>();
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
}
void runOnFunction() override {
@ -183,19 +154,17 @@ struct TransformUnrankedHloPass
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// TODO(frgossen): Populate all unary and binary operations.
// clang-format off
#define MAP_UNARY(op) UnaryElementwiseOpConversion<op>
#define MAP_BINARY(op) BinaryElementwiseOpConversion<op>
#define MAP_UNARY(op) ElementwiseOpConversion<op>
#define MAP_BINARY(op) ElementwiseOpConversion<op>
#define COMMA ,
// clang-format off
patterns->insert<
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)
>(context);
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context);
// clang-format on
#undef MAP_UNARY
#undef MAP_BINARY
#undef COMMA
// clang-format on
}
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {