[MLIR] Extend unranked transformation to CHLO dialect
PiperOrigin-RevId: 332026604
This commit is contained in:
parent
2aa07b0091
commit
69b80d8deb
|
@ -344,14 +344,16 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||
Type TensorType> : HLOClient_Op<mnemonic,
|
||||
!listconcat(traits, [InferFusibilityOpInterface])> {
|
||||
Type TensorType> : HLOClient_Op<mnemonic, !listconcat(traits, [
|
||||
InferFusibilityOpInterface, NoSideEffect, SameOperandsAndResultType])> {
|
||||
let arguments = (ins TensorType:$operand);
|
||||
let results = (outs TensorType);
|
||||
let results = (outs TensorType:$result);
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand)";
|
||||
}
|
||||
|
||||
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
||||
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Acos operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -364,8 +366,8 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
||||
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Tan operation";
|
||||
|
||||
let description = [{
|
||||
|
|
|
@ -30,6 +30,9 @@ template <typename T>
|
|||
class OperationPass;
|
||||
class Pass;
|
||||
|
||||
// Transforms unranked HLO operations to ranked ones where possible.
|
||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
|
||||
|
||||
namespace mhlo {
|
||||
|
||||
/// Lowers HLO control flow ops to the Standard dialect.
|
||||
|
@ -52,9 +55,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
|||
// Lowers from HLO dialect to Linalg dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
||||
|
||||
// Transforms unranked HLO operations to ranked ones where possible.
|
||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
|
||||
|
||||
// Sinks constants implicitly captured in control flow regions. This is
|
||||
// necessary to export to XLA.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||
|
|
|
@ -28,6 +28,12 @@ class LLVMTypeConverter;
|
|||
class LowerToLLVMOptions;
|
||||
class OwningRewritePatternList;
|
||||
class BufferAssignmentPlacer;
|
||||
|
||||
// Populates a collection of rewrite patterns to realize element-wise operations
|
||||
// on ranked tensors where possible.
|
||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
namespace mhlo {
|
||||
|
||||
// Collection of rewrite patterns for lowering a general dot product.
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
|
@ -27,7 +28,6 @@ limitations under the License.
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
|
@ -46,6 +46,9 @@ namespace {
|
|||
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||
|
@ -101,8 +104,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
operand.getType().template cast<ShapedType>().getElementType();
|
||||
Type flatTy =
|
||||
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
|
||||
Value flat =
|
||||
rewriter.create<DynamicReshapeOp>(loc, flatTy, operand, flatShape);
|
||||
Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
|
||||
flatShape);
|
||||
flatOperands.push_back(flat);
|
||||
}
|
||||
|
||||
|
@ -115,8 +118,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
|
||||
|
||||
// Restore original shape.
|
||||
rewriter.replaceOpWithNewOp<DynamicReshapeOp>(op, op.getType(), flatResult,
|
||||
shape);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
|
||||
flatResult, shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -132,13 +135,16 @@ struct TransformUnrankedHloPass
|
|||
// Setup conversion target.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<MhloDialect, StandardOpsDialect,
|
||||
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
#define ADD_LEGAL(op) AddLegalOpOnRankedTensor<op>(&target)
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;);
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;);
|
||||
#undef ADD_LEGAL
|
||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
|
||||
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
|
||||
#undef ADD_LEGAL_MHLO
|
||||
#undef ADD_LEGAL_CHLO
|
||||
|
||||
// Populate rewrite patterns.
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -154,16 +160,19 @@ struct TransformUnrankedHloPass
|
|||
|
||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
#define MAP_UNARY(op) ElementwiseOpConversion<op>
|
||||
#define MAP_BINARY(op) ElementwiseOpConversion<op>
|
||||
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
|
||||
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
|
||||
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
|
||||
#define COMMA ,
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)>(context);
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
|
||||
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA)>(context);
|
||||
// clang-format on
|
||||
#undef MAP_UNARY
|
||||
#undef MAP_BINARY
|
||||
#undef MAP_CHLO_UNARY
|
||||
#undef COMMA
|
||||
}
|
||||
|
||||
|
@ -171,5 +180,4 @@ std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
|||
return std::make_unique<TransformUnrankedHloPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
|
||||
|
||||
// Check the validity of expected IR.
|
||||
// CHECK-LABEL: @sqr_transform_result
|
||||
|
@ -80,3 +80,19 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
|||
%result = mhlo.add %a, %b : tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tan
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
||||
func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32>
|
||||
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[B]] : tensor<*xf32>
|
||||
%result = chlo.tan %a : tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
}
|
Loading…
Reference in New Issue