[MLIR] Extend unranked transformation to CHLO dialect

PiperOrigin-RevId: 332026604
This commit is contained in:
A. Unique TensorFlower 2020-09-16 09:48:43 -07:00 committed by TensorFlow MLIR Team
parent 2aa07b0091
commit 69b80d8deb
5 changed files with 57 additions and 25 deletions

View File

@ -344,14 +344,16 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
//===----------------------------------------------------------------------===//
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType> : HLOClient_Op<mnemonic,
!listconcat(traits, [InferFusibilityOpInterface])> {
Type TensorType> : HLOClient_Op<mnemonic, !listconcat(traits, [
InferFusibilityOpInterface, NoSideEffect, SameOperandsAndResultType])> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType);
let results = (outs TensorType:$result);
let assemblyFormat = "$operand attr-dict `:` type($operand)";
}
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
HLO_FpOrComplexTensor> {
let summary = "Acos operator";
let description = [{
@ -364,8 +366,8 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
}];
}
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
HLO_FpOrComplexTensor> {
let summary = "Tan operation";
let description = [{

View File

@ -30,6 +30,9 @@ template <typename T>
class OperationPass;
class Pass;
// Transforms unranked HLO operations to ranked ones where possible.
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
namespace mhlo {
/// Lowers HLO control flow ops to the Standard dialect.
@ -52,9 +55,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
// Transforms unranked HLO operations to ranked ones where possible.
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
// Sinks constants implicitly captured in control flow regions. This is
// necessary to export to XLA.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();

View File

@ -28,6 +28,12 @@ class LLVMTypeConverter;
class LowerToLLVMOptions;
class OwningRewritePatternList;
class BufferAssignmentPlacer;
// Populates a collection of rewrite patterns to realize element-wise operations
// on ranked tensors where possible.
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
namespace mhlo {
// Collection of rewrite patterns for lowering a general dot product.

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
@ -27,7 +28,6 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace mhlo {
namespace {
// TODO(herhut): Generate these out of op definitions.
@ -46,6 +46,9 @@ namespace {
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
@ -101,8 +104,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
operand.getType().template cast<ShapedType>().getElementType();
Type flatTy =
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
Value flat =
rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape);
Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
flatShape);
flatOperands.push_back(flat);
}
@ -115,8 +118,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
// Restore original shape.
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
shape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
flatResult, shape);
return success();
}
@ -132,13 +135,16 @@ struct TransformUnrankedHloPass
// Setup conversion target.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<MhloDialect, StandardOpsDialect,
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
shape::ShapeDialect>();
target.addLegalOp<FuncOp>();
#define ADD_LEGAL(op) AddLegalOpOnRankedTensor<op>(&target)
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;);
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;);
#undef ADD_LEGAL
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
#undef ADD_LEGAL_MHLO
#undef ADD_LEGAL_CHLO
// Populate rewrite patterns.
OwningRewritePatternList patterns;
@ -154,16 +160,19 @@ struct TransformUnrankedHloPass
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
#define MAP_UNARY(op) ElementwiseOpConversion<op>
#define MAP_BINARY(op) ElementwiseOpConversion<op>
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
#define COMMA ,
// clang-format off
patterns->insert<
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context);
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA)>(context);
// clang-format on
#undef MAP_UNARY
#undef MAP_BINARY
#undef MAP_CHLO_UNARY
#undef COMMA
}
@ -171,5 +180,4 @@ std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
return std::make_unique<TransformUnrankedHloPass>();
}
} // namespace mhlo
} // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s
// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
// Check the validity of expected IR.
// CHECK-LABEL: @sqr_transform_result
@ -80,3 +80,19 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
%result = mhlo.add %a, %b : tensor<*xf32>
return %result : tensor<*xf32>
}
// -----
// CHECK-LABEL: @tan
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32>
func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32>
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[B]] : tensor<*xf32>
%result = chlo.tan %a : tensor<*xf32>
return %result : tensor<*xf32>
}