Move unranked chlo lowering to transform_unranked_hlo.
Additionally: - Forward listeners through new if/else op builders. This corrects an error that led to incomplete legalization of broadcasted op lowering. - Use OpConversionPattern to ensure up to date operand values are used. PiperOrigin-RevId: 339838833
This commit is contained in:
parent
e188ef10f2
commit
76b30fd426
|
@ -0,0 +1,97 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace chlo {
|
||||
|
||||
struct HloComplexAdaptor {
|
||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
template <typename FromOpTy, typename ToOpTy>
|
||||
struct HloBinaryElementwiseAdaptor {
|
||||
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
struct HloCompareAdaptor {
|
||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::CompareOp>(
|
||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction(), from_op.compare_typeAttr());
|
||||
}
|
||||
};
|
||||
|
||||
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
|
||||
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
|
||||
template <template <typename, typename, typename> class Pattern,
|
||||
typename... ConstructorArgs>
|
||||
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns,
|
||||
ConstructorArgs &&...args) {
|
||||
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||
patterns->insert< \
|
||||
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
|
||||
context, args...);
|
||||
|
||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
patterns
|
||||
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
|
||||
context, args...);
|
||||
patterns
|
||||
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
|
||||
context, args...);
|
||||
|
||||
#undef POPULATE_BCAST
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H_
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
@ -69,13 +70,18 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
|
|||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
struct ConvertTrivialNonBroadcastBinaryOp
|
||||
: public OpConversionPattern<ChloOpTy> {
|
||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ChloOpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only rewrite for statically determinable non-broadcasting cases.
|
||||
auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>();
|
||||
auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>();
|
||||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
auto lhs_type =
|
||||
transformed.lhs().getType().template dyn_cast<RankedTensorType>();
|
||||
auto rhs_type =
|
||||
transformed.rhs().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
|
||||
// Requires rank broadcast.
|
||||
|
@ -93,8 +99,9 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
|||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
|
||||
op.lhs(), op.rhs(), rewriter)});
|
||||
rewriter.replaceOp(
|
||||
op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
|
||||
operands[1], rewriter)});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -113,13 +120,15 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
|||
// `shape.broadcast` op, which only supports prefix-padding.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
: public OpConversionPattern<ChloOpTy> {
|
||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ChloOpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only support ranked operands.
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
Value lhs = transformed.lhs();
|
||||
Value rhs = transformed.rhs();
|
||||
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto result_type =
|
||||
|
@ -193,324 +202,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
|||
}
|
||||
};
|
||||
|
||||
// Converts a broadcasting binary operation with a scalar operand and an
|
||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
||||
// the unranked operand to a 1D tensor. This will always be safe because
|
||||
// broadcasting from a scalar to another shape always works.
|
||||
template <typename ChloOpTy, typename HloOpTy>
|
||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
bool lhs_is_scalar = lhs_ranked_type &&
|
||||
lhs_ranked_type.getShape().empty() &&
|
||||
rhs_unranked_type;
|
||||
bool rhs_is_scalar = rhs_ranked_type &&
|
||||
rhs_ranked_type.getShape().empty() &&
|
||||
lhs_unranked_type;
|
||||
|
||||
// Only support the case where exactly one operand is scalar and the other
|
||||
// is unranked. Other patterns in this file will create more efficient
|
||||
// lowerings for cases where both ranks are known or will handle the more
|
||||
// generic case of both inputs being unranked.
|
||||
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
||||
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
||||
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||
Value size_tensor =
|
||||
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||
|
||||
// Create a new ranked Chlo op that will be further lowered by other
|
||||
// patterns into Mhlo.
|
||||
SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
|
||||
rhs_is_scalar ? rhs : reshaped};
|
||||
Value computed = rewriter.create<ChloOpTy>(
|
||||
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
|
||||
|
||||
// Reshape the result back into an unranked tensor.
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||
computed, shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Handles lowering of the following pattern to patterns that will be further
|
||||
// matched by other patterns until they result in LHLO:
|
||||
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
||||
//
|
||||
// The sequence of specializations this handles is:
|
||||
// - Either operand being scalar
|
||||
// - Operands having equal shapes
|
||||
// - The resulting value being any of ranks [2,6]
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Only support unranked operands. If either operand is ranked, another
|
||||
// pattern will handle the lowering.
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
|
||||
// If lhs is scalar
|
||||
auto if_op = rewriter.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||
op.getAttrs());
|
||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
||||
|
||||
// If lhs is NOT scalar
|
||||
//
|
||||
// See if rhs is scalar
|
||||
OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder();
|
||||
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
||||
true);
|
||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
if_rhs_scalar_op.getResult(0));
|
||||
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||
op.getAttrs());
|
||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
||||
|
||||
// If NEITHER shape is scalar
|
||||
//
|
||||
// See if shapes are equal.
|
||||
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
|
||||
Value shape_of_lhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value shape_of_rhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
||||
loc, shape_of_lhs, shape_of_rhs);
|
||||
|
||||
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
||||
loc, result_type, equal_shapes, true);
|
||||
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
||||
if_eq_shapes_op.getResult(0));
|
||||
|
||||
OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder();
|
||||
Value non_broadcast_op =
|
||||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||
|
||||
// If shapes are not scalar, nor equal
|
||||
//
|
||||
// See if values are of a rank that we support.
|
||||
OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder();
|
||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
||||
|
||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns the dyanamic result of checking the given value is a scalar
|
||||
// tensor.
|
||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
||||
Value rank_tensor = rewriter.create<shape::RankOp>(
|
||||
loc, rewriter.getIndexType(), shape_of_tensor);
|
||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
||||
rank_tensor,
|
||||
rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
// Create the if statement and code for a broadcasting op with a result of a
|
||||
// given rank.
|
||||
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
|
||||
Value lhs, Value rhs,
|
||||
Value actual_rank,
|
||||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Create the if block to place the current specialized logic in.
|
||||
Value greater_rank_is_n = builder.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, actual_rank,
|
||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||
auto if_op =
|
||||
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder();
|
||||
|
||||
// Handle shape broadcasting and inferrence.
|
||||
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
auto known_rank_extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||
auto reshaped_type = RankedTensorType::get(
|
||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||
RankedTensorType::kDynamicSize),
|
||||
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
||||
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
||||
loc, known_rank_extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||
ranked_shape));
|
||||
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_lhs);
|
||||
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_rhs);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of elements)
|
||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
||||
// can be broadcasted and do the actual broadcasting)
|
||||
// 3. Type erase the output back to unranked
|
||||
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, lhs, extended_lhs_casted);
|
||||
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, rhs, extended_rhs_casted);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{reshaped_type},
|
||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
||||
Value reshaped_result = if_builder.create<TensorCastOp>(
|
||||
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
|
||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||
|
||||
// Return the if_op, so the result can be used and the else block can be
|
||||
// used for the next rank specialized step.
|
||||
return if_op;
|
||||
}
|
||||
|
||||
// Iterates over the desired ranks to be specialized and generates the code
|
||||
// snippet for each case.
|
||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
||||
Value rhs) const {
|
||||
constexpr int max_rank_specialization = 7;
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Find the larger rank of the 2 operands.
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value lhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
||||
Value rhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
||||
Value lhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
||||
Value rhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
||||
Value greater_rank_lhs =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
||||
Value greater_rank =
|
||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
||||
|
||||
// Generate a list of nested if/else statements to handle rank
|
||||
// specializations from 2-6.
|
||||
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
||||
rhs, greater_rank, 2);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder();
|
||||
for (int i = 3; i < max_rank_specialization; i++) {
|
||||
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
||||
rhs, greater_rank, i);
|
||||
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder();
|
||||
}
|
||||
|
||||
// Fire an assertion if none of the rank specializations applied (one of the
|
||||
// ranks was greater than 6).
|
||||
else_builder.create<AssertOp>(
|
||||
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
|
||||
"Input for dynamic binary op lowering was of a rank greater than 6");
|
||||
else_builder.create<scf::YieldOp>(loc, lhs);
|
||||
|
||||
// Return the result of the outermost if statement.
|
||||
return if_op.getResult(0);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
void PopulateForBinaryOp(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns
|
||||
->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context, 10);
|
||||
patterns->insert<
|
||||
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context, 5);
|
||||
patterns->insert<
|
||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>,
|
||||
ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context);
|
||||
}
|
||||
|
||||
template <typename FromOpTy, typename ToOpTy>
|
||||
struct HloBinaryElementwiseAdaptor {
|
||||
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
struct HloComplexAdaptor {
|
||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
struct HloCompareAdaptor {
|
||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::CompareOp>(
|
||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction(), from_op.compare_typeAttr());
|
||||
}
|
||||
};
|
||||
|
||||
#include "generated_chlo_legalize_to_hlo.inc"
|
||||
} // namespace
|
||||
|
||||
|
@ -521,32 +212,10 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
|||
// Instantiate conversion templates for conforming binary elementwise ops
|
||||
// that do not have different dtypes between operands and results and do
|
||||
// not have special attributes that need to be preserved.
|
||||
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||
PopulateForBinaryOp<ChloOp, HloOp, \
|
||||
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
|
||||
patterns);
|
||||
|
||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
|
||||
context, patterns);
|
||||
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
|
||||
context, patterns);
|
||||
PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
|
||||
context, patterns, 10);
|
||||
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
||||
context, patterns, 5);
|
||||
|
||||
// Other patterns.
|
||||
patterns->insert<ConvertConstantLikeOp>(context);
|
||||
|
|
|
@ -16,7 +16,9 @@ limitations under the License.
|
|||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
|
@ -126,6 +128,291 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
// Converts a broadcasting binary operation with a scalar operand and an
|
||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
||||
// the unranked operand to a 1D tensor. This will always be safe because
|
||||
// broadcasting from a scalar to another shape always works.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||
: public OpConversionPattern<ChloOpTy> {
|
||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ChloOpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
Value lhs = transformed.lhs();
|
||||
Value rhs = transformed.rhs();
|
||||
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
bool lhs_is_scalar = lhs_ranked_type &&
|
||||
lhs_ranked_type.getShape().empty() &&
|
||||
rhs_unranked_type;
|
||||
bool rhs_is_scalar = rhs_ranked_type &&
|
||||
rhs_ranked_type.getShape().empty() &&
|
||||
lhs_unranked_type;
|
||||
|
||||
// Only support the case where exactly one operand is scalar and the other
|
||||
// is unranked. Other patterns in chlo-to-hlo legalization will create more
|
||||
// efficient lowerings for cases where both ranks are known or will handle
|
||||
// the more generic case of both inputs being unranked.
|
||||
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
||||
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
||||
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||
Value size_tensor =
|
||||
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||
|
||||
// Create a new ranked Chlo op that will be further lowered by other
|
||||
// patterns into Mhlo.
|
||||
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
||||
rhs_is_scalar ? rhs : reshaped};
|
||||
Value computed =
|
||||
rewriter.create<ChloOpTy>(loc, SmallVector<Type, 1>{reshaped.getType()},
|
||||
new_operands, op.getAttrs());
|
||||
|
||||
// Reshape the result back into an unranked tensor.
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||
computed, shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Handles lowering of the following pattern to patterns that will be further
|
||||
// matched by other patterns until they result in LHLO:
|
||||
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
||||
//
|
||||
// The sequence of specializations this handles is:
|
||||
// - Either operand being scalar
|
||||
// - Operands having equal shapes
|
||||
// - The resulting value being any of ranks [2,6]
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
: public OpConversionPattern<ChloOpTy> {
|
||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
ChloOpTy op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
Value lhs = transformed.lhs();
|
||||
Value rhs = transformed.rhs();
|
||||
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Only support unranked operands. If either operand is ranked, another
|
||||
// pattern will handle the lowering.
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
|
||||
// If lhs is scalar
|
||||
auto if_op = rewriter.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||
OpBuilder if_lhs_scalar_builder =
|
||||
if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||
op.getAttrs());
|
||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
||||
|
||||
// If lhs is NOT scalar
|
||||
//
|
||||
// See if rhs is scalar
|
||||
OpBuilder else_lhs_scalar_builder =
|
||||
if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
||||
true);
|
||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
if_rhs_scalar_op.getResult(0));
|
||||
OpBuilder if_rhs_scalar_builder =
|
||||
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||
op.getAttrs());
|
||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
||||
|
||||
// If NEITHER shape is scalar
|
||||
//
|
||||
// See if shapes are equal.
|
||||
OpBuilder else_no_scalars_builder =
|
||||
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
|
||||
Value shape_of_lhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value shape_of_rhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
||||
loc, shape_of_lhs, shape_of_rhs);
|
||||
|
||||
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
||||
loc, result_type, equal_shapes, true);
|
||||
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
||||
if_eq_shapes_op.getResult(0));
|
||||
|
||||
OpBuilder if_eq_shapes_builder =
|
||||
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value non_broadcast_op =
|
||||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||
|
||||
// If shapes are not scalar, nor equal
|
||||
//
|
||||
// See if values are of a rank that we support.
|
||||
OpBuilder if_neq_shapes_builder =
|
||||
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
||||
|
||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns the dyanamic result of checking the given value is a scalar
|
||||
// tensor.
|
||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
||||
Value rank_tensor = rewriter.create<shape::RankOp>(
|
||||
loc, rewriter.getIndexType(), shape_of_tensor);
|
||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
||||
rank_tensor,
|
||||
rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
// Create the if statement and code for a broadcasting op with a result of a
|
||||
// given rank.
|
||||
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
|
||||
Value lhs, Value rhs,
|
||||
Value actual_rank,
|
||||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Create the if block to place the current specialized logic in.
|
||||
Value greater_rank_is_n = builder.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, actual_rank,
|
||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||
auto if_op =
|
||||
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener());
|
||||
|
||||
// Handle shape broadcasting and inferrence.
|
||||
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
auto known_rank_extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||
auto reshaped_type = RankedTensorType::get(
|
||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||
RankedTensorType::kDynamicSize),
|
||||
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
||||
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
||||
loc, known_rank_extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||
ranked_shape));
|
||||
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_lhs);
|
||||
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_rhs);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of elements)
|
||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
||||
// can be broadcasted and do the actual broadcasting)
|
||||
// 3. Type erase the output back to unranked
|
||||
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, lhs, extended_lhs_casted);
|
||||
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, rhs, extended_rhs_casted);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{reshaped_type},
|
||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
||||
Value reshaped_result = if_builder.create<TensorCastOp>(
|
||||
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
|
||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||
|
||||
// Return the if_op, so the result can be used and the else block can be
|
||||
// used for the next rank specialized step.
|
||||
return if_op;
|
||||
}
|
||||
|
||||
// Iterates over the desired ranks to be specialized and generates the code
|
||||
// snippet for each case.
|
||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
||||
Value rhs) const {
|
||||
constexpr int max_rank_specialization = 7;
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Find the larger rank of the 2 operands.
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value lhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
||||
Value rhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
||||
Value lhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
||||
Value rhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
||||
Value greater_rank_lhs =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
||||
Value greater_rank =
|
||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
||||
|
||||
// Generate a list of nested if/else statements to handle rank
|
||||
// specializations from 2-6.
|
||||
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
||||
rhs, greater_rank, 2);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
for (int i = 3; i < max_rank_specialization; i++) {
|
||||
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
||||
rhs, greater_rank, i);
|
||||
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||
}
|
||||
|
||||
// Fire an assertion if none of the rank specializations applied (one of the
|
||||
// ranks was greater than 6).
|
||||
else_builder.create<AssertOp>(
|
||||
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
|
||||
"Input for dynamic binary op lowering was of a rank greater than 6");
|
||||
else_builder.create<scf::YieldOp>(loc, lhs);
|
||||
|
||||
// Return the result of the outermost if statement.
|
||||
return if_op.getResult(0);
|
||||
}
|
||||
};
|
||||
|
||||
struct TransformUnrankedHloPass
|
||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -137,7 +424,7 @@ struct TransformUnrankedHloPass
|
|||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect>();
|
||||
shape::ShapeDialect, scf::SCFDialect>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
||||
|
@ -148,6 +435,12 @@ struct TransformUnrankedHloPass
|
|||
#undef ADD_LEGAL_CHLO
|
||||
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
|
||||
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
|
||||
target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
|
||||
[](Operation *op) {
|
||||
return !llvm::any_of(op->getOperandTypes(), [](Type type) {
|
||||
return type.isa<UnrankedTensorType>();
|
||||
});
|
||||
});
|
||||
|
||||
// Populate rewrite patterns.
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -180,6 +473,10 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|||
#undef MAP_BINARY
|
||||
#undef MAP_CHLO_UNARY
|
||||
#undef COMMA
|
||||
chlo::PopulateForBroadcastingBinaryOp<
|
||||
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
||||
chlo::PopulateForBroadcastingBinaryOp<
|
||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
||||
|
|
|
@ -237,209 +237,3 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
|||
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addScalarUnranked(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
||||
// CHECK-SAME: ) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
|
||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// The assuming region is part of the second stage of lowering
|
||||
// with ranked broadcasting logic.
|
||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32>
|
||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
|
||||
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @addUnrankedScalar(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
|
||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// The assuming region is part of the second stage of lowering
|
||||
// with ranked broadcasting logic.
|
||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||
// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedUnranked(
|
||||
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addUnrankedUnranked(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
||||
// Handle scalar LHS case
|
||||
// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
||||
// Handle scalar RHS case
|
||||
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||
// Handle scalar RHS case
|
||||
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_19]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
||||
// Handle rank 2 specialization
|
||||
// CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
|
||||
// Handle rank 3 specialization
|
||||
// CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C4:.*]] = constant 4 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
|
||||
// Handle rank 4 specialization
|
||||
// CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
|
||||
// Handle rank 5 specialization
|
||||
// CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C6:.*]] = constant 6 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
||||
// Handle rank 6 specialization
|
||||
// CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %false = constant false
|
||||
// CHECK: assert %false
|
||||
// CHECK: scf.yield %[[LHS]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return %[[VAL_71:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-hlo-opt --transform-unranked-hlo --cse --split-input-file %s | FileCheck %s
|
||||
|
||||
// Check the validity of expected IR.
|
||||
// CHECK-LABEL: @sqr_transform_result
|
||||
|
@ -96,3 +96,203 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
|
|||
%result = chlo.tan %a : tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addScalarUnranked(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
||||
// CHECK-SAME: ) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @addUnrankedScalar(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// The assuming region is part of the second stage of lowering
|
||||
// with ranked broadcasting logic.
|
||||
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedUnranked(
|
||||
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addUnrankedUnranked(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
||||
// Handle scalar LHS case
|
||||
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
||||
// Handle scalar RHS case
|
||||
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||
// Handle equal shapes case
|
||||
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
||||
// Handle rank 2 specialization
|
||||
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
|
||||
// Handle rank 3 specialization
|
||||
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
|
||||
// Handle rank 4 specialization
|
||||
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
|
||||
// Handle rank 5 specialization
|
||||
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
||||
// Handle rank 6 specialization
|
||||
// CHECK-NEXT: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %false = constant false
|
||||
// CHECK-NEXT: assert %false
|
||||
// CHECK-NEXT: scf.yield %[[LHS]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[VAL_71:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
|
Loading…
Reference in New Issue