[MLIR][HLO] Remove unused pass `TransformUnrankedHloPass`
The pass was replaced by the new generalized rank specialization and the two passes `mhlo-rank-specialization-cluster` and `mhlo-rank-specialization-to-scf`. PiperOrigin-RevId: 379935562
This commit is contained in:
parent
10634ca3a6
commit
470ac45f45
24
BUILD
24
BUILD
|
@ -857,29 +857,6 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "transform_unranked_hlo",
|
|
||||||
srcs = ["lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc"],
|
|
||||||
hdrs = [
|
|
||||||
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
|
||||||
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":hlo",
|
|
||||||
":map_chlo_to_hlo_op",
|
|
||||||
":pass_details",
|
|
||||||
"@llvm-project//llvm:Support",
|
|
||||||
"@llvm-project//mlir:IR",
|
|
||||||
"@llvm-project//mlir:Pass",
|
|
||||||
"@llvm-project//mlir:SCFDialect",
|
|
||||||
"@llvm-project//mlir:Shape",
|
|
||||||
"@llvm-project//mlir:StandardOps",
|
|
||||||
"@llvm-project//mlir:TensorDialect",
|
|
||||||
"@llvm-project//mlir:Transforms",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "broadcast_propagation",
|
name = "broadcast_propagation",
|
||||||
srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"],
|
srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"],
|
||||||
|
@ -1379,7 +1356,6 @@ cc_library(
|
||||||
":rank_specialization",
|
":rank_specialization",
|
||||||
":sink_constants_to_control_flow",
|
":sink_constants_to_control_flow",
|
||||||
":test_passes",
|
":test_passes",
|
||||||
":transform_unranked_hlo",
|
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -112,12 +112,6 @@ def TestInferShapedTypeMethodsPass : FunctionPass<"mhlo-test-infer-shaped-type-m
|
||||||
let constructor = "createTestInferShapedTypeMethodsPass()";
|
let constructor = "createTestInferShapedTypeMethodsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def TransformUnrankedHloPass : FunctionPass<"mhlo-transform-unranked-hlo"> {
|
|
||||||
let summary = "Realize element-wise operations on ranked tensors where possible.";
|
|
||||||
let constructor = "createTransformUnrankedHloPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> {
|
def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> {
|
||||||
let summary = "Move dynamic broadcasts up over element-wise operations and "
|
let summary = "Move dynamic broadcasts up over element-wise operations and "
|
||||||
"broadcast the operands rather than the result. This will eventually allow "
|
"broadcast the operands rather than the result. This will eventually allow "
|
||||||
|
|
|
@ -32,9 +32,6 @@ class Pass;
|
||||||
|
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
// Transforms unranked HLO operations to ranked ones where possible.
|
|
||||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
|
|
||||||
|
|
||||||
/// Lowers HLO control flow ops to the Standard dialect.
|
/// Lowers HLO control flow ops to the Standard dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||||
|
|
||||||
|
|
|
@ -81,10 +81,6 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
||||||
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns);
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
// Sets up legality definitions for element-wise operations on ranked tensors.
|
|
||||||
void SetupTransformUnrankedHloLegality(MLIRContext *context,
|
|
||||||
ConversionTarget *conversionTarget);
|
|
||||||
|
|
||||||
// Populates a collection of rewrite patterns to realize element-wise operations
|
// Populates a collection of rewrite patterns to realize element-wise operations
|
||||||
// on ranked tensors where possible.
|
// on ranked tensors where possible.
|
||||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
|
|
|
@ -62,7 +62,6 @@ add_mlir_library(MhloPasses
|
||||||
rank_specialization.cc
|
rank_specialization.cc
|
||||||
sink_constants_to_control_flow.cc
|
sink_constants_to_control_flow.cc
|
||||||
test_infer_shaped_type_pass.cc
|
test_infer_shaped_type_pass.cc
|
||||||
transform_unranked_hlo.cc
|
|
||||||
unfuse_batch_norm.cc
|
unfuse_batch_norm.cc
|
||||||
unfuse_batch_norm_pass.cc
|
unfuse_batch_norm_pass.cc
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto result_ty = op.getType().cast<ShapedType>();
|
auto result_ty = op.getType().cast<ShapedType>();
|
||||||
|
|
||||||
// Unranked uses are not supported. Consider `mhlo-transform-unranked-hlo`.
|
// Unranked uses are not supported.
|
||||||
if (!result_ty.hasRank()) return failure();
|
if (!result_ty.hasRank()) return failure();
|
||||||
|
|
||||||
// Lower to MHLO constant if statically shaped.
|
// Lower to MHLO constant if statically shaped.
|
||||||
|
|
|
@ -1,619 +0,0 @@
|
||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
|
||||||
#define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \
|
|
||||||
fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(ConvertOp) sep fn(CosOp) \
|
|
||||||
sep fn(ExpOp) sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) \
|
|
||||||
sep fn(IsFiniteOp) sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) \
|
|
||||||
sep fn(NotOp) sep fn(NegOp) sep fn(PopulationCountOp) \
|
|
||||||
sep fn(RealOp) sep fn(RoundOp) sep fn(RsqrtOp) \
|
|
||||||
sep fn(SignOp) sep fn(SinOp) sep fn(SqrtOp) \
|
|
||||||
sep fn(TanhOp)
|
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
|
||||||
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
|
|
||||||
fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \
|
|
||||||
sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
|
|
||||||
sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
|
||||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
|
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
|
||||||
fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
|
|
||||||
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(DigammaOp) \
|
|
||||||
sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
|
|
||||||
sep fn(SinhOp) sep fn(TanOp)
|
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
|
||||||
#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(PolygammaOp) sep fn(ZetaOp)
|
|
||||||
|
|
||||||
template <typename OpTy>
|
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
|
||||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
|
||||||
return llvm::all_of(op.getOperation()->getOperandTypes(),
|
|
||||||
[&](Type t) { return t.isa<RankedTensorType>(); });
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Element-wise operations on unranked tensors can be applied to the flattened
|
|
||||||
/// tensor operands with the same effect. This pattern rewrites every such
|
|
||||||
/// operation to
|
|
||||||
/// (i) flatten the input tensor,
|
|
||||||
/// (ii) apply the operation, and
|
|
||||||
/// (iii) restore the original shape.
|
|
||||||
template <typename OpTy>
|
|
||||||
struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|
||||||
explicit ElementwiseOpConversion(MLIRContext *context)
|
|
||||||
: OpRewritePattern<OpTy>(context) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(OpTy op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// Only apply conversion if at least one operand is unranked.
|
|
||||||
if (llvm::none_of(op.getOperation()->getOperands(), [&](Value operand) {
|
|
||||||
return operand.getType().isa<UnrankedTensorType>();
|
|
||||||
})) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get operands' shape.
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
|
|
||||||
SmallVector<Value, 3> operandShapes;
|
|
||||||
for (Value operand : op.getOperation()->getOperands()) {
|
|
||||||
Value shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
|
|
||||||
operandShapes.push_back(shape);
|
|
||||||
}
|
|
||||||
Value shape =
|
|
||||||
operandShapes.size() == 1
|
|
||||||
? operandShapes.front()
|
|
||||||
: rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
|
|
||||||
|
|
||||||
// Derive flat shape.
|
|
||||||
Type indexTy = rewriter.getIndexType();
|
|
||||||
Value numElements =
|
|
||||||
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
|
|
||||||
Value flatShape = rewriter.create<tensor::FromElementsOp>(loc, numElements);
|
|
||||||
|
|
||||||
// Flatten operands.
|
|
||||||
SmallVector<Value, 3> flatOperands;
|
|
||||||
for (Value operand : op.getOperation()->getOperands()) {
|
|
||||||
Type operandElementTy =
|
|
||||||
operand.getType().template cast<ShapedType>().getElementType();
|
|
||||||
Type flatTy =
|
|
||||||
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
|
|
||||||
Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
|
|
||||||
flatShape);
|
|
||||||
flatOperands.push_back(flat);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply operation to flattened operands.
|
|
||||||
Type resultElementTy =
|
|
||||||
op.getType().template cast<ShapedType>().getElementType();
|
|
||||||
Type flatResultTy =
|
|
||||||
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
|
|
||||||
Value flatResult =
|
|
||||||
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op->getAttrs());
|
|
||||||
|
|
||||||
// Restore original shape.
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
|
|
||||||
flatResult, shape);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Converts a broadcasting binary operation with a scalar operand and an
|
|
||||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
|
||||||
// the unranked operand to a 1D tensor. This will always be safe because
|
|
||||||
// broadcasting from a scalar to another shape always works.
|
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
|
||||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
|
||||||
: public OpConversionPattern<ChloOpTy> {
|
|
||||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
|
||||||
LogicalResult matchAndRewrite(
|
|
||||||
ChloOpTy op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
typename ChloOpTy::Adaptor transformed(operands);
|
|
||||||
Value lhs = transformed.lhs();
|
|
||||||
Value rhs = transformed.rhs();
|
|
||||||
|
|
||||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
|
||||||
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
|
||||||
|
|
||||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
|
||||||
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
|
||||||
|
|
||||||
bool lhs_is_scalar = lhs_ranked_type &&
|
|
||||||
lhs_ranked_type.getShape().empty() &&
|
|
||||||
rhs_unranked_type;
|
|
||||||
bool rhs_is_scalar = rhs_ranked_type &&
|
|
||||||
rhs_ranked_type.getShape().empty() &&
|
|
||||||
lhs_unranked_type;
|
|
||||||
|
|
||||||
// Only support the case where exactly one operand is scalar and the other
|
|
||||||
// is unranked. Other patterns in chlo-to-hlo legalization will create more
|
|
||||||
// efficient lowerings for cases where both ranks are known or will handle
|
|
||||||
// the more generic case of both inputs being unranked.
|
|
||||||
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
|
||||||
|
|
||||||
auto scalar_element_type = lhs_is_scalar ? lhs_ranked_type.getElementType()
|
|
||||||
: rhs_ranked_type.getElementType();
|
|
||||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
|
||||||
auto result_element_type = result_type.getElementType();
|
|
||||||
|
|
||||||
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
|
||||||
Value shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
|
||||||
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
|
||||||
Value size_tensor =
|
|
||||||
rewriter.create<tensor::FromElementsOp>(loc, num_elements);
|
|
||||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc,
|
|
||||||
RankedTensorType::get({RankedTensorType::kDynamicSize},
|
|
||||||
scalar_element_type),
|
|
||||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
|
||||||
|
|
||||||
// Create a new ranked Chlo op that will be further lowered by other
|
|
||||||
// patterns into Mhlo.
|
|
||||||
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
|
||||||
rhs_is_scalar ? rhs : reshaped};
|
|
||||||
Value computed = rewriter.create<ChloOpTy>(
|
|
||||||
loc,
|
|
||||||
TypeRange{RankedTensorType::get({RankedTensorType::kDynamicSize},
|
|
||||||
result_element_type)},
|
|
||||||
new_operands, op->getAttrs());
|
|
||||||
|
|
||||||
// Reshape the result back into an unranked tensor.
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
|
||||||
computed, shape);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Handles lowering of the following pattern to patterns that will be further
|
|
||||||
// matched by other patterns until they result in LHLO:
|
|
||||||
// %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy>
|
|
||||||
//
|
|
||||||
// The sequence of specializations this handles is:
|
|
||||||
// - At most one operand has a shape that does not consist of exactly one
|
|
||||||
// element.
|
|
||||||
// - All operands having equal shapes
|
|
||||||
// - The resulting minimized shapes being any of ranks [1,5]
|
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
|
||||||
struct ConvertUnrankedDynamicBroadcastNaryOp
|
|
||||||
: public OpConversionPattern<ChloOpTy> {
|
|
||||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
|
||||||
ChloOpTy op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
typename ChloOpTy::Adaptor transformed(operands);
|
|
||||||
ValueRange transformed_operands = transformed.getOperands();
|
|
||||||
auto num_operands = transformed_operands.size();
|
|
||||||
llvm::SmallVector<Type, 3> operand_element_types;
|
|
||||||
operand_element_types.reserve(num_operands);
|
|
||||||
bool has_unranked_tensor_type = false;
|
|
||||||
for (int i = 0; i < num_operands; ++i) {
|
|
||||||
if (auto type =
|
|
||||||
transformed_operands[i].getType().dyn_cast<TensorType>()) {
|
|
||||||
if (type.isa<UnrankedTensorType>()) {
|
|
||||||
has_unranked_tensor_type = true;
|
|
||||||
}
|
|
||||||
operand_element_types.push_back(type.getElementType());
|
|
||||||
} else {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!has_unranked_tensor_type) return failure();
|
|
||||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
|
||||||
|
|
||||||
llvm::SmallVector<Value> shapes;
|
|
||||||
shapes.reserve(num_operands);
|
|
||||||
for (int i = 0; i < num_operands; ++i) {
|
|
||||||
shapes.push_back(
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, transformed_operands[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
// If at most one shape does not have exactly one element
|
|
||||||
Value counter = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
||||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
|
||||||
for (int i = 0; i < num_operands; ++i) {
|
|
||||||
Value is_scalar_like = IsSingleElementShape(rewriter, op, shapes[i]);
|
|
||||||
Value counter_plus_one = rewriter.create<AddIOp>(loc, counter, one);
|
|
||||||
counter = rewriter.create<SelectOp>(loc, is_scalar_like, counter_plus_one,
|
|
||||||
counter);
|
|
||||||
}
|
|
||||||
Value num_operands_minus_one =
|
|
||||||
rewriter.create<ConstantIndexOp>(loc, num_operands - 1);
|
|
||||||
Value at_most_one_non_scalar =
|
|
||||||
rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::uge,
|
|
||||||
counter, num_operands_minus_one);
|
|
||||||
|
|
||||||
auto if_op = rewriter.create<scf::IfOp>(loc, result_type,
|
|
||||||
at_most_one_non_scalar, true);
|
|
||||||
OpBuilder if_at_most_one_non_scalar_builder =
|
|
||||||
if_op.getThenBodyBuilder(rewriter.getListener());
|
|
||||||
|
|
||||||
llvm::SmallVector<Value, 3> reshaped_operands;
|
|
||||||
reshaped_operands.reserve(num_operands);
|
|
||||||
for (int i = 0; i < num_operands; ++i) {
|
|
||||||
Value num_elements =
|
|
||||||
if_at_most_one_non_scalar_builder.create<shape::NumElementsOp>(
|
|
||||||
loc, shapes[i]);
|
|
||||||
Value size_tensor =
|
|
||||||
if_at_most_one_non_scalar_builder.create<tensor::FromElementsOp>(
|
|
||||||
loc, num_elements);
|
|
||||||
Value reshaped =
|
|
||||||
if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc,
|
|
||||||
RankedTensorType::get({RankedTensorType::kDynamicSize},
|
|
||||||
operand_element_types[i]),
|
|
||||||
transformed_operands[i], size_tensor);
|
|
||||||
reshaped_operands.push_back(reshaped);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto rank_one_result_type = RankedTensorType::get(
|
|
||||||
{RankedTensorType::kDynamicSize}, result_type.getElementType());
|
|
||||||
Value if_at_most_one_non_scalar_result =
|
|
||||||
if_at_most_one_non_scalar_builder.create<ChloOpTy>(
|
|
||||||
loc, ArrayRef<Type>{rank_one_result_type}, reshaped_operands,
|
|
||||||
op->getAttrs());
|
|
||||||
Value extended_result = extendToBroadcastShape(
|
|
||||||
if_at_most_one_non_scalar_builder, loc, result_type,
|
|
||||||
if_at_most_one_non_scalar_result, shapes);
|
|
||||||
if_at_most_one_non_scalar_builder.create<scf::YieldOp>(loc,
|
|
||||||
extended_result);
|
|
||||||
|
|
||||||
// If there is more than one shape which does not have exactly one element
|
|
||||||
//
|
|
||||||
// See if all shapes are equal.
|
|
||||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
|
||||||
Value equal_shapes =
|
|
||||||
else_builder.create<shape::ShapeEqOp>(loc, shapes[0], shapes[1]);
|
|
||||||
for (int i = 2; i < num_operands; ++i) {
|
|
||||||
Value are_equal =
|
|
||||||
else_builder.create<shape::ShapeEqOp>(loc, shapes[0], shapes[i]);
|
|
||||||
equal_shapes = else_builder.create<AndOp>(loc, equal_shapes, are_equal);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto if_eq_shapes_op =
|
|
||||||
else_builder.create<scf::IfOp>(loc, result_type, equal_shapes, true);
|
|
||||||
else_builder.create<scf::YieldOp>(loc, if_eq_shapes_op.getResult(0));
|
|
||||||
|
|
||||||
OpBuilder if_eq_shapes_builder =
|
|
||||||
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
|
||||||
Value non_broadcast_op = Adaptor::CreateOp(
|
|
||||||
op, result_type, transformed_operands, if_eq_shapes_builder);
|
|
||||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
|
||||||
|
|
||||||
// If shapes do not have exactly one element, nor are equal
|
|
||||||
//
|
|
||||||
// See if values are of a rank that we support.
|
|
||||||
OpBuilder if_neq_shapes_builder =
|
|
||||||
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
|
||||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
|
||||||
loc,
|
|
||||||
HandleBroadcastAndOp(if_neq_shapes_builder, op, transformed_operands));
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Returns the dynamic result of checking the given value is effectively a
|
|
||||||
// scalar shape (i.e. the number of elements is 1).
|
|
||||||
Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op,
|
|
||||||
Value shape_of_tensor) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
|
|
||||||
Value num_elements =
|
|
||||||
rewriter.create<shape::NumElementsOp>(loc, shape_of_tensor);
|
|
||||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
|
||||||
num_elements,
|
|
||||||
rewriter.create<ConstantIndexOp>(loc, 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
Value extendToBroadcastShape(OpBuilder &builder, Location loc,
|
|
||||||
Type result_type, Value value,
|
|
||||||
ValueRange shapes) const {
|
|
||||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
|
||||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
|
||||||
Value broadcast_shape = builder.create<shape::BroadcastOp>(
|
|
||||||
loc, unknown_rank_extent_tensor_type, shapes, nullptr);
|
|
||||||
return builder.create<mhlo::DynamicReshapeOp>(loc, result_type, value,
|
|
||||||
broadcast_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the dynamic result of checking the given value is effectively a
|
|
||||||
// scalar shape (i.e. the number of elements is 1).
|
|
||||||
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
|
|
||||||
int targeted_rank) const {
|
|
||||||
return builder.create<CmpIOp>(
|
|
||||||
loc, CmpIPredicate::eq, actual_rank,
|
|
||||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
|
||||||
}
|
|
||||||
|
|
||||||
scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
|
|
||||||
OpBuilder &builder, ChloOpTy op, Value actual_rank,
|
|
||||||
int targeted_rank) const {
|
|
||||||
// Create the if block to place the current specialized logic in.
|
|
||||||
Value greater_rank_is_n =
|
|
||||||
GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
|
|
||||||
return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
|
|
||||||
greater_rank_is_n, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value shape,
|
|
||||||
int targeted_rank) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
|
||||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
|
||||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
|
||||||
auto known_rank_extent_tensor_type =
|
|
||||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
|
||||||
Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
|
|
||||||
loc, known_rank_extent_tensor_type,
|
|
||||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
|
||||||
ranked_shape));
|
|
||||||
Value extended_value = builder.create<shape::BroadcastOp>(
|
|
||||||
loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
|
|
||||||
return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
|
|
||||||
extended_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the if statement and code for a broadcasting op with a result of a
|
|
||||||
// given rank.
|
|
||||||
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
|
|
||||||
ValueRange operands,
|
|
||||||
ValueRange operand_shapes,
|
|
||||||
int targeted_rank) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
SmallVector<Value, 2> reshaped_operands;
|
|
||||||
|
|
||||||
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
|
||||||
targeted_rank, RankedTensorType::kDynamicSize);
|
|
||||||
|
|
||||||
for (auto it : llvm::zip(operands, operand_shapes)) {
|
|
||||||
Value operand, shape;
|
|
||||||
std::tie(operand, shape) = it;
|
|
||||||
// Handle shape broadcasting and inference.
|
|
||||||
Value extended_operand_casted =
|
|
||||||
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
|
|
||||||
|
|
||||||
// 1. Reshape operands to the given rank (with the same number of
|
|
||||||
// elements)
|
|
||||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the
|
|
||||||
// ops
|
|
||||||
// can be broadcasted and do the actual broadcasting)
|
|
||||||
// 3. Type erase the output back to unranked
|
|
||||||
auto reshaped_type = RankedTensorType::get(
|
|
||||||
dynamic_dimensions,
|
|
||||||
operand.getType().template dyn_cast<TensorType>().getElementType());
|
|
||||||
Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, reshaped_type, operand, extended_operand_casted);
|
|
||||||
reshaped_operands.push_back(reshaped_operand);
|
|
||||||
}
|
|
||||||
auto result_element_type = op.getResult()
|
|
||||||
.getType()
|
|
||||||
.template dyn_cast<TensorType>()
|
|
||||||
.getElementType();
|
|
||||||
auto result_type =
|
|
||||||
RankedTensorType::get(dynamic_dimensions, result_element_type);
|
|
||||||
Value result = if_builder.create<ChloOpTy>(
|
|
||||||
loc, ArrayRef<Type>{result_type}, reshaped_operands, op->getAttrs());
|
|
||||||
Value reshaped_result = if_builder.create<tensor::CastOp>(
|
|
||||||
loc, UnrankedTensorType::get(result_element_type), result);
|
|
||||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterates over the desired ranks to be specialized and generates the code
|
|
||||||
// snippet for each case.
|
|
||||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
|
|
||||||
ValueRange operands) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
|
|
||||||
// Get the minimum broadcast shapes of the operands.
|
|
||||||
SmallVector<Value> shapes;
|
|
||||||
shapes.reserve(operands.size());
|
|
||||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
rewriter.getIndexType());
|
|
||||||
for (Value operand : operands) {
|
|
||||||
Value shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
|
|
||||||
shapes.push_back(shape);
|
|
||||||
}
|
|
||||||
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
|
|
||||||
loc, extent_tensor_type, shapes, nullptr);
|
|
||||||
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
|
|
||||||
auto reduced_shapes =
|
|
||||||
rewriter
|
|
||||||
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
|
|
||||||
.results();
|
|
||||||
SmallVector<Value> reshaped_operands;
|
|
||||||
reshaped_operands.reserve(operands.size());
|
|
||||||
for (auto it : llvm::zip(operands, reduced_shapes)) {
|
|
||||||
Value operand;
|
|
||||||
Value reduced_shape;
|
|
||||||
std::tie(operand, reduced_shape) = it;
|
|
||||||
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, operand.getType(), operand, reduced_shape);
|
|
||||||
reshaped_operands.push_back(reshaped_operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the largest rank of the operands.
|
|
||||||
Value greater_rank;
|
|
||||||
for (Value shape : reduced_shapes) {
|
|
||||||
Value rank =
|
|
||||||
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
|
|
||||||
if (!greater_rank) {
|
|
||||||
greater_rank = rank;
|
|
||||||
} else {
|
|
||||||
Value greater_rank_compare = rewriter.create<CmpIOp>(
|
|
||||||
loc, CmpIPredicate::sgt, greater_rank, rank);
|
|
||||||
greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare,
|
|
||||||
greater_rank, rank);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a list of nested if/else statements to handle rank
|
|
||||||
// specializations from 1 to `kMaxRankSpecialization`.
|
|
||||||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
|
||||||
rewriter, op, greater_rank, 1);
|
|
||||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
|
||||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
|
||||||
reduced_shapes, 1);
|
|
||||||
|
|
||||||
// Put each subsequent rank specialization inside the else statement of the
|
|
||||||
// previous one.
|
|
||||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
|
||||||
|
|
||||||
// Tensorflow supports up to rank 8 for SelectOp (currently the only op with
|
|
||||||
// arity > 2 that we support), but only up to rank 5 for binary ops. We want
|
|
||||||
// to preserve this behavior.
|
|
||||||
const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5;
|
|
||||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
|
||||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
|
||||||
else_builder, op, greater_rank, i);
|
|
||||||
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
|
||||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
|
||||||
reduced_shapes, i);
|
|
||||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
|
||||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
|
||||||
}
|
|
||||||
// Fire an assertion if none of the rank specializations applied (one of
|
|
||||||
// the ranks was greater than `kMaxRankSpecialization`).
|
|
||||||
else_builder.create<AssertOp>(
|
|
||||||
loc,
|
|
||||||
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
|
|
||||||
kMaxRankSpecialization),
|
|
||||||
"Input for dynamic binary op lowering was of a rank greater than " +
|
|
||||||
std::to_string(kMaxRankSpecialization));
|
|
||||||
// Add the rank 5 specialization to the innermost else block.
|
|
||||||
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
|
|
||||||
reduced_shapes, kMaxRankSpecialization);
|
|
||||||
|
|
||||||
// Return the reshaped result of the outermost if statement.
|
|
||||||
auto result = if_op.getResult(0);
|
|
||||||
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, result.getType(), result, broadcast_shape);
|
|
||||||
return reshaped_result;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TransformUnrankedHloPass
|
|
||||||
: public mhlo::TransformUnrankedHloPassBase<TransformUnrankedHloPass> {
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
||||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
|
|
||||||
shape::ShapeDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnFunction() override {
|
|
||||||
// Setup conversion target.
|
|
||||||
MLIRContext &ctx = getContext();
|
|
||||||
ConversionTarget target(ctx);
|
|
||||||
target.addLegalDialect<chlo::HloClientDialect, mhlo::MhloDialect,
|
|
||||||
StandardOpsDialect, shape::ShapeDialect,
|
|
||||||
scf::SCFDialect, tensor::TensorDialect>();
|
|
||||||
target.addLegalOp<FuncOp>();
|
|
||||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
|
||||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
|
||||||
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
|
|
||||||
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
|
|
||||||
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
|
|
||||||
#undef ADD_LEGAL_MHLO
|
|
||||||
#undef ADD_LEGAL_CHLO
|
|
||||||
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
|
|
||||||
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
|
|
||||||
target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
|
|
||||||
[](Operation *op) {
|
|
||||||
return llvm::none_of(op->getOperandTypes(), [](Type type) {
|
|
||||||
return type.isa<UnrankedTensorType>();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Populate rewrite patterns.
|
|
||||||
OwningRewritePatternList patterns(&ctx);
|
|
||||||
mhlo::PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
|
|
||||||
|
|
||||||
// Apply transformation.
|
|
||||||
if (failed(
|
|
||||||
applyPartialConversion(getFunction(), target, std::move(patterns))))
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace mhlo {
|
|
||||||
|
|
||||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|
||||||
OwningRewritePatternList *patterns) {
|
|
||||||
#define MAP_HLO(op) ElementwiseOpConversion<mhlo::op>
|
|
||||||
#define MAP_CHLO(op) ElementwiseOpConversion<chlo::op>
|
|
||||||
#define COMMA ,
|
|
||||||
// clang-format off
|
|
||||||
patterns->insert<
|
|
||||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA),
|
|
||||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA),
|
|
||||||
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA),
|
|
||||||
MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA),
|
|
||||||
ElementwiseOpConversion<mhlo::CompareOp>,
|
|
||||||
ElementwiseOpConversion<mhlo::SelectOp>>(context);
|
|
||||||
// clang-format on
|
|
||||||
#undef MAP_HLO
|
|
||||||
#undef MAP_CHLO
|
|
||||||
#undef COMMA
|
|
||||||
chlo::PopulateForBroadcastingBinaryOp<ConvertUnrankedDynamicBroadcastNaryOp>(
|
|
||||||
context, patterns);
|
|
||||||
patterns->insert<ConvertUnrankedDynamicBroadcastNaryOp<
|
|
||||||
chlo::BroadcastSelectOp, mhlo::SelectOp,
|
|
||||||
chlo::HloNaryElementwiseAdaptor<chlo::BroadcastSelectOp,
|
|
||||||
mhlo::SelectOp>>>(context);
|
|
||||||
chlo::PopulateForBroadcastingBinaryOp<
|
|
||||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
|
||||||
return std::make_unique<TransformUnrankedHloPass>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mhlo
|
|
||||||
} // namespace mlir
|
|
|
@ -1,410 +0,0 @@
|
||||||
// RUN: mlir-hlo-opt --mhlo-transform-unranked-hlo --cse --split-input-file %s | FileCheck %s
|
|
||||||
|
|
||||||
// Check the validity of expected IR.
|
|
||||||
// CHECK-LABEL: @sqr_transform_result
|
|
||||||
func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
|
|
||||||
// Flatten operand shape.
|
|
||||||
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
|
|
||||||
%flat_shape = tensor.from_elements %num_elements : tensor<1xindex>
|
|
||||||
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
|
|
||||||
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
|
|
||||||
// Apply operation.
|
|
||||||
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
|
|
||||||
// Restore original shape.
|
|
||||||
%b = "mhlo.dynamic_reshape"(%flat_b, %shape)
|
|
||||||
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
|
|
||||||
return %b : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Check transformation of unranked code.
|
|
||||||
// CHECK-LABEL: @sqrt
|
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>)
|
|
||||||
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
|
||||||
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
|
||||||
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
return %b : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Not transformed when ranked.
|
|
||||||
// CHECK-LABEL: @sqrt_ranked
|
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
|
|
||||||
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
|
||||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
|
||||||
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
|
|
||||||
%b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
|
||||||
return %b : tensor<3x?xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Not transformed when statically shaped.
|
|
||||||
// CHECK-LABEL: @sqrt_static
|
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
|
|
||||||
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
|
||||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
|
||||||
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
|
|
||||||
%b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
|
||||||
return %b : tensor<2x3xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Transformed if there is a mix of unranked/static shapes.
|
|
||||||
// CHECK-LABEL: @select_mixed
|
|
||||||
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>)
|
|
||||||
func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>) -> tensor<*xf32> {
|
|
||||||
// CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]]
|
|
||||||
// CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]]
|
|
||||||
// CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]]
|
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]]
|
|
||||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
|
||||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
|
||||||
// CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
|
||||||
// CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
|
||||||
%b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
|
|
||||||
return %b : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: @add_unranked
|
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
// CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]]
|
|
||||||
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
|
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]]
|
|
||||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
|
||||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
|
||||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
|
||||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
|
||||||
%result = mhlo.add %a, %b : tensor<*xf32>
|
|
||||||
return %result : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: @tan
|
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
|
||||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
|
||||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32> -> tensor<?xf32>
|
|
||||||
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK: return %[[B]] : tensor<*xf32>
|
|
||||||
%result = chlo.tan %a : tensor<*xf32> -> tensor<*xf32>
|
|
||||||
return %result : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
|
||||||
-> tensor<*xf32>
|
|
||||||
return %0 : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @addScalarUnranked(
|
|
||||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
|
||||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
|
||||||
// CHECK-SAME: ) -> tensor<*xf32> {
|
|
||||||
// First handle the dynamic reshaping of the unranked operand
|
|
||||||
// to a 1D tensor.
|
|
||||||
// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// As part of the unranked logic, the result is reshaped back
|
|
||||||
// to an unranked tensor.
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
|
|
||||||
// -----
|
|
||||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
|
||||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
|
||||||
-> tensor<*xf32>
|
|
||||||
return %0 : tensor<*xf32>
|
|
||||||
}
|
|
||||||
// CHECK-LABEL: func @addUnrankedScalar(
|
|
||||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
|
||||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
|
||||||
// First handle the dynamic reshaping of the unranked operand
|
|
||||||
// to a 1D tensor.
|
|
||||||
// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// The assuming region is part of the second stage of lowering
|
|
||||||
// with ranked broadcasting logic.
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
|
||||||
// As part of the unranked logic, the result is reshaped back
|
|
||||||
// to an unranked tensor.
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
|
|
||||||
// -----
|
|
||||||
func @addUnrankedUnranked(
|
|
||||||
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
|
|
||||||
-> tensor<*xf32>
|
|
||||||
return %0 : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @addUnrankedUnranked(
|
|
||||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
|
||||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
|
||||||
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[C0]], %[[C1]] : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[C0]] : index
|
|
||||||
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER_PLUS_ONE2:.*]] = addi %[[COUNTER]], %[[C1]] : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE2]], %[[COUNTER]] : index
|
|
||||||
// Handle scalar case
|
|
||||||
// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER2]], %[[C1]] : index
|
|
||||||
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
|
||||||
// Handle equal shapes case
|
|
||||||
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
|
||||||
// Handle rank 1 specialization
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index
|
|
||||||
// Handle rank 2 specialization
|
|
||||||
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index
|
|
||||||
// Handle rank 3 specialization
|
|
||||||
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index
|
|
||||||
// Handle rank 4 specialization
|
|
||||||
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
|
|
||||||
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
|
|
||||||
// Handle rank 5 specialization
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @selectUnrankedUnrankedUnranked(
|
|
||||||
%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>)
|
|
||||||
-> tensor<*xf32> {
|
|
||||||
%0 = chlo.broadcast_select %arg0, %arg1, %arg2
|
|
||||||
: (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
return %0 : tensor<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @selectUnrankedUnrankedUnranked(
|
|
||||||
// CHECK-SAME: %[[PRED:.*]]: tensor<*xi1>,
|
|
||||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %c0 = constant 0 : index
|
|
||||||
// CHECK-NEXT: %c1 = constant 1 : index
|
|
||||||
// CHECK-NEXT: %[[NUM_ELEMENTS_PRED:.*]] = shape.num_elements %[[PRED_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[PRED_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_PRED]], %c1 : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %c0, %c1 : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER:.*]] = select %[[PRED_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %c0 : index
|
|
||||||
// CHECK-NEXT: %[[NUM_ELEMENTS_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_LHS]], %c1 : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER]], %c1 : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER]] : index
|
|
||||||
// CHECK-NEXT: %[[NUM_ELEMENTS_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_RHS]], %c1 : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER2]], %c1 : index
|
|
||||||
// CHECK-NEXT: %[[COUNTER3:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER2]] : index
|
|
||||||
// CHECK-NEXT: %c2 = constant 2 : index
|
|
||||||
// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER3]], %c2 : index
|
|
||||||
// CHECK-NEXT: %[[IF_IS_SCALAR_CASE:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[NUM_TENS_PRED:.*]] = tensor.from_elements %[[NUM_ELEMENTS_PRED]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[NUM_TENS_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
|
||||||
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_LHS]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_RHS]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[FIRST_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[LHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[SECOND_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[ALL_SHAPES_EQUAL:.*]] = and %[[FIRST_SHAPES_EQUAL]], %[[SECOND_SHAPES_EQUAL]] : i1
|
|
||||||
// CHECK-NEXT: %[[IF_EQUAL_CASE:.*]] = scf.if %[[ALL_SHAPES_EQUAL]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[FLATTENED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[ANY_TENSOR]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
|
||||||
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = "mhlo.select"(%[[FLATTENED_PRED]], %[[FLATTENED_LHS]], %[[FLATTENED_RHS]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index
|
|
||||||
// Handle rank 1 specialization
|
|
||||||
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
|
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?xf32>
|
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
|
Loading…
Reference in New Issue