Move unranked chlo lowering to transform_unranked_hlo.
Additionally: - Forward listeners through new if/else op builders. This corrects an error that led to incomplete legalization of broadcasted op lowering. - Use OpConversionPattern to ensure up to date operand values are used. PiperOrigin-RevId: 339838833
This commit is contained in:
parent
e188ef10f2
commit
76b30fd426
|
@ -0,0 +1,97 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace chlo {
|
||||||
|
|
||||||
|
struct HloComplexAdaptor {
|
||||||
|
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||||
|
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||||
|
OpBuilder &builder) {
|
||||||
|
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||||
|
broadcasted_lhs, broadcasted_rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename FromOpTy, typename ToOpTy>
|
||||||
|
struct HloBinaryElementwiseAdaptor {
|
||||||
|
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
||||||
|
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||||
|
OpBuilder &builder) {
|
||||||
|
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||||
|
broadcasted_lhs, broadcasted_rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct HloCompareAdaptor {
|
||||||
|
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||||
|
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||||
|
OpBuilder &builder) {
|
||||||
|
return builder.create<mhlo::CompareOp>(
|
||||||
|
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||||
|
from_op.comparison_direction(), from_op.compare_typeAttr());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
|
||||||
|
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
|
||||||
|
template <template <typename, typename, typename> class Pattern,
|
||||||
|
typename... ConstructorArgs>
|
||||||
|
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns,
|
||||||
|
ConstructorArgs &&...args) {
|
||||||
|
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||||
|
patterns->insert< \
|
||||||
|
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
|
||||||
|
context, args...);
|
||||||
|
|
||||||
|
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||||
|
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||||
|
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||||
|
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||||
|
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||||
|
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||||
|
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
||||||
|
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
||||||
|
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
||||||
|
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
||||||
|
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
||||||
|
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
||||||
|
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||||
|
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||||
|
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||||
|
|
||||||
|
// Broadcasting ops requiring special construction.
|
||||||
|
patterns
|
||||||
|
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
|
||||||
|
context, args...);
|
||||||
|
patterns
|
||||||
|
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
|
||||||
|
context, args...);
|
||||||
|
|
||||||
|
#undef POPULATE_BCAST
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace chlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H_
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
@ -69,13 +70,18 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
|
||||||
// Converts binary ops that statically are determined to not broadcast directly
|
// Converts binary ops that statically are determined to not broadcast directly
|
||||||
// to the corresponding mhlo non-broadcasting op.
|
// to the corresponding mhlo non-broadcasting op.
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
struct ConvertTrivialNonBroadcastBinaryOp
|
||||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
: public OpConversionPattern<ChloOpTy> {
|
||||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||||
PatternRewriter &rewriter) const override {
|
LogicalResult matchAndRewrite(
|
||||||
|
ChloOpTy op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// Only rewrite for statically determinable non-broadcasting cases.
|
// Only rewrite for statically determinable non-broadcasting cases.
|
||||||
auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>();
|
typename ChloOpTy::Adaptor transformed(operands);
|
||||||
auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>();
|
auto lhs_type =
|
||||||
|
transformed.lhs().getType().template dyn_cast<RankedTensorType>();
|
||||||
|
auto rhs_type =
|
||||||
|
transformed.rhs().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!lhs_type || !rhs_type) return failure();
|
if (!lhs_type || !rhs_type) return failure();
|
||||||
|
|
||||||
// Requires rank broadcast.
|
// Requires rank broadcast.
|
||||||
|
@ -93,8 +99,9 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
|
rewriter.replaceOp(
|
||||||
op.lhs(), op.rhs(), rewriter)});
|
op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
|
||||||
|
operands[1], rewriter)});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -113,13 +120,15 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||||
// `shape.broadcast` op, which only supports prefix-padding.
|
// `shape.broadcast` op, which only supports prefix-padding.
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
struct ConvertRankedDynamicBroadcastBinaryOp
|
struct ConvertRankedDynamicBroadcastBinaryOp
|
||||||
: public OpRewritePattern<ChloOpTy> {
|
: public OpConversionPattern<ChloOpTy> {
|
||||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
LogicalResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
ChloOpTy op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// Only support ranked operands.
|
// Only support ranked operands.
|
||||||
Value lhs = op.lhs();
|
typename ChloOpTy::Adaptor transformed(operands);
|
||||||
Value rhs = op.rhs();
|
Value lhs = transformed.lhs();
|
||||||
|
Value rhs = transformed.rhs();
|
||||||
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||||
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
auto result_type =
|
auto result_type =
|
||||||
|
@ -193,324 +202,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts a broadcasting binary operation with a scalar operand and an
|
|
||||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
|
||||||
// the unranked operand to a 1D tensor. This will always be safe because
|
|
||||||
// broadcasting from a scalar to another shape always works.
|
|
||||||
template <typename ChloOpTy, typename HloOpTy>
|
|
||||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
|
||||||
: public OpRewritePattern<ChloOpTy> {
|
|
||||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
Value lhs = op.lhs();
|
|
||||||
Value rhs = op.rhs();
|
|
||||||
|
|
||||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
|
||||||
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
|
||||||
|
|
||||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
|
||||||
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
|
||||||
|
|
||||||
bool lhs_is_scalar = lhs_ranked_type &&
|
|
||||||
lhs_ranked_type.getShape().empty() &&
|
|
||||||
rhs_unranked_type;
|
|
||||||
bool rhs_is_scalar = rhs_ranked_type &&
|
|
||||||
rhs_ranked_type.getShape().empty() &&
|
|
||||||
lhs_unranked_type;
|
|
||||||
|
|
||||||
// Only support the case where exactly one operand is scalar and the other
|
|
||||||
// is unranked. Other patterns in this file will create more efficient
|
|
||||||
// lowerings for cases where both ranks are known or will handle the more
|
|
||||||
// generic case of both inputs being unranked.
|
|
||||||
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
|
||||||
|
|
||||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
|
||||||
|
|
||||||
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
|
||||||
Value shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
|
||||||
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
|
||||||
Value size_tensor =
|
|
||||||
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
|
||||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
|
||||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
|
||||||
|
|
||||||
// Create a new ranked Chlo op that will be further lowered by other
|
|
||||||
// patterns into Mhlo.
|
|
||||||
SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
|
|
||||||
rhs_is_scalar ? rhs : reshaped};
|
|
||||||
Value computed = rewriter.create<ChloOpTy>(
|
|
||||||
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
|
|
||||||
|
|
||||||
// Reshape the result back into an unranked tensor.
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
|
||||||
computed, shape);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Handles lowering of the following pattern to patterns that will be further
|
|
||||||
// matched by other patterns until they result in LHLO:
|
|
||||||
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
|
||||||
//
|
|
||||||
// The sequence of specializations this handles is:
|
|
||||||
// - Either operand being scalar
|
|
||||||
// - Operands having equal shapes
|
|
||||||
// - The resulting value being any of ranks [2,6]
|
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
|
||||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|
||||||
: public OpRewritePattern<ChloOpTy> {
|
|
||||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
Value lhs = op.lhs();
|
|
||||||
Value rhs = op.rhs();
|
|
||||||
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
|
||||||
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
|
||||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
|
||||||
|
|
||||||
// Only support unranked operands. If either operand is ranked, another
|
|
||||||
// pattern will handle the lowering.
|
|
||||||
if (!lhs_type || !rhs_type) return failure();
|
|
||||||
|
|
||||||
// If lhs is scalar
|
|
||||||
auto if_op = rewriter.create<scf::IfOp>(
|
|
||||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
|
||||||
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
|
||||||
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
|
||||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
|
||||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
|
||||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
|
||||||
op.getAttrs());
|
|
||||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
|
||||||
|
|
||||||
// If lhs is NOT scalar
|
|
||||||
//
|
|
||||||
// See if rhs is scalar
|
|
||||||
OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder();
|
|
||||||
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
|
||||||
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
|
||||||
true);
|
|
||||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
|
||||||
if_rhs_scalar_op.getResult(0));
|
|
||||||
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
|
||||||
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
|
||||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
|
||||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
|
||||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
|
||||||
op.getAttrs());
|
|
||||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
|
||||||
|
|
||||||
// If NEITHER shape is scalar
|
|
||||||
//
|
|
||||||
// See if shapes are equal.
|
|
||||||
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
|
|
||||||
Value shape_of_lhs =
|
|
||||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
|
||||||
Value shape_of_rhs =
|
|
||||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
|
||||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
|
||||||
loc, shape_of_lhs, shape_of_rhs);
|
|
||||||
|
|
||||||
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
|
||||||
loc, result_type, equal_shapes, true);
|
|
||||||
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
|
||||||
if_eq_shapes_op.getResult(0));
|
|
||||||
|
|
||||||
OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder();
|
|
||||||
Value non_broadcast_op =
|
|
||||||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
|
||||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
|
||||||
|
|
||||||
// If shapes are not scalar, nor equal
|
|
||||||
//
|
|
||||||
// See if values are of a rank that we support.
|
|
||||||
OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder();
|
|
||||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
|
||||||
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Returns the dyanamic result of checking the given value is a scalar
|
|
||||||
// tensor.
|
|
||||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
|
|
||||||
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
|
||||||
Value rank_tensor = rewriter.create<shape::RankOp>(
|
|
||||||
loc, rewriter.getIndexType(), shape_of_tensor);
|
|
||||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
|
||||||
rank_tensor,
|
|
||||||
rewriter.create<ConstantIndexOp>(loc, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the if statement and code for a broadcasting op with a result of a
|
|
||||||
// given rank.
|
|
||||||
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
|
|
||||||
Value lhs, Value rhs,
|
|
||||||
Value actual_rank,
|
|
||||||
int targeted_rank) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
|
|
||||||
// Create the if block to place the current specialized logic in.
|
|
||||||
Value greater_rank_is_n = builder.create<CmpIOp>(
|
|
||||||
loc, CmpIPredicate::eq, actual_rank,
|
|
||||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
|
||||||
auto if_op =
|
|
||||||
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
|
|
||||||
OpBuilder if_builder = if_op.getThenBodyBuilder();
|
|
||||||
|
|
||||||
// Handle shape broadcasting and inferrence.
|
|
||||||
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
|
||||||
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
|
||||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
|
||||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
|
||||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
|
||||||
auto known_rank_extent_tensor_type =
|
|
||||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
|
||||||
auto reshaped_type = RankedTensorType::get(
|
|
||||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
|
||||||
RankedTensorType::kDynamicSize),
|
|
||||||
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
|
||||||
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
|
||||||
loc, known_rank_extent_tensor_type,
|
|
||||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
|
||||||
ranked_shape));
|
|
||||||
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
|
||||||
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
|
|
||||||
nullptr);
|
|
||||||
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
|
|
||||||
loc, known_rank_extent_tensor_type, extended_lhs);
|
|
||||||
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
|
||||||
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
|
|
||||||
nullptr);
|
|
||||||
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
|
|
||||||
loc, known_rank_extent_tensor_type, extended_rhs);
|
|
||||||
|
|
||||||
// 1. Reshape operands to the given rank (with the same number of elements)
|
|
||||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
|
||||||
// can be broadcasted and do the actual broadcasting)
|
|
||||||
// 3. Type erase the output back to unranked
|
|
||||||
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, reshaped_type, lhs, extended_lhs_casted);
|
|
||||||
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, reshaped_type, rhs, extended_rhs_casted);
|
|
||||||
Value result = if_builder.create<ChloOpTy>(
|
|
||||||
loc, ArrayRef<Type>{reshaped_type},
|
|
||||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
|
||||||
Value reshaped_result = if_builder.create<TensorCastOp>(
|
|
||||||
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
|
|
||||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
|
||||||
|
|
||||||
// Return the if_op, so the result can be used and the else block can be
|
|
||||||
// used for the next rank specialized step.
|
|
||||||
return if_op;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterates over the desired ranks to be specialized and generates the code
|
|
||||||
// snippet for each case.
|
|
||||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
|
||||||
Value rhs) const {
|
|
||||||
constexpr int max_rank_specialization = 7;
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
|
|
||||||
// Find the larger rank of the 2 operands.
|
|
||||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
rewriter.getIndexType());
|
|
||||||
Value lhs_shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
|
||||||
Value rhs_shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
|
||||||
Value lhs_rank =
|
|
||||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
|
||||||
Value rhs_rank =
|
|
||||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
|
||||||
Value greater_rank_lhs =
|
|
||||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
|
||||||
Value greater_rank =
|
|
||||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
|
||||||
|
|
||||||
// Generate a list of nested if/else statements to handle rank
|
|
||||||
// specializations from 2-6.
|
|
||||||
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
|
||||||
rhs, greater_rank, 2);
|
|
||||||
|
|
||||||
// Put each subsequent rank specialization inside the else statement of the
|
|
||||||
// previous one.
|
|
||||||
OpBuilder else_builder = if_op.getElseBodyBuilder();
|
|
||||||
for (int i = 3; i < max_rank_specialization; i++) {
|
|
||||||
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
|
||||||
rhs, greater_rank, i);
|
|
||||||
|
|
||||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
|
||||||
else_builder = inner_if.getElseBodyBuilder();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fire an assertion if none of the rank specializations applied (one of the
|
|
||||||
// ranks was greater than 6).
|
|
||||||
else_builder.create<AssertOp>(
|
|
||||||
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
|
|
||||||
"Input for dynamic binary op lowering was of a rank greater than 6");
|
|
||||||
else_builder.create<scf::YieldOp>(loc, lhs);
|
|
||||||
|
|
||||||
// Return the result of the outermost if statement.
|
|
||||||
return if_op.getResult(0);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
|
||||||
void PopulateForBinaryOp(MLIRContext *context,
|
|
||||||
OwningRewritePatternList *patterns) {
|
|
||||||
patterns
|
|
||||||
->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
|
||||||
context, 10);
|
|
||||||
patterns->insert<
|
|
||||||
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
|
||||||
context, 5);
|
|
||||||
patterns->insert<
|
|
||||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>,
|
|
||||||
ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
|
||||||
context);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename FromOpTy, typename ToOpTy>
|
|
||||||
struct HloBinaryElementwiseAdaptor {
|
|
||||||
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
|
||||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
|
||||||
OpBuilder &builder) {
|
|
||||||
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
|
||||||
broadcasted_lhs, broadcasted_rhs);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct HloComplexAdaptor {
|
|
||||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
|
||||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
|
||||||
OpBuilder &builder) {
|
|
||||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
|
||||||
broadcasted_lhs, broadcasted_rhs);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct HloCompareAdaptor {
|
|
||||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
|
||||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
|
||||||
OpBuilder &builder) {
|
|
||||||
return builder.create<mhlo::CompareOp>(
|
|
||||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
|
||||||
from_op.comparison_direction(), from_op.compare_typeAttr());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#include "generated_chlo_legalize_to_hlo.inc"
|
#include "generated_chlo_legalize_to_hlo.inc"
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -521,32 +212,10 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
// Instantiate conversion templates for conforming binary elementwise ops
|
// Instantiate conversion templates for conforming binary elementwise ops
|
||||||
// that do not have different dtypes between operands and results and do
|
// that do not have different dtypes between operands and results and do
|
||||||
// not have special attributes that need to be preserved.
|
// not have special attributes that need to be preserved.
|
||||||
#define POPULATE_BCAST(ChloOp, HloOp) \
|
PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
|
||||||
PopulateForBinaryOp<ChloOp, HloOp, \
|
context, patterns, 10);
|
||||||
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
|
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
||||||
patterns);
|
context, patterns, 5);
|
||||||
|
|
||||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
|
||||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
|
||||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
|
||||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
|
||||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
|
||||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
|
||||||
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
|
||||||
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
|
||||||
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
|
||||||
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
|
||||||
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
|
||||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
|
||||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
|
||||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
|
||||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
|
||||||
|
|
||||||
// Broadcasting ops requiring special construction.
|
|
||||||
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
|
|
||||||
context, patterns);
|
|
||||||
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
|
|
||||||
context, patterns);
|
|
||||||
|
|
||||||
// Other patterns.
|
// Other patterns.
|
||||||
patterns->insert<ConvertConstantLikeOp>(context);
|
patterns->insert<ConvertConstantLikeOp>(context);
|
||||||
|
|
|
@ -16,7 +16,9 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
|
@ -126,6 +128,291 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Converts a broadcasting binary operation with a scalar operand and an
|
||||||
|
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
||||||
|
// the unranked operand to a 1D tensor. This will always be safe because
|
||||||
|
// broadcasting from a scalar to another shape always works.
|
||||||
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
|
: public OpConversionPattern<ChloOpTy> {
|
||||||
|
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
ChloOpTy op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
typename ChloOpTy::Adaptor transformed(operands);
|
||||||
|
Value lhs = transformed.lhs();
|
||||||
|
Value rhs = transformed.rhs();
|
||||||
|
|
||||||
|
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||||
|
|
||||||
|
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||||
|
|
||||||
|
bool lhs_is_scalar = lhs_ranked_type &&
|
||||||
|
lhs_ranked_type.getShape().empty() &&
|
||||||
|
rhs_unranked_type;
|
||||||
|
bool rhs_is_scalar = rhs_ranked_type &&
|
||||||
|
rhs_ranked_type.getShape().empty() &&
|
||||||
|
lhs_unranked_type;
|
||||||
|
|
||||||
|
// Only support the case where exactly one operand is scalar and the other
|
||||||
|
// is unranked. Other patterns in chlo-to-hlo legalization will create more
|
||||||
|
// efficient lowerings for cases where both ranks are known or will handle
|
||||||
|
// the more generic case of both inputs being unranked.
|
||||||
|
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
||||||
|
|
||||||
|
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||||
|
|
||||||
|
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
||||||
|
Value shape =
|
||||||
|
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
||||||
|
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||||
|
Value size_tensor =
|
||||||
|
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
||||||
|
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
||||||
|
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||||
|
|
||||||
|
// Create a new ranked Chlo op that will be further lowered by other
|
||||||
|
// patterns into Mhlo.
|
||||||
|
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
||||||
|
rhs_is_scalar ? rhs : reshaped};
|
||||||
|
Value computed =
|
||||||
|
rewriter.create<ChloOpTy>(loc, SmallVector<Type, 1>{reshaped.getType()},
|
||||||
|
new_operands, op.getAttrs());
|
||||||
|
|
||||||
|
// Reshape the result back into an unranked tensor.
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||||
|
computed, shape);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handles lowering of the following pattern to patterns that will be further
|
||||||
|
// matched by other patterns until they result in LHLO:
|
||||||
|
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
||||||
|
//
|
||||||
|
// The sequence of specializations this handles is:
|
||||||
|
// - Either operand being scalar
|
||||||
|
// - Operands having equal shapes
|
||||||
|
// - The resulting value being any of ranks [2,6]
|
||||||
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
|
: public OpConversionPattern<ChloOpTy> {
|
||||||
|
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
ChloOpTy op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
typename ChloOpTy::Adaptor transformed(operands);
|
||||||
|
Value lhs = transformed.lhs();
|
||||||
|
Value rhs = transformed.rhs();
|
||||||
|
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||||
|
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||||
|
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||||
|
|
||||||
|
// Only support unranked operands. If either operand is ranked, another
|
||||||
|
// pattern will handle the lowering.
|
||||||
|
if (!lhs_type || !rhs_type) return failure();
|
||||||
|
|
||||||
|
// If lhs is scalar
|
||||||
|
auto if_op = rewriter.create<scf::IfOp>(
|
||||||
|
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||||
|
OpBuilder if_lhs_scalar_builder =
|
||||||
|
if_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
|
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
||||||
|
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||||
|
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||||
|
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||||
|
op.getAttrs());
|
||||||
|
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
||||||
|
|
||||||
|
// If lhs is NOT scalar
|
||||||
|
//
|
||||||
|
// See if rhs is scalar
|
||||||
|
OpBuilder else_lhs_scalar_builder =
|
||||||
|
if_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
|
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
||||||
|
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
||||||
|
true);
|
||||||
|
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||||
|
if_rhs_scalar_op.getResult(0));
|
||||||
|
OpBuilder if_rhs_scalar_builder =
|
||||||
|
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
|
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
||||||
|
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||||
|
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||||
|
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||||
|
op.getAttrs());
|
||||||
|
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
||||||
|
|
||||||
|
// If NEITHER shape is scalar
|
||||||
|
//
|
||||||
|
// See if shapes are equal.
|
||||||
|
OpBuilder else_no_scalars_builder =
|
||||||
|
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
|
Value shape_of_lhs =
|
||||||
|
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||||
|
Value shape_of_rhs =
|
||||||
|
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||||
|
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
||||||
|
loc, shape_of_lhs, shape_of_rhs);
|
||||||
|
|
||||||
|
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
||||||
|
loc, result_type, equal_shapes, true);
|
||||||
|
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
||||||
|
if_eq_shapes_op.getResult(0));
|
||||||
|
|
||||||
|
OpBuilder if_eq_shapes_builder =
|
||||||
|
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
|
Value non_broadcast_op =
|
||||||
|
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
||||||
|
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||||
|
|
||||||
|
// If shapes are not scalar, nor equal
|
||||||
|
//
|
||||||
|
// See if values are of a rank that we support.
|
||||||
|
OpBuilder if_neq_shapes_builder =
|
||||||
|
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
|
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||||
|
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns the dyanamic result of checking the given value is a scalar
|
||||||
|
// tensor.
|
||||||
|
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
||||||
|
Value rank_tensor = rewriter.create<shape::RankOp>(
|
||||||
|
loc, rewriter.getIndexType(), shape_of_tensor);
|
||||||
|
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
||||||
|
rank_tensor,
|
||||||
|
rewriter.create<ConstantIndexOp>(loc, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the if statement and code for a broadcasting op with a result of a
|
||||||
|
// given rank.
|
||||||
|
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
|
||||||
|
Value lhs, Value rhs,
|
||||||
|
Value actual_rank,
|
||||||
|
int targeted_rank) const {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// Create the if block to place the current specialized logic in.
|
||||||
|
Value greater_rank_is_n = builder.create<CmpIOp>(
|
||||||
|
loc, CmpIPredicate::eq, actual_rank,
|
||||||
|
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||||
|
auto if_op =
|
||||||
|
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
|
||||||
|
OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener());
|
||||||
|
|
||||||
|
// Handle shape broadcasting and inferrence.
|
||||||
|
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||||
|
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||||
|
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||||
|
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||||
|
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||||
|
auto known_rank_extent_tensor_type =
|
||||||
|
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||||
|
auto reshaped_type = RankedTensorType::get(
|
||||||
|
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||||
|
RankedTensorType::kDynamicSize),
|
||||||
|
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
||||||
|
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
||||||
|
loc, known_rank_extent_tensor_type,
|
||||||
|
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||||
|
ranked_shape));
|
||||||
|
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
||||||
|
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
|
||||||
|
nullptr);
|
||||||
|
Value extended_lhs_casted = if_builder.create<TensorCastOp>(
|
||||||
|
loc, known_rank_extent_tensor_type, extended_lhs);
|
||||||
|
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
||||||
|
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
|
||||||
|
nullptr);
|
||||||
|
Value extended_rhs_casted = if_builder.create<TensorCastOp>(
|
||||||
|
loc, known_rank_extent_tensor_type, extended_rhs);
|
||||||
|
|
||||||
|
// 1. Reshape operands to the given rank (with the same number of elements)
|
||||||
|
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
||||||
|
// can be broadcasted and do the actual broadcasting)
|
||||||
|
// 3. Type erase the output back to unranked
|
||||||
|
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, reshaped_type, lhs, extended_lhs_casted);
|
||||||
|
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, reshaped_type, rhs, extended_rhs_casted);
|
||||||
|
Value result = if_builder.create<ChloOpTy>(
|
||||||
|
loc, ArrayRef<Type>{reshaped_type},
|
||||||
|
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
||||||
|
Value reshaped_result = if_builder.create<TensorCastOp>(
|
||||||
|
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
|
||||||
|
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||||
|
|
||||||
|
// Return the if_op, so the result can be used and the else block can be
|
||||||
|
// used for the next rank specialized step.
|
||||||
|
return if_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterates over the desired ranks to be specialized and generates the code
|
||||||
|
// snippet for each case.
|
||||||
|
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
||||||
|
Value rhs) const {
|
||||||
|
constexpr int max_rank_specialization = 7;
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// Find the larger rank of the 2 operands.
|
||||||
|
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
|
rewriter.getIndexType());
|
||||||
|
Value lhs_shape =
|
||||||
|
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
||||||
|
Value rhs_shape =
|
||||||
|
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
||||||
|
Value lhs_rank =
|
||||||
|
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
||||||
|
Value rhs_rank =
|
||||||
|
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
||||||
|
Value greater_rank_lhs =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
||||||
|
Value greater_rank =
|
||||||
|
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
||||||
|
|
||||||
|
// Generate a list of nested if/else statements to handle rank
|
||||||
|
// specializations from 2-6.
|
||||||
|
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
||||||
|
rhs, greater_rank, 2);
|
||||||
|
|
||||||
|
// Put each subsequent rank specialization inside the else statement of the
|
||||||
|
// previous one.
|
||||||
|
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
|
for (int i = 3; i < max_rank_specialization; i++) {
|
||||||
|
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
||||||
|
rhs, greater_rank, i);
|
||||||
|
|
||||||
|
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||||
|
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fire an assertion if none of the rank specializations applied (one of the
|
||||||
|
// ranks was greater than 6).
|
||||||
|
else_builder.create<AssertOp>(
|
||||||
|
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
|
||||||
|
"Input for dynamic binary op lowering was of a rank greater than 6");
|
||||||
|
else_builder.create<scf::YieldOp>(loc, lhs);
|
||||||
|
|
||||||
|
// Return the result of the outermost if statement.
|
||||||
|
return if_op.getResult(0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct TransformUnrankedHloPass
|
struct TransformUnrankedHloPass
|
||||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
@ -137,7 +424,7 @@ struct TransformUnrankedHloPass
|
||||||
MLIRContext &ctx = getContext();
|
MLIRContext &ctx = getContext();
|
||||||
ConversionTarget target(ctx);
|
ConversionTarget target(ctx);
|
||||||
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
||||||
shape::ShapeDialect>();
|
shape::ShapeDialect, scf::SCFDialect>();
|
||||||
target.addLegalOp<FuncOp>();
|
target.addLegalOp<FuncOp>();
|
||||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
||||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
||||||
|
@ -148,6 +435,12 @@ struct TransformUnrankedHloPass
|
||||||
#undef ADD_LEGAL_CHLO
|
#undef ADD_LEGAL_CHLO
|
||||||
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
|
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
|
||||||
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
|
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
|
||||||
|
target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
|
||||||
|
[](Operation *op) {
|
||||||
|
return !llvm::any_of(op->getOperandTypes(), [](Type type) {
|
||||||
|
return type.isa<UnrankedTensorType>();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
// Populate rewrite patterns.
|
// Populate rewrite patterns.
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
@ -180,6 +473,10 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
#undef MAP_BINARY
|
#undef MAP_BINARY
|
||||||
#undef MAP_CHLO_UNARY
|
#undef MAP_CHLO_UNARY
|
||||||
#undef COMMA
|
#undef COMMA
|
||||||
|
chlo::PopulateForBroadcastingBinaryOp<
|
||||||
|
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
||||||
|
chlo::PopulateForBroadcastingBinaryOp<
|
||||||
|
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
||||||
|
|
|
@ -237,209 +237,3 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||||
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||||
return %0 : tensor<4xi1>
|
return %0 : tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
|
||||||
-> tensor<*xf32>
|
|
||||||
return %0 : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @addScalarUnranked(
|
|
||||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
|
||||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
|
||||||
// CHECK-SAME: ) -> tensor<*xf32> {
|
|
||||||
// First handle the dynamic reshaping of the unranked operand
|
|
||||||
// to a 1D tensor.
|
|
||||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
|
||||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
|
|
||||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
|
||||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// The assuming region is part of the second stage of lowering
|
|
||||||
// with ranked broadcasting logic.
|
|
||||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32>
|
|
||||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
|
||||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
|
|
||||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
|
||||||
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
|
|
||||||
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
|
|
||||||
// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
|
||||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// As part of the unranked logic, the result is reshaped back
|
|
||||||
// to an unranked tensor.
|
|
||||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
// -----
|
|
||||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
|
||||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
|
||||||
-> tensor<*xf32>
|
|
||||||
return %0 : tensor<*xf32>
|
|
||||||
}
|
|
||||||
// CHECK-LABEL: func @addUnrankedScalar(
|
|
||||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
|
||||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
|
||||||
// First handle the dynamic reshaping of the unranked operand
|
|
||||||
// to a 1D tensor.
|
|
||||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
|
||||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
|
|
||||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
|
||||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// The assuming region is part of the second stage of lowering
|
|
||||||
// with ranked broadcasting logic.
|
|
||||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
|
||||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
|
|
||||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
|
|
||||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
|
||||||
// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]]
|
|
||||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
|
||||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// As part of the unranked logic, the result is reshaped back
|
|
||||||
// to an unranked tensor.
|
|
||||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
// -----
|
|
||||||
func @addUnrankedUnranked(
|
|
||||||
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
|
|
||||||
-> tensor<*xf32>
|
|
||||||
return %0 : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @addUnrankedUnranked(
|
|
||||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
|
||||||
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
|
||||||
// Handle scalar LHS case
|
|
||||||
// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
|
||||||
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
|
||||||
// Handle scalar RHS case
|
|
||||||
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
|
||||||
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
|
||||||
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
|
||||||
// Handle scalar RHS case
|
|
||||||
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
|
|
||||||
// CHECK: scf.yield %[[VAL_19]] : tensor<*xf32>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
|
|
||||||
// CHECK: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
|
|
||||||
// CHECK: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
|
||||||
// CHECK: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
|
||||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
|
||||||
// CHECK: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
|
||||||
// Handle rank 2 specialization
|
|
||||||
// CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
|
||||||
// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
|
||||||
// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
|
||||||
// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
|
||||||
// CHECK: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
|
|
||||||
// Handle rank 3 specialization
|
|
||||||
// CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
|
||||||
// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
|
||||||
// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
|
||||||
// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
|
||||||
// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: %[[C4:.*]] = constant 4 : index
|
|
||||||
// CHECK: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
|
|
||||||
// Handle rank 4 specialization
|
|
||||||
// CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
|
||||||
// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
|
||||||
// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
|
||||||
// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: %[[C5:.*]] = constant 5 : index
|
|
||||||
// CHECK: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
|
|
||||||
// Handle rank 5 specialization
|
|
||||||
// CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
|
||||||
// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
|
||||||
// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
|
||||||
// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: %[[C6:.*]] = constant 6 : index
|
|
||||||
// CHECK: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
|
||||||
// Handle rank 6 specialization
|
|
||||||
// CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
|
||||||
// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
|
||||||
// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
|
||||||
// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
|
||||||
// CHECK: } else {
|
|
||||||
// CHECK: %false = constant false
|
|
||||||
// CHECK: assert %false
|
|
||||||
// CHECK: scf.yield %[[LHS]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: return %[[VAL_71:.*]] : tensor<*xf32>
|
|
||||||
// CHECK: }
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
|
// RUN: mlir-hlo-opt --transform-unranked-hlo --cse --split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// Check the validity of expected IR.
|
// Check the validity of expected IR.
|
||||||
// CHECK-LABEL: @sqr_transform_result
|
// CHECK-LABEL: @sqr_transform_result
|
||||||
|
@ -96,3 +96,203 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
%result = chlo.tan %a : tensor<*xf32>
|
%result = chlo.tan %a : tensor<*xf32>
|
||||||
return %result : tensor<*xf32>
|
return %result : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
||||||
|
-> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @addScalarUnranked(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
||||||
|
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
||||||
|
// CHECK-SAME: ) -> tensor<*xf32> {
|
||||||
|
// First handle the dynamic reshaping of the unranked operand
|
||||||
|
// to a 1D tensor.
|
||||||
|
// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// As part of the unranked logic, the result is reshaped back
|
||||||
|
// to an unranked tensor.
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
||||||
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
||||||
|
-> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @addUnrankedScalar(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
||||||
|
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
||||||
|
// First handle the dynamic reshaping of the unranked operand
|
||||||
|
// to a 1D tensor.
|
||||||
|
// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// The assuming region is part of the second stage of lowering
|
||||||
|
// with ranked broadcasting logic.
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||||
|
// As part of the unranked logic, the result is reshaped back
|
||||||
|
// to an unranked tensor.
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @addUnrankedUnranked(
|
||||||
|
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
-> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @addUnrankedUnranked(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
||||||
|
// Handle scalar LHS case
|
||||||
|
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
||||||
|
// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
||||||
|
// Handle scalar RHS case
|
||||||
|
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
||||||
|
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||||
|
// Handle equal shapes case
|
||||||
|
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
|
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
||||||
|
// Handle rank 2 specialization
|
||||||
|
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
|
||||||
|
// Handle rank 3 specialization
|
||||||
|
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
|
||||||
|
// Handle rank 4 specialization
|
||||||
|
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
|
||||||
|
// Handle rank 5 specialization
|
||||||
|
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
||||||
|
// Handle rank 6 specialization
|
||||||
|
// CHECK-NEXT: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %false = constant false
|
||||||
|
// CHECK-NEXT: assert %false
|
||||||
|
// CHECK-NEXT: scf.yield %[[LHS]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return %[[VAL_71:.*]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
Loading…
Reference in New Issue