From da43c8596bff2f5b78c3f2c54be1fffeb31a4ee3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Sep 2020 01:12:09 -0700 Subject: [PATCH] [MLIR] Simplify and generalize `transform-unranked-hlo` This refactoring allows to support a wider range of n-ary operations in future changes. PiperOrigin-RevId: 331953362 --- .../mhlo/transforms/transform_unranked_hlo.cc | 129 +++++++----------- 1 file changed, 49 insertions(+), 80 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index bfa8cf5..f52239f 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -49,99 +49,70 @@ namespace { template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { - return llvm::all_of((op.getOperation())->getOperandTypes(), + return llvm::all_of(op.getOperation()->getOperandTypes(), [&](Type t) { return t.isa(); }); }); } -/// Unary element-wise operations on unranked tensors can be applied to the -/// flattened tensor with the same effect. -/// This pattern rewrites every such operation to +/// Element-wise operations on unranked tensors can be applied to the flattened +/// tensor operands with the same effect. This pattern rewrites every such +/// operation to /// (i) flatten the input tensor, -/// (ii) apply the unary operation, and +/// (ii) apply the operation, and /// (iii) restore the original shape. template -struct UnaryElementwiseOpConversion : public OpRewritePattern { - explicit UnaryElementwiseOpConversion(MLIRContext *context) +struct ElementwiseOpConversion : public OpRewritePattern { + explicit ElementwiseOpConversion(MLIRContext *context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Don't apply conversion to ops with statically shaped operands. - Value operand = op.getOperand(); - auto operandTy = operand.getType().dyn_cast(); - if (operandTy.hasRank()) return failure(); - - // Generate IR to flatten the operand. - auto loc = op.getLoc(); - Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); - Value shape = - rewriter.create(loc, extentTensorTy, operand); - Type indexTy = rewriter.getIndexType(); - Value numElements = - rewriter.create(loc, indexTy, shape); - Value flatShape = rewriter.create(loc, numElements); - auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - operandTy.getElementType()); - Value flatOperand = rewriter.create( - loc, flatTensorTy, operand, flatShape); - - // Generate IR for the actual operation. - Value flatResult = rewriter.create(loc, flatTensorTy, flatOperand); - - // Generate IR to restore the original shape. - rewriter.replaceOpWithNewOp(op, operandTy, - flatResult, shape); - - return success(); - } -}; - -/// Binary element-wise operation on unranked tensors can be applied to the -/// flattened operand tensors with the same effect. -/// This pattern rewrites every such operation to -/// (i) flatten the operand tensors, -/// (ii) apply the binary operation, and -// (iii) restore the original shape. -template -struct BinaryElementwiseOpConversion : public OpRewritePattern { - explicit BinaryElementwiseOpConversion(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Don't apply conversion unless both operands are unranked. - if (op.lhs().getType().template isa() || - op.rhs().getType().template isa()) { + // Don't apply conversion unless all operands are unranked. + if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) { + return operand.getType().isa(); + })) { return failure(); } - // Flatten operands. + // Get operands' shape. auto loc = op.getLoc(); Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); - Value shapeLhs = - rewriter.create(loc, extentTensorTy, op.lhs()); - Value shapeRhs = - rewriter.create(loc, extentTensorTy, op.rhs()); - Value shape = rewriter.create(loc, extentTensorTy, - ValueRange{shapeLhs, shapeRhs}); + SmallVector operandShapes; + for (Value operand : op.getOperation()->getOperands()) { + Value shape = + rewriter.create(loc, extentTensorTy, operand); + operandShapes.push_back(shape); + } + Value shape = + operandShapes.size() == 1 + ? operandShapes.front() + : rewriter.create(loc, extentTensorTy, operandShapes); + + // Derive flat shape. Type indexTy = rewriter.getIndexType(); Value numElements = rewriter.create(loc, indexTy, shape); Value flatShape = rewriter.create(loc, numElements); - TensorType lhsTy = op.lhs().getType().template cast(); - Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, - lhsTy.getElementType()); - Value flatLhs = - rewriter.create(loc, flatLhsTy, op.lhs(), flatShape); - TensorType rhsTy = op.rhs().getType().template cast(); - Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rhsTy.getElementType()); - Value flatRhs = - rewriter.create(loc, flatRhsTy, op.rhs(), flatShape); - // Apply actual operation to flattened operands. - Value flatResult = rewriter.create(loc, flatLhs, flatRhs); + // Flatten operands. + SmallVector flatOperands; + for (Value operand : op.getOperation()->getOperands()) { + Type operandElementTy = + operand.getType().template cast().getElementType(); + Type flatTy = + RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); + Value flat = + rewriter.create(loc, flatTy, operand, flatShape); + flatOperands.push_back(flat); + } + + // Apply operation to flattened operands. + Type resultElementTy = + op.getType().template cast().getElementType(); + Type flatResultTy = + RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy); + Value flatResult = + rewriter.create(loc, flatResultTy, flatOperands, op.getAttrs()); // Restore original shape. rewriter.replaceOpWithNewOp(op, op.getType(), flatResult, @@ -154,7 +125,7 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { struct TransformUnrankedHloPass : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnFunction() override { @@ -183,19 +154,17 @@ struct TransformUnrankedHloPass void PopulateTransformUnrankedHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - // TODO(frgossen): Populate all unary and binary operations. - // clang-format off -#define MAP_UNARY(op) UnaryElementwiseOpConversion -#define MAP_BINARY(op) BinaryElementwiseOpConversion +#define MAP_UNARY(op) ElementwiseOpConversion +#define MAP_BINARY(op) ElementwiseOpConversion #define COMMA , + // clang-format off patterns->insert< MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), - MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA) - >(context); + MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context); + // clang-format on #undef MAP_UNARY #undef MAP_BINARY #undef COMMA - // clang-format on } std::unique_ptr createTransformUnrankedHloPass() {