From 69b80d8deb9071a40c904c8faee4f45216a5cffe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Sep 2020 09:48:43 -0700 Subject: [PATCH] [MLIR] Extend unranked transformation to CHLO dialect PiperOrigin-RevId: 332026604 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 16 +++++---- .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 6 ++-- .../Dialect/mhlo/transforms/rewriters.h | 6 ++++ .../mhlo/transforms/transform_unranked_hlo.cc | 36 +++++++++++-------- ...anked.mlir => hlo-transform-unranked.mlir} | 18 +++++++++- 5 files changed, 57 insertions(+), 25 deletions(-) rename tests/{mhlo-transform-unranked.mlir => hlo-transform-unranked.mlir} (80%) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 04f984a..54b40fe 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -344,14 +344,16 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< //===----------------------------------------------------------------------===// class HLOClient_UnaryElementwiseOp traits, - Type TensorType> : HLOClient_Op { + Type TensorType> : HLOClient_Op { let arguments = (ins TensorType:$operand); - let results = (outs TensorType); + let results = (outs TensorType:$result); + + let assemblyFormat = "$operand attr-dict `:` type($operand)"; } -def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { +def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [], + HLO_FpOrComplexTensor> { let summary = "Acos operator"; let description = [{ @@ -364,8 +366,8 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", }]; } -def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { +def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [], + HLO_FpOrComplexTensor> { let summary = "Tan operation"; let description = [{ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index bc6e3fd..fae79d9 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -30,6 +30,9 @@ template class OperationPass; class Pass; +// Transforms unranked HLO operations to ranked ones where possible. +std::unique_ptr createTransformUnrankedHloPass(); + namespace mhlo { /// Lowers HLO control flow ops to the Standard dialect. @@ -52,9 +55,6 @@ std::unique_ptr> createLegalizeToLhloPass( // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); -// Transforms unranked HLO operations to ranked ones where possible. -std::unique_ptr createTransformUnrankedHloPass(); - // Sinks constants implicitly captured in control flow regions. This is // necessary to export to XLA. std::unique_ptr> createSinkConstantsToControlFlowPass(); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 3d0fc2a..cf21a95 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -28,6 +28,12 @@ class LLVMTypeConverter; class LowerToLLVMOptions; class OwningRewritePatternList; class BufferAssignmentPlacer; + +// Populates a collection of rewrite patterns to realize element-wise operations +// on ranked tensors where possible. +void PopulateTransformUnrankedHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + namespace mhlo { // Collection of rewrite patterns for lowering a general dot product. diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index f52239f..1500a96 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -27,7 +28,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" namespace mlir { -namespace mhlo { namespace { // TODO(herhut): Generate these out of op definitions. @@ -46,6 +46,9 @@ namespace { sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ sep fn(ShiftRightLogicalOp) sep fn(SubOp) +// TODO(herhut): Generate these out of op definitions. +#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp) + template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { @@ -101,8 +104,8 @@ struct ElementwiseOpConversion : public OpRewritePattern { operand.getType().template cast().getElementType(); Type flatTy = RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); - Value flat = - rewriter.create(loc, flatTy, operand, flatShape); + Value flat = rewriter.create(loc, flatTy, operand, + flatShape); flatOperands.push_back(flat); } @@ -115,8 +118,8 @@ struct ElementwiseOpConversion : public OpRewritePattern { rewriter.create(loc, flatResultTy, flatOperands, op.getAttrs()); // Restore original shape. - rewriter.replaceOpWithNewOp(op, op.getType(), flatResult, - shape); + rewriter.replaceOpWithNewOp(op, op.getType(), + flatResult, shape); return success(); } @@ -132,13 +135,16 @@ struct TransformUnrankedHloPass // Setup conversion target. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); target.addLegalOp(); -#define ADD_LEGAL(op) AddLegalOpOnRankedTensor(&target) - MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;); - MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;); -#undef ADD_LEGAL +#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor(&target) +#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor(&target) + MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;); + MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;); + MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;); +#undef ADD_LEGAL_MHLO +#undef ADD_LEGAL_CHLO // Populate rewrite patterns. OwningRewritePatternList patterns; @@ -154,16 +160,19 @@ struct TransformUnrankedHloPass void PopulateTransformUnrankedHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { -#define MAP_UNARY(op) ElementwiseOpConversion -#define MAP_BINARY(op) ElementwiseOpConversion +#define MAP_UNARY(op) ElementwiseOpConversion +#define MAP_BINARY(op) ElementwiseOpConversion +#define MAP_CHLO_UNARY(op) ElementwiseOpConversion #define COMMA , // clang-format off patterns->insert< MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), - MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context); + MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA), + MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA)>(context); // clang-format on #undef MAP_UNARY #undef MAP_BINARY +#undef MAP_CHLO_UNARY #undef COMMA } @@ -171,5 +180,4 @@ std::unique_ptr createTransformUnrankedHloPass() { return std::make_unique(); } -} // namespace mhlo } // namespace mlir diff --git a/tests/mhlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir similarity index 80% rename from tests/mhlo-transform-unranked.mlir rename to tests/hlo-transform-unranked.mlir index 187e8f3..ae61fc8 100644 --- a/tests/mhlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s // Check the validity of expected IR. // CHECK-LABEL: @sqr_transform_result @@ -80,3 +80,19 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { %result = mhlo.add %a, %b : tensor<*xf32> return %result : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: @tan +// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32> +func @tan(%a : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor + // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> + // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor + // CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: return %[[B]] : tensor<*xf32> + %result = chlo.tan %a : tensor<*xf32> + return %result : tensor<*xf32> +}