[MLIR] Simplify and generalize `transform-unranked-hlo`

This refactoring allows to support a wider range of n-ary operations in future
changes.

PiperOrigin-RevId: 331953362
This commit is contained in:
A. Unique TensorFlower 2020-09-16 01:12:09 -07:00 committed by TensorFlow MLIR Team
parent d1f85f32b1
commit da43c8596b
1 changed files with 49 additions and 80 deletions

View File

@ -49,99 +49,70 @@ namespace {
template <typename OpTy> template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) { target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
return llvm::all_of((op.getOperation())->getOperandTypes(), return llvm::all_of(op.getOperation()->getOperandTypes(),
[&](Type t) { return t.isa<RankedTensorType>(); }); [&](Type t) { return t.isa<RankedTensorType>(); });
}); });
} }
/// Unary element-wise operations on unranked tensors can be applied to the /// Element-wise operations on unranked tensors can be applied to the flattened
/// flattened tensor with the same effect. /// tensor operands with the same effect. This pattern rewrites every such
/// This pattern rewrites every such operation to /// operation to
/// (i) flatten the input tensor, /// (i) flatten the input tensor,
/// (ii) apply the unary operation, and /// (ii) apply the operation, and
/// (iii) restore the original shape. /// (iii) restore the original shape.
template <typename OpTy> template <typename OpTy>
struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> { struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
explicit UnaryElementwiseOpConversion(MLIRContext *context) explicit ElementwiseOpConversion(MLIRContext *context)
: OpRewritePattern<OpTy>(context) {} : OpRewritePattern<OpTy>(context) {}
LogicalResult matchAndRewrite(OpTy op, LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// Don't apply conversion to ops with statically shaped operands. // Don't apply conversion unless all operands are unranked.
Value operand = op.getOperand(); if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
auto operandTy = operand.getType().dyn_cast<TensorType>(); return operand.getType().isa<UnrankedTensorType>();
if (operandTy.hasRank()) return failure(); })) {
// Generate IR to flatten the operand.
auto loc = op.getLoc();
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
Type indexTy = rewriter.getIndexType();
Value numElements =
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
operandTy.getElementType());
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, flatTensorTy, operand, flatShape);
// Generate IR for the actual operation.
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
// Generate IR to restore the original shape.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, operandTy,
flatResult, shape);
return success();
}
};
/// Binary element-wise operation on unranked tensors can be applied to the
/// flattened operand tensors with the same effect.
/// This pattern rewrites every such operation to
/// (i) flatten the operand tensors,
/// (ii) apply the binary operation, and
// (iii) restore the original shape.
template <typename OpTy>
struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
explicit BinaryElementwiseOpConversion(MLIRContext *context)
: OpRewritePattern<OpTy>(context) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Don't apply conversion unless both operands are unranked.
if (op.lhs().getType().template isa<RankedTensorType>() ||
op.rhs().getType().template isa<RankedTensorType>()) {
return failure(); return failure();
} }
// Flatten operands. // Get operands' shape.
auto loc = op.getLoc(); auto loc = op.getLoc();
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
Value shapeLhs = SmallVector<Value, 3> operandShapes;
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.lhs()); for (Value operand : op.getOperation()->getOperands()) {
Value shapeRhs = Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, op.rhs()); rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
Value shape = rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes.push_back(shape);
ValueRange{shapeLhs, shapeRhs}); }
Value shape =
operandShapes.size() == 1
? operandShapes.front()
: rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
// Derive flat shape.
Type indexTy = rewriter.getIndexType(); Type indexTy = rewriter.getIndexType();
Value numElements = Value numElements =
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape); rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements); Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
lhsTy.getElementType());
Value flatLhs =
rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape);
TensorType rhsTy = op.rhs().getType().template cast<TensorType>();
Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
rhsTy.getElementType());
Value flatRhs =
rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape);
// Apply actual operation to flattened operands. // Flatten operands.
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs); SmallVector<Value, 3> flatOperands;
for (Value operand : op.getOperation()->getOperands()) {
Type operandElementTy =
operand.getType().template cast<ShapedType>().getElementType();
Type flatTy =
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
Value flat =
rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape);
flatOperands.push_back(flat);
}
// Apply operation to flattened operands.
Type resultElementTy =
op.getType().template cast<ShapedType>().getElementType();
Type flatResultTy =
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
Value flatResult =
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
// Restore original shape. // Restore original shape.
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult, rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
@ -154,7 +125,7 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
struct TransformUnrankedHloPass struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> { : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>(); registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
} }
void runOnFunction() override { void runOnFunction() override {
@ -183,19 +154,17 @@ struct TransformUnrankedHloPass
void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
// TODO(frgossen): Populate all unary and binary operations. #define MAP_UNARY(op) ElementwiseOpConversion<op>
// clang-format off #define MAP_BINARY(op) ElementwiseOpConversion<op>
#define MAP_UNARY(op) UnaryElementwiseOpConversion<op>
#define MAP_BINARY(op) BinaryElementwiseOpConversion<op>
#define COMMA , #define COMMA ,
// clang-format off
patterns->insert< patterns->insert<
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA) MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context);
>(context); // clang-format on
#undef MAP_UNARY #undef MAP_UNARY
#undef MAP_BINARY #undef MAP_BINARY
#undef COMMA #undef COMMA
// clang-format on
} }
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() { std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {