[MLIR] Extend unranked transformation to CHLO dialect
PiperOrigin-RevId: 332026604
This commit is contained in:
parent
2aa07b0091
commit
69b80d8deb
|
@ -344,14 +344,16 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||||
Type TensorType> : HLOClient_Op<mnemonic,
|
Type TensorType> : HLOClient_Op<mnemonic, !listconcat(traits, [
|
||||||
!listconcat(traits, [InferFusibilityOpInterface])> {
|
InferFusibilityOpInterface, NoSideEffect, SameOperandsAndResultType])> {
|
||||||
let arguments = (ins TensorType:$operand);
|
let arguments = (ins TensorType:$operand);
|
||||||
let results = (outs TensorType);
|
let results = (outs TensorType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$operand attr-dict `:` type($operand)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Acos operator";
|
let summary = "Acos operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -364,8 +366,8 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Tan operation";
|
let summary = "Tan operation";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -30,6 +30,9 @@ template <typename T>
|
||||||
class OperationPass;
|
class OperationPass;
|
||||||
class Pass;
|
class Pass;
|
||||||
|
|
||||||
|
// Transforms unranked HLO operations to ranked ones where possible.
|
||||||
|
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
|
||||||
|
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
/// Lowers HLO control flow ops to the Standard dialect.
|
/// Lowers HLO control flow ops to the Standard dialect.
|
||||||
|
@ -52,9 +55,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||||
// Lowers from HLO dialect to Linalg dialect.
|
// Lowers from HLO dialect to Linalg dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
||||||
|
|
||||||
// Transforms unranked HLO operations to ranked ones where possible.
|
|
||||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
|
|
||||||
|
|
||||||
// Sinks constants implicitly captured in control flow regions. This is
|
// Sinks constants implicitly captured in control flow regions. This is
|
||||||
// necessary to export to XLA.
|
// necessary to export to XLA.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||||
|
|
|
@ -28,6 +28,12 @@ class LLVMTypeConverter;
|
||||||
class LowerToLLVMOptions;
|
class LowerToLLVMOptions;
|
||||||
class OwningRewritePatternList;
|
class OwningRewritePatternList;
|
||||||
class BufferAssignmentPlacer;
|
class BufferAssignmentPlacer;
|
||||||
|
|
||||||
|
// Populates a collection of rewrite patterns to realize element-wise operations
|
||||||
|
// on ranked tensors where possible.
|
||||||
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
// Collection of rewrite patterns for lowering a general dot product.
|
// Collection of rewrite patterns for lowering a general dot product.
|
||||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
|
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
@ -27,7 +28,6 @@ limitations under the License.
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
|
@ -46,6 +46,9 @@ namespace {
|
||||||
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
||||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
||||||
|
|
||||||
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||||
|
@ -101,8 +104,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
operand.getType().template cast<ShapedType>().getElementType();
|
operand.getType().template cast<ShapedType>().getElementType();
|
||||||
Type flatTy =
|
Type flatTy =
|
||||||
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
|
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
|
||||||
Value flat =
|
Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
|
||||||
rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape);
|
flatShape);
|
||||||
flatOperands.push_back(flat);
|
flatOperands.push_back(flat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,8 +118,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
|
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
|
||||||
|
|
||||||
// Restore original shape.
|
// Restore original shape.
|
||||||
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
|
||||||
shape);
|
flatResult, shape);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -132,13 +135,16 @@ struct TransformUnrankedHloPass
|
||||||
// Setup conversion target.
|
// Setup conversion target.
|
||||||
MLIRContext &ctx = getContext();
|
MLIRContext &ctx = getContext();
|
||||||
ConversionTarget target(ctx);
|
ConversionTarget target(ctx);
|
||||||
target.addLegalDialect<MhloDialect, StandardOpsDialect,
|
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
||||||
shape::ShapeDialect>();
|
shape::ShapeDialect>();
|
||||||
target.addLegalOp<FuncOp>();
|
target.addLegalOp<FuncOp>();
|
||||||
#define ADD_LEGAL(op) AddLegalOpOnRankedTensor<op>(&target)
|
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
||||||
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;);
|
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
||||||
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;);
|
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
|
||||||
#undef ADD_LEGAL
|
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
|
||||||
|
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
|
||||||
|
#undef ADD_LEGAL_MHLO
|
||||||
|
#undef ADD_LEGAL_CHLO
|
||||||
|
|
||||||
// Populate rewrite patterns.
|
// Populate rewrite patterns.
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
@ -154,16 +160,19 @@ struct TransformUnrankedHloPass
|
||||||
|
|
||||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
#define MAP_UNARY(op) ElementwiseOpConversion<op>
|
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
|
||||||
#define MAP_BINARY(op) ElementwiseOpConversion<op>
|
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
|
||||||
|
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
|
||||||
#define COMMA ,
|
#define COMMA ,
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
||||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context);
|
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
|
||||||
|
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA)>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#undef MAP_UNARY
|
#undef MAP_UNARY
|
||||||
#undef MAP_BINARY
|
#undef MAP_BINARY
|
||||||
|
#undef MAP_CHLO_UNARY
|
||||||
#undef COMMA
|
#undef COMMA
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,5 +180,4 @@ std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
||||||
return std::make_unique<TransformUnrankedHloPass>();
|
return std::make_unique<TransformUnrankedHloPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s
|
// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// Check the validity of expected IR.
|
// Check the validity of expected IR.
|
||||||
// CHECK-LABEL: @sqr_transform_result
|
// CHECK-LABEL: @sqr_transform_result
|
||||||
|
@ -80,3 +80,19 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
%result = mhlo.add %a, %b : tensor<*xf32>
|
%result = mhlo.add %a, %b : tensor<*xf32>
|
||||||
return %result : tensor<*xf32>
|
return %result : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @tan
|
||||||
|
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
|
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
|
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||||
|
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32>
|
||||||
|
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %[[B]] : tensor<*xf32>
|
||||||
|
%result = chlo.tan %a : tensor<*xf32>
|
||||||
|
return %result : tensor<*xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue