2020-07-07 04:57:00 +08:00
|
|
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
|
|
|
|
==============================================================================*/
|
|
|
|
|
2020-09-17 00:48:43 +08:00
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
|
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/IR/Function.h"
|
|
|
|
#include "mlir/IR/MLIRContext.h"
|
|
|
|
#include "mlir/IR/Operation.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "mlir/IR/StandardTypes.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
namespace {
|
|
|
|
|
2020-07-28 15:55:58 +08:00
|
|
|
// TODO(herhut): Generate these out of op definitions.
|
|
|
|
#define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \
|
|
|
|
fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(CosOp) sep fn(ExpOp) \
|
|
|
|
sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) sep fn(IsFiniteOp) \
|
|
|
|
sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) sep fn(NotOp) \
|
|
|
|
sep fn(NegOp) sep fn(PopulationCountOp) sep fn(RealOp) \
|
|
|
|
sep fn(RoundOp) sep fn(RsqrtOp) sep fn(SignOp) sep fn(SinOp) \
|
|
|
|
sep fn(SqrtOp) sep fn(TanhOp)
|
|
|
|
|
|
|
|
// TODO(herhut): Generate these out of op definitions.
|
|
|
|
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
|
|
|
|
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
|
|
|
|
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \
|
|
|
|
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
|
|
|
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
|
|
|
|
2020-09-17 00:48:43 +08:00
|
|
|
// TODO(herhut): Generate these out of op definitions.
|
2020-09-18 19:25:55 +08:00
|
|
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
2020-10-05 20:06:35 +08:00
|
|
|
fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp)
|
2020-09-17 00:48:43 +08:00
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
template <typename OpTy>
|
|
|
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
|
|
|
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
2020-09-16 16:12:09 +08:00
|
|
|
return llvm::all_of(op.getOperation()->getOperandTypes(),
|
2020-07-07 04:57:00 +08:00
|
|
|
[&](Type t) { return t.isa<RankedTensorType>(); });
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2020-09-16 16:12:09 +08:00
|
|
|
/// Element-wise operations on unranked tensors can be applied to the flattened
|
|
|
|
/// tensor operands with the same effect. This pattern rewrites every such
|
|
|
|
/// operation to
|
2020-07-07 04:57:00 +08:00
|
|
|
/// (i) flatten the input tensor,
|
2020-09-16 16:12:09 +08:00
|
|
|
/// (ii) apply the operation, and
|
2020-07-07 04:57:00 +08:00
|
|
|
/// (iii) restore the original shape.
|
|
|
|
template <typename OpTy>
|
2020-09-16 16:12:09 +08:00
|
|
|
struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|
|
|
explicit ElementwiseOpConversion(MLIRContext *context)
|
2020-07-07 04:57:00 +08:00
|
|
|
: OpRewritePattern<OpTy>(context) {}
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2020-09-16 16:12:09 +08:00
|
|
|
// Don't apply conversion unless all operands are unranked.
|
|
|
|
if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
|
|
|
|
return operand.getType().isa<UnrankedTensorType>();
|
|
|
|
})) {
|
|
|
|
return failure();
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-09-16 16:12:09 +08:00
|
|
|
// Get operands' shape.
|
2020-07-07 04:57:00 +08:00
|
|
|
auto loc = op.getLoc();
|
2020-08-06 02:10:20 +08:00
|
|
|
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
2020-09-16 16:12:09 +08:00
|
|
|
SmallVector<Value, 3> operandShapes;
|
|
|
|
for (Value operand : op.getOperation()->getOperands()) {
|
|
|
|
Value shape =
|
|
|
|
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
|
|
|
operandShapes.push_back(shape);
|
|
|
|
}
|
2020-08-06 02:10:20 +08:00
|
|
|
Value shape =
|
2020-09-16 16:12:09 +08:00
|
|
|
operandShapes.size() == 1
|
|
|
|
? operandShapes.front()
|
|
|
|
: rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
|
|
|
|
|
|
|
|
// Derive flat shape.
|
2020-08-06 02:10:20 +08:00
|
|
|
Type indexTy = rewriter.getIndexType();
|
|
|
|
Value numElements =
|
|
|
|
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
|
|
|
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-09-16 16:12:09 +08:00
|
|
|
// Flatten operands.
|
|
|
|
SmallVector<Value, 3> flatOperands;
|
|
|
|
for (Value operand : op.getOperation()->getOperands()) {
|
|
|
|
Type operandElementTy =
|
|
|
|
operand.getType().template cast<ShapedType>().getElementType();
|
|
|
|
Type flatTy =
|
|
|
|
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
|
2020-09-17 00:48:43 +08:00
|
|
|
Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
|
|
|
|
flatShape);
|
2020-09-16 16:12:09 +08:00
|
|
|
flatOperands.push_back(flat);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
2020-09-16 16:12:09 +08:00
|
|
|
// Apply operation to flattened operands.
|
|
|
|
Type resultElementTy =
|
|
|
|
op.getType().template cast<ShapedType>().getElementType();
|
|
|
|
Type flatResultTy =
|
|
|
|
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
|
|
|
|
Value flatResult =
|
|
|
|
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Restore original shape.
|
2020-09-17 00:48:43 +08:00
|
|
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
|
|
|
|
flatResult, shape);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct TransformUnrankedHloPass
|
|
|
|
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
2020-08-26 11:30:05 +08:00
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
2020-09-16 16:12:09 +08:00
|
|
|
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
2020-08-26 11:30:05 +08:00
|
|
|
}
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
void runOnFunction() override {
|
|
|
|
// Setup conversion target.
|
|
|
|
MLIRContext &ctx = getContext();
|
|
|
|
ConversionTarget target(ctx);
|
2020-09-17 00:48:43 +08:00
|
|
|
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
2020-07-07 04:57:00 +08:00
|
|
|
shape::ShapeDialect>();
|
|
|
|
target.addLegalOp<FuncOp>();
|
2020-09-17 00:48:43 +08:00
|
|
|
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
|
|
|
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
|
|
|
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
|
|
|
|
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
|
|
|
|
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
|
|
|
|
#undef ADD_LEGAL_MHLO
|
|
|
|
#undef ADD_LEGAL_CHLO
|
2020-09-18 16:39:48 +08:00
|
|
|
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
|
|
|
|
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Populate rewrite patterns.
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
|
|
|
|
|
|
|
|
// Apply transformation.
|
2020-10-27 21:55:28 +08:00
|
|
|
if (failed(
|
|
|
|
applyPartialConversion(getFunction(), target, std::move(patterns))))
|
2020-07-07 04:57:00 +08:00
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|
|
|
OwningRewritePatternList *patterns) {
|
2020-09-17 00:48:43 +08:00
|
|
|
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
|
|
|
|
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
|
|
|
|
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
|
2020-07-28 15:55:58 +08:00
|
|
|
#define COMMA ,
|
2020-09-16 16:12:09 +08:00
|
|
|
// clang-format off
|
2020-07-07 04:57:00 +08:00
|
|
|
patterns->insert<
|
2020-07-28 15:55:58 +08:00
|
|
|
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
2020-09-17 00:48:43 +08:00
|
|
|
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
|
2020-09-18 16:39:48 +08:00
|
|
|
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA),
|
|
|
|
ElementwiseOpConversion<mhlo::CompareOp>,
|
|
|
|
ElementwiseOpConversion<mhlo::SelectOp>>(context);
|
2020-09-16 16:12:09 +08:00
|
|
|
// clang-format on
|
2020-07-28 15:55:58 +08:00
|
|
|
#undef MAP_UNARY
|
|
|
|
#undef MAP_BINARY
|
2020-09-17 00:48:43 +08:00
|
|
|
#undef MAP_CHLO_UNARY
|
2020-07-28 15:55:58 +08:00
|
|
|
#undef COMMA
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
2020-09-08 21:05:50 +08:00
|
|
|
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
2020-07-29 07:12:08 +08:00
|
|
|
return std::make_unique<TransformUnrankedHloPass>();
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
} // namespace mlir
|