[MLIR] Extend unranked transformation to CHLO dialect

PiperOrigin-RevId: 332026604
This commit is contained in:
A. Unique TensorFlower 2020-09-16 09:48:43 -07:00 committed by TensorFlow MLIR Team
parent 2aa07b0091
commit 69b80d8deb
5 changed files with 57 additions and 25 deletions

View File

@ -344,14 +344,16 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits, class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType> : HLOClient_Op<mnemonic, Type TensorType> : HLOClient_Op<mnemonic, !listconcat(traits, [
!listconcat(traits, [InferFusibilityOpInterface])> { InferFusibilityOpInterface, NoSideEffect, SameOperandsAndResultType])> {
let arguments = (ins TensorType:$operand); let arguments = (ins TensorType:$operand);
let results = (outs TensorType); let results = (outs TensorType:$result);
let assemblyFormat = "$operand attr-dict `:` type($operand)";
} }
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { HLO_FpOrComplexTensor> {
let summary = "Acos operator"; let summary = "Acos operator";
let description = [{ let description = [{
@ -364,8 +366,8 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
}]; }];
} }
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { HLO_FpOrComplexTensor> {
let summary = "Tan operation"; let summary = "Tan operation";
let description = [{ let description = [{

View File

@ -30,6 +30,9 @@ template <typename T>
class OperationPass; class OperationPass;
class Pass; class Pass;
// Transforms unranked HLO operations to ranked ones where possible.
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
namespace mhlo { namespace mhlo {
/// Lowers HLO control flow ops to the Standard dialect. /// Lowers HLO control flow ops to the Standard dialect.
@ -52,9 +55,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
// Lowers from HLO dialect to Linalg dialect. // Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
// Transforms unranked HLO operations to ranked ones where possible.
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
// Sinks constants implicitly captured in control flow regions. This is // Sinks constants implicitly captured in control flow regions. This is
// necessary to export to XLA. // necessary to export to XLA.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass(); std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();

View File

@ -28,6 +28,12 @@ class LLVMTypeConverter;
class LowerToLLVMOptions; class LowerToLLVMOptions;
class OwningRewritePatternList; class OwningRewritePatternList;
class BufferAssignmentPlacer; class BufferAssignmentPlacer;
// Populates a collection of rewrite patterns to realize element-wise operations
// on ranked tensors where possible.
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
namespace mhlo { namespace mhlo {
// Collection of rewrite patterns for lowering a general dot product. // Collection of rewrite patterns for lowering a general dot product.

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
@ -27,7 +28,6 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace mhlo {
namespace { namespace {
// TODO(herhut): Generate these out of op definitions. // TODO(herhut): Generate these out of op definitions.
@ -46,6 +46,9 @@ namespace {
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(ShiftRightLogicalOp) sep fn(SubOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp)
template <typename OpTy> template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) { target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
@ -101,8 +104,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
operand.getType().template cast<ShapedType>().getElementType(); operand.getType().template cast<ShapedType>().getElementType();
Type flatTy = Type flatTy =
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
Value flat = Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape); flatShape);
flatOperands.push_back(flat); flatOperands.push_back(flat);
} }
@ -115,8 +118,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs()); rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
// Restore original shape. // Restore original shape.
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult, rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
shape); flatResult, shape);
return success(); return success();
} }
@ -132,13 +135,16 @@ struct TransformUnrankedHloPass
// Setup conversion target. // Setup conversion target.
MLIRContext &ctx = getContext(); MLIRContext &ctx = getContext();
ConversionTarget target(ctx); ConversionTarget target(ctx);
target.addLegalDialect<MhloDialect, StandardOpsDialect, target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
shape::ShapeDialect>(); shape::ShapeDialect>();
target.addLegalOp<FuncOp>(); target.addLegalOp<FuncOp>();
#define ADD_LEGAL(op) AddLegalOpOnRankedTensor<op>(&target) #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;); #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;); MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
#undef ADD_LEGAL MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
#undef ADD_LEGAL_MHLO
#undef ADD_LEGAL_CHLO
// Populate rewrite patterns. // Populate rewrite patterns.
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
@ -154,16 +160,19 @@ struct TransformUnrankedHloPass
void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
#define MAP_UNARY(op) ElementwiseOpConversion<op> #define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
#define MAP_BINARY(op) ElementwiseOpConversion<op> #define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
#define COMMA , #define COMMA ,
// clang-format off // clang-format off
patterns->insert< patterns->insert<
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context); MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA)>(context);
// clang-format on // clang-format on
#undef MAP_UNARY #undef MAP_UNARY
#undef MAP_BINARY #undef MAP_BINARY
#undef MAP_CHLO_UNARY
#undef COMMA #undef COMMA
} }
@ -171,5 +180,4 @@ std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
return std::make_unique<TransformUnrankedHloPass>(); return std::make_unique<TransformUnrankedHloPass>();
} }
} // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s // RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
// Check the validity of expected IR. // Check the validity of expected IR.
// CHECK-LABEL: @sqr_transform_result // CHECK-LABEL: @sqr_transform_result
@ -80,3 +80,19 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
%result = mhlo.add %a, %b : tensor<*xf32> %result = mhlo.add %a, %b : tensor<*xf32>
return %result : tensor<*xf32> return %result : tensor<*xf32>
} }
// -----
// CHECK-LABEL: @tan
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32>
func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32>
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[B]] : tensor<*xf32>
%result = chlo.tan %a : tensor<*xf32>
return %result : tensor<*xf32>
}