2020-07-07 04:57:00 +08:00
|
|
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
|
|
|
|
==============================================================================*/
|
|
|
|
|
2020-09-17 00:48:43 +08:00
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
2020-10-30 17:55:49 +08:00
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
2020-10-30 17:55:49 +08:00
|
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/IR/Function.h"
|
|
|
|
#include "mlir/IR/MLIRContext.h"
|
|
|
|
#include "mlir/IR/Operation.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "mlir/IR/StandardTypes.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
namespace {
|
|
|
|
|
2020-07-28 15:55:58 +08:00
|
|
|
// TODO(herhut): Generate these out of op definitions.
|
|
|
|
#define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \
|
|
|
|
fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(CosOp) sep fn(ExpOp) \
|
|
|
|
sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) sep fn(IsFiniteOp) \
|
|
|
|
sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) sep fn(NotOp) \
|
|
|
|
sep fn(NegOp) sep fn(PopulationCountOp) sep fn(RealOp) \
|
|
|
|
sep fn(RoundOp) sep fn(RsqrtOp) sep fn(SignOp) sep fn(SinOp) \
|
|
|
|
sep fn(SqrtOp) sep fn(TanhOp)
|
|
|
|
|
|
|
|
// TODO(herhut): Generate these out of op definitions.
|
|
|
|
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
|
|
|
|
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
|
|
|
|
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \
|
|
|
|
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
|
|
|
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
|
|
|
|
2020-09-17 00:48:43 +08:00
|
|
|
// TODO(herhut): Generate these out of op definitions.
|
2020-11-25 02:55:01 +08:00
|
|
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
2020-11-25 22:36:42 +08:00
|
|
|
fn(AcosOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(ErfOp) sep fn(ErfcOp) \
|
|
|
|
sep fn(SinhOp) sep fn(TanOp)
|
2020-09-17 00:48:43 +08:00
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
template <typename OpTy>
|
|
|
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
|
|
|
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
2020-09-16 16:12:09 +08:00
|
|
|
return llvm::all_of(op.getOperation()->getOperandTypes(),
|
2020-07-07 04:57:00 +08:00
|
|
|
[&](Type t) { return t.isa<RankedTensorType>(); });
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2020-09-16 16:12:09 +08:00
|
|
|
/// Element-wise operations on unranked tensors can be applied to the flattened
|
|
|
|
/// tensor operands with the same effect. This pattern rewrites every such
|
|
|
|
/// operation to
|
2020-07-07 04:57:00 +08:00
|
|
|
/// (i) flatten the input tensor,
|
2020-09-16 16:12:09 +08:00
|
|
|
/// (ii) apply the operation, and
|
2020-07-07 04:57:00 +08:00
|
|
|
/// (iii) restore the original shape.
|
|
|
|
template <typename OpTy>
|
2020-09-16 16:12:09 +08:00
|
|
|
struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|
|
|
explicit ElementwiseOpConversion(MLIRContext *context)
|
2020-07-07 04:57:00 +08:00
|
|
|
: OpRewritePattern<OpTy>(context) {}
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2020-09-16 16:12:09 +08:00
|
|
|
// Don't apply conversion unless all operands are unranked.
|
|
|
|
if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
|
|
|
|
return operand.getType().isa<UnrankedTensorType>();
|
|
|
|
})) {
|
|
|
|
return failure();
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-09-16 16:12:09 +08:00
|
|
|
// Get operands' shape.
|
2020-07-07 04:57:00 +08:00
|
|
|
auto loc = op.getLoc();
|
2020-08-06 02:10:20 +08:00
|
|
|
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
2020-09-16 16:12:09 +08:00
|
|
|
SmallVector<Value, 3> operandShapes;
|
|
|
|
for (Value operand : op.getOperation()->getOperands()) {
|
|
|
|
Value shape =
|
|
|
|
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
|
|
|
operandShapes.push_back(shape);
|
|
|
|
}
|
2020-08-06 02:10:20 +08:00
|
|
|
Value shape =
|
2020-09-16 16:12:09 +08:00
|
|
|
operandShapes.size() == 1
|
|
|
|
? operandShapes.front()
|
|
|
|
: rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
|
|
|
|
|
|
|
|
// Derive flat shape.
|
2020-08-06 02:10:20 +08:00
|
|
|
Type indexTy = rewriter.getIndexType();
|
|
|
|
Value numElements =
|
|
|
|
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
|
|
|
Value flatShape = rewriter.create<TensorFromElementsOp>(loc, numElements);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-09-16 16:12:09 +08:00
|
|
|
// Flatten operands.
|
|
|
|
SmallVector<Value, 3> flatOperands;
|
|
|
|
for (Value operand : op.getOperation()->getOperands()) {
|
|
|
|
Type operandElementTy =
|
|
|
|
operand.getType().template cast<ShapedType>().getElementType();
|
|
|
|
Type flatTy =
|
|
|
|
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
|
2020-09-17 00:48:43 +08:00
|
|
|
Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
|
|
|
|
flatShape);
|
2020-09-16 16:12:09 +08:00
|
|
|
flatOperands.push_back(flat);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
2020-09-16 16:12:09 +08:00
|
|
|
// Apply operation to flattened operands.
|
|
|
|
Type resultElementTy =
|
|
|
|
op.getType().template cast<ShapedType>().getElementType();
|
|
|
|
Type flatResultTy =
|
|
|
|
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
|
|
|
|
Value flatResult =
|
|
|
|
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op.getAttrs());
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Restore original shape.
|
2020-09-17 00:48:43 +08:00
|
|
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
|
|
|
|
flatResult, shape);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-10-30 17:55:49 +08:00
|
|
|
// Converts a broadcasting binary operation with a scalar operand and an
|
|
|
|
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
|
|
|
// the unranked operand to a 1D tensor. This will always be safe because
|
|
|
|
// broadcasting from a scalar to another shape always works.
|
|
|
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
|
|
|
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
|
|
|
: public OpConversionPattern<ChloOpTy> {
|
|
|
|
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
ChloOpTy op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
typename ChloOpTy::Adaptor transformed(operands);
|
|
|
|
Value lhs = transformed.lhs();
|
|
|
|
Value rhs = transformed.rhs();
|
|
|
|
|
|
|
|
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
|
|
|
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
|
|
|
|
|
|
|
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
|
|
|
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
|
|
|
|
|
|
|
bool lhs_is_scalar = lhs_ranked_type &&
|
|
|
|
lhs_ranked_type.getShape().empty() &&
|
|
|
|
rhs_unranked_type;
|
|
|
|
bool rhs_is_scalar = rhs_ranked_type &&
|
|
|
|
rhs_ranked_type.getShape().empty() &&
|
|
|
|
lhs_unranked_type;
|
|
|
|
|
|
|
|
// Only support the case where exactly one operand is scalar and the other
|
|
|
|
// is unranked. Other patterns in chlo-to-hlo legalization will create more
|
|
|
|
// efficient lowerings for cases where both ranks are known or will handle
|
|
|
|
// the more generic case of both inputs being unranked.
|
|
|
|
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
|
|
|
|
2020-11-26 20:19:51 +08:00
|
|
|
auto scalar_element_type = lhs_is_scalar ? lhs_ranked_type.getElementType()
|
|
|
|
: rhs_ranked_type.getElementType();
|
2020-10-30 17:55:49 +08:00
|
|
|
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
2020-11-26 20:19:51 +08:00
|
|
|
auto result_element_type = result_type.getElementType();
|
2020-10-30 17:55:49 +08:00
|
|
|
|
|
|
|
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
|
|
|
Value shape =
|
|
|
|
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
|
|
|
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
|
|
|
Value size_tensor =
|
|
|
|
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
|
|
|
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
2020-11-26 20:19:51 +08:00
|
|
|
loc, RankedTensorType::get({-1}, scalar_element_type),
|
2020-10-30 17:55:49 +08:00
|
|
|
lhs_is_scalar ? rhs : lhs, size_tensor);
|
|
|
|
|
|
|
|
// Create a new ranked Chlo op that will be further lowered by other
|
|
|
|
// patterns into Mhlo.
|
|
|
|
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
|
|
|
rhs_is_scalar ? rhs : reshaped};
|
2020-11-26 20:19:51 +08:00
|
|
|
Value computed = rewriter.create<ChloOpTy>(
|
|
|
|
loc, TypeRange{RankedTensorType::get({-1}, result_element_type)},
|
|
|
|
new_operands, op.getAttrs());
|
2020-10-30 17:55:49 +08:00
|
|
|
|
|
|
|
// Reshape the result back into an unranked tensor.
|
|
|
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
|
|
|
computed, shape);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// Handles lowering of the following pattern to patterns that will be further
|
|
|
|
// matched by other patterns until they result in LHLO:
|
|
|
|
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
|
|
|
//
|
|
|
|
// The sequence of specializations this handles is:
|
|
|
|
// - Either operand being scalar
|
|
|
|
// - Operands having equal shapes
|
|
|
|
// - The resulting value being any of ranks [2,6]
|
|
|
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
|
|
|
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|
|
|
: public OpConversionPattern<ChloOpTy> {
|
|
|
|
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
ChloOpTy op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
typename ChloOpTy::Adaptor transformed(operands);
|
|
|
|
Value lhs = transformed.lhs();
|
|
|
|
Value rhs = transformed.rhs();
|
|
|
|
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
|
|
|
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
|
|
|
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
|
|
|
|
|
|
|
// Only support unranked operands. If either operand is ranked, another
|
|
|
|
// pattern will handle the lowering.
|
|
|
|
if (!lhs_type || !rhs_type) return failure();
|
|
|
|
|
|
|
|
// If lhs is scalar
|
|
|
|
auto if_op = rewriter.create<scf::IfOp>(
|
|
|
|
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
|
|
|
OpBuilder if_lhs_scalar_builder =
|
|
|
|
if_op.getThenBodyBuilder(rewriter.getListener());
|
|
|
|
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
|
|
|
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
|
|
|
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
|
|
|
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
|
|
|
op.getAttrs());
|
|
|
|
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
|
|
|
|
|
|
|
// If lhs is NOT scalar
|
|
|
|
//
|
|
|
|
// See if rhs is scalar
|
|
|
|
OpBuilder else_lhs_scalar_builder =
|
|
|
|
if_op.getElseBodyBuilder(rewriter.getListener());
|
|
|
|
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
|
|
|
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
|
|
|
true);
|
|
|
|
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
|
|
|
if_rhs_scalar_op.getResult(0));
|
|
|
|
OpBuilder if_rhs_scalar_builder =
|
|
|
|
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
|
|
|
|
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
|
|
|
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
|
|
|
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
|
|
|
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
|
|
|
op.getAttrs());
|
|
|
|
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
|
|
|
|
|
|
|
// If NEITHER shape is scalar
|
|
|
|
//
|
|
|
|
// See if shapes are equal.
|
|
|
|
OpBuilder else_no_scalars_builder =
|
|
|
|
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
|
|
|
|
Value shape_of_lhs =
|
|
|
|
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
|
|
|
Value shape_of_rhs =
|
|
|
|
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
|
|
|
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
|
|
|
loc, shape_of_lhs, shape_of_rhs);
|
|
|
|
|
|
|
|
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
|
|
|
loc, result_type, equal_shapes, true);
|
|
|
|
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
|
|
|
if_eq_shapes_op.getResult(0));
|
|
|
|
|
|
|
|
OpBuilder if_eq_shapes_builder =
|
|
|
|
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
|
|
|
Value non_broadcast_op =
|
|
|
|
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
|
|
|
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
|
|
|
|
|
|
|
// If shapes are not scalar, nor equal
|
|
|
|
//
|
|
|
|
// See if values are of a rank that we support.
|
|
|
|
OpBuilder if_neq_shapes_builder =
|
|
|
|
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
|
|
|
if_neq_shapes_builder.create<scf::YieldOp>(
|
|
|
|
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, {if_op.getResult(0)});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
2020-11-26 20:19:51 +08:00
|
|
|
// Returns the dynamic result of checking the given value is a scalar tensor.
|
2020-10-30 17:55:49 +08:00
|
|
|
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
|
|
|
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
|
|
|
Value rank_tensor = rewriter.create<shape::RankOp>(
|
|
|
|
loc, rewriter.getIndexType(), shape_of_tensor);
|
|
|
|
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
|
|
|
rank_tensor,
|
|
|
|
rewriter.create<ConstantIndexOp>(loc, 0));
|
|
|
|
}
|
|
|
|
|
2020-11-26 20:19:51 +08:00
|
|
|
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
|
|
|
|
int targeted_rank) const {
|
|
|
|
return builder.create<CmpIOp>(
|
|
|
|
loc, CmpIPredicate::eq, actual_rank,
|
|
|
|
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
|
|
|
}
|
|
|
|
|
|
|
|
scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
|
|
|
|
OpBuilder &builder, ChloOpTy op, Value actual_rank,
|
|
|
|
int targeted_rank) const {
|
|
|
|
// Create the if block to place the current specialized logic in.
|
|
|
|
Value greater_rank_is_n =
|
|
|
|
GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
|
|
|
|
return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
|
|
|
|
greater_rank_is_n, true);
|
|
|
|
}
|
|
|
|
|
2020-10-30 17:55:49 +08:00
|
|
|
// Create the if statement and code for a broadcasting op with a result of a
|
|
|
|
// given rank.
|
2020-11-26 20:19:51 +08:00
|
|
|
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
|
|
|
|
Value lhs, Value rhs,
|
|
|
|
int targeted_rank) const {
|
2020-10-30 17:55:49 +08:00
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
|
|
|
// Handle shape broadcasting and inferrence.
|
|
|
|
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
|
|
|
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
|
|
|
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
|
|
|
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
2020-11-26 20:19:51 +08:00
|
|
|
{RankedTensorType::kDynamicSize}, if_builder.getIndexType());
|
2020-10-30 17:55:49 +08:00
|
|
|
auto known_rank_extent_tensor_type =
|
2020-11-26 20:19:51 +08:00
|
|
|
RankedTensorType::get({targeted_rank}, if_builder.getIndexType());
|
2020-10-30 17:55:49 +08:00
|
|
|
auto reshaped_type = RankedTensorType::get(
|
|
|
|
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
|
|
|
RankedTensorType::kDynamicSize),
|
|
|
|
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
|
|
|
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
|
|
|
loc, known_rank_extent_tensor_type,
|
|
|
|
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
|
|
|
ranked_shape));
|
|
|
|
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
|
|
|
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
|
|
|
|
nullptr);
|
|
|
|
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
|
|
|
|
loc, known_rank_extent_tensor_type, extended_lhs);
|
|
|
|
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
|
|
|
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
|
|
|
|
nullptr);
|
|
|
|
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
|
|
|
|
loc, known_rank_extent_tensor_type, extended_rhs);
|
|
|
|
|
|
|
|
// 1. Reshape operands to the given rank (with the same number of elements)
|
|
|
|
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
|
|
|
// can be broadcasted and do the actual broadcasting)
|
|
|
|
// 3. Type erase the output back to unranked
|
|
|
|
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
|
|
|
loc, reshaped_type, lhs, extended_lhs_casted);
|
|
|
|
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
|
|
|
loc, reshaped_type, rhs, extended_rhs_casted);
|
2020-11-26 20:19:51 +08:00
|
|
|
auto result_element_type = op.getResult()
|
|
|
|
.getType()
|
|
|
|
.template dyn_cast<TensorType>()
|
|
|
|
.getElementType();
|
|
|
|
auto result_type = RankedTensorType::get(
|
|
|
|
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
|
|
|
RankedTensorType::kDynamicSize),
|
|
|
|
result_element_type);
|
2020-10-30 17:55:49 +08:00
|
|
|
Value result = if_builder.create<ChloOpTy>(
|
2020-11-26 20:19:51 +08:00
|
|
|
loc, ArrayRef<Type>{result_type},
|
2020-10-30 17:55:49 +08:00
|
|
|
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
|
|
|
Value reshaped_result = if_builder.create<TensorCastOp>(
|
2020-11-26 20:19:51 +08:00
|
|
|
loc, UnrankedTensorType::get(result_element_type), result);
|
2020-10-30 17:55:49 +08:00
|
|
|
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Iterates over the desired ranks to be specialized and generates the code
|
|
|
|
// snippet for each case.
|
|
|
|
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
|
|
|
Value rhs) const {
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
|
|
|
// Find the larger rank of the 2 operands.
|
|
|
|
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
|
|
rewriter.getIndexType());
|
|
|
|
Value lhs_shape =
|
|
|
|
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
|
|
|
Value rhs_shape =
|
|
|
|
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
|
|
|
Value lhs_rank =
|
2020-10-31 00:58:48 +08:00
|
|
|
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
2020-10-30 17:55:49 +08:00
|
|
|
Value rhs_rank =
|
2020-10-31 00:58:48 +08:00
|
|
|
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
2020-10-30 17:55:49 +08:00
|
|
|
Value greater_rank_lhs =
|
|
|
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
|
|
|
Value greater_rank =
|
|
|
|
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
|
|
|
|
|
|
|
// Generate a list of nested if/else statements to handle rank
|
2020-11-30 17:40:08 +08:00
|
|
|
// specializations from 1 to `kMaxRankSpecialization`.
|
2020-11-26 20:19:51 +08:00
|
|
|
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
|
|
|
rewriter, op, greater_rank, 1);
|
|
|
|
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
|
|
|
createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1);
|
2020-10-30 17:55:49 +08:00
|
|
|
|
|
|
|
// Put each subsequent rank specialization inside the else statement of the
|
|
|
|
// previous one.
|
|
|
|
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
2020-11-26 20:19:51 +08:00
|
|
|
constexpr int kMaxRankSpecialization = 6;
|
|
|
|
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
|
|
|
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
|
|
|
else_builder, op, greater_rank, i);
|
|
|
|
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
|
|
|
createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i);
|
2020-10-30 17:55:49 +08:00
|
|
|
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
|
|
|
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
|
|
|
}
|
2020-11-26 20:19:51 +08:00
|
|
|
// Fire an assertion if none of the rank specializations applied (one of
|
2020-11-30 17:40:08 +08:00
|
|
|
// the ranks was greater than `kMaxRankSpecialization`).
|
2020-10-30 17:55:49 +08:00
|
|
|
else_builder.create<AssertOp>(
|
2020-11-26 20:19:51 +08:00
|
|
|
loc,
|
|
|
|
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
|
|
|
|
kMaxRankSpecialization),
|
2020-11-30 17:40:08 +08:00
|
|
|
"Input for dynamic binary op lowering was of a rank greater than " +
|
|
|
|
std::to_string(kMaxRankSpecialization));
|
2020-11-26 20:19:51 +08:00
|
|
|
// Add the rank 6 specialization to the innermost else block.
|
|
|
|
createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs,
|
|
|
|
kMaxRankSpecialization);
|
2020-10-30 17:55:49 +08:00
|
|
|
|
|
|
|
// Return the result of the outermost if statement.
|
|
|
|
return if_op.getResult(0);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
struct TransformUnrankedHloPass
|
|
|
|
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
2020-08-26 11:30:05 +08:00
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
2020-09-16 16:12:09 +08:00
|
|
|
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
2020-08-26 11:30:05 +08:00
|
|
|
}
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
void runOnFunction() override {
|
|
|
|
// Setup conversion target.
|
|
|
|
MLIRContext &ctx = getContext();
|
|
|
|
ConversionTarget target(ctx);
|
2020-09-17 00:48:43 +08:00
|
|
|
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
2020-10-30 17:55:49 +08:00
|
|
|
shape::ShapeDialect, scf::SCFDialect>();
|
2020-07-07 04:57:00 +08:00
|
|
|
target.addLegalOp<FuncOp>();
|
2020-09-17 00:48:43 +08:00
|
|
|
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
|
|
|
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
|
|
|
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
|
|
|
|
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
|
|
|
|
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
|
|
|
|
#undef ADD_LEGAL_MHLO
|
|
|
|
#undef ADD_LEGAL_CHLO
|
2020-09-18 16:39:48 +08:00
|
|
|
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
|
|
|
|
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
|
2020-10-30 17:55:49 +08:00
|
|
|
target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
|
|
|
|
[](Operation *op) {
|
|
|
|
return !llvm::any_of(op->getOperandTypes(), [](Type type) {
|
|
|
|
return type.isa<UnrankedTensorType>();
|
|
|
|
});
|
|
|
|
});
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Populate rewrite patterns.
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
|
|
|
|
|
|
|
|
// Apply transformation.
|
2020-10-27 21:55:28 +08:00
|
|
|
if (failed(
|
|
|
|
applyPartialConversion(getFunction(), target, std::move(patterns))))
|
2020-07-07 04:57:00 +08:00
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|
|
|
OwningRewritePatternList *patterns) {
|
2020-09-17 00:48:43 +08:00
|
|
|
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
|
|
|
|
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
|
|
|
|
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
|
2020-07-28 15:55:58 +08:00
|
|
|
#define COMMA ,
|
2020-09-16 16:12:09 +08:00
|
|
|
// clang-format off
|
2020-07-07 04:57:00 +08:00
|
|
|
patterns->insert<
|
2020-07-28 15:55:58 +08:00
|
|
|
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
2020-09-17 00:48:43 +08:00
|
|
|
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
|
2020-09-18 16:39:48 +08:00
|
|
|
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA),
|
|
|
|
ElementwiseOpConversion<mhlo::CompareOp>,
|
|
|
|
ElementwiseOpConversion<mhlo::SelectOp>>(context);
|
2020-09-16 16:12:09 +08:00
|
|
|
// clang-format on
|
2020-07-28 15:55:58 +08:00
|
|
|
#undef MAP_UNARY
|
|
|
|
#undef MAP_BINARY
|
2020-09-17 00:48:43 +08:00
|
|
|
#undef MAP_CHLO_UNARY
|
2020-07-28 15:55:58 +08:00
|
|
|
#undef COMMA
|
2020-10-30 17:55:49 +08:00
|
|
|
chlo::PopulateForBroadcastingBinaryOp<
|
|
|
|
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
|
|
|
chlo::PopulateForBroadcastingBinaryOp<
|
|
|
|
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
2020-09-08 21:05:50 +08:00
|
|
|
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
2020-07-29 07:12:08 +08:00
|
|
|
return std::make_unique<TransformUnrankedHloPass>();
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
} // namespace mlir
|