[MLIR][HLO] Add rank specialization with multiple non-scalar operands
Add lowering pattern for rank specialization clusters with more than one non-scalar operand. The lowering resembles that of the `TransformUnrankedHlo` pass and switches cases for maximal ranks from 1 through 8. PiperOrigin-RevId: 374377002
This commit is contained in:
parent
0168484eed
commit
6af3d2df91
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -298,11 +299,160 @@ struct LowerSingleNonScalarOperandPattern
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Value MaterializeRankSpecialization(OpBuilder &b, Location loc,
|
||||||
|
chlo::RankSpecializationClusterOp op,
|
||||||
|
const SmallVector<Value, 8> &shapes,
|
||||||
|
int64_t target_rank) {
|
||||||
|
// Reshape operands to match the target rank.
|
||||||
|
MLIRContext *ctx = op->getContext();
|
||||||
|
llvm::SmallVector<int64_t, 8> ranked_ty_dynamic_dims(
|
||||||
|
target_rank, RankedTensorType::kDynamicSize);
|
||||||
|
RankedTensorType extent_tensor_ty =
|
||||||
|
shape::getExtentTensorType(ctx, target_rank);
|
||||||
|
Value all_ones_shape = b.create<shape::ConstShapeOp>(
|
||||||
|
loc, extent_tensor_ty,
|
||||||
|
mlir::DenseIntElementsAttr::get(extent_tensor_ty,
|
||||||
|
SmallVector<int64_t, 6>(target_rank, 1)));
|
||||||
|
SmallVector<Value, 2> ranked_operands;
|
||||||
|
for (auto it : llvm::zip(op.operands(), shapes)) {
|
||||||
|
Value operand, shape;
|
||||||
|
std::tie(operand, shape) = it;
|
||||||
|
Value ranked_shape = b.create<tensor::CastOp>(
|
||||||
|
loc, extent_tensor_ty,
|
||||||
|
b.create<shape::BroadcastOp>(loc, shape::getExtentTensorType(ctx),
|
||||||
|
shape, all_ones_shape,
|
||||||
|
/*error=*/nullptr));
|
||||||
|
Type element_ty = operand.getType().dyn_cast<TensorType>().getElementType();
|
||||||
|
auto ranked_ty = RankedTensorType::get(ranked_ty_dynamic_dims, element_ty);
|
||||||
|
ranked_operands.push_back(b.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, ranked_ty, operand, ranked_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Materialize ranked versions of the element-wise operations.
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
for (auto it : llvm::zip(op.body().front().getArguments(), ranked_operands))
|
||||||
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||||
|
|
||||||
|
// Return as unranked for compatibility with other target ranks.
|
||||||
|
auto unshaped_result =
|
||||||
|
MaterializeRankedOperations(b, loc, bvm, op, target_rank).front();
|
||||||
|
return b.create<tensor::CastOp>(
|
||||||
|
loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
|
||||||
|
unshaped_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value MaterializeAllRankSpecializations(OpBuilder &b, Location loc,
|
||||||
|
chlo::RankSpecializationClusterOp op,
|
||||||
|
const SmallVector<Value, 8> &shapes,
|
||||||
|
Value max_rank, int64_t min_target_rank,
|
||||||
|
int64_t max_target_rank) {
|
||||||
|
Value min_target_rank_predicate =
|
||||||
|
b.create<CmpIOp>(loc, CmpIPredicate::eq, max_rank,
|
||||||
|
b.create<ConstantIndexOp>(loc, min_target_rank));
|
||||||
|
|
||||||
|
// If only a unique target rank is left, we can lower to an assert instead
|
||||||
|
// of the usual if operation.
|
||||||
|
if (min_target_rank == max_target_rank) {
|
||||||
|
b.create<AssertOp>(loc, min_target_rank_predicate,
|
||||||
|
"Input for dynamic binary or n-ary op lowering was of "
|
||||||
|
"a rank greater than " +
|
||||||
|
std::to_string(max_target_rank));
|
||||||
|
return MaterializeRankSpecialization(b, loc, op, shapes, min_target_rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Materialize IR for the smallest considered target rank.
|
||||||
|
auto if_op =
|
||||||
|
b.create<scf::IfOp>(loc, op->getResultTypes(), min_target_rank_predicate,
|
||||||
|
/*withElseRegion=*/true);
|
||||||
|
auto then_builder = if_op.getThenBodyBuilder();
|
||||||
|
then_builder.create<scf::YieldOp>(
|
||||||
|
loc, MaterializeRankSpecialization(then_builder, loc, op, shapes,
|
||||||
|
min_target_rank));
|
||||||
|
|
||||||
|
// Recur for all remaining target ranks.
|
||||||
|
auto else_builder = if_op.getElseBodyBuilder();
|
||||||
|
else_builder.create<scf::YieldOp>(
|
||||||
|
loc,
|
||||||
|
MaterializeAllRankSpecializations(else_builder, loc, op, shapes, max_rank,
|
||||||
|
min_target_rank + 1, max_target_rank));
|
||||||
|
|
||||||
|
return if_op.results().front();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LowerMultipleNonScalarOperandPattern
|
||||||
|
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||||
|
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// We have a specialized pattern for the case in which all but one operands
|
||||||
|
// are scalars.
|
||||||
|
if (FindUniqueNonScalar(op.operands())) return failure();
|
||||||
|
|
||||||
|
// Restoring the result shape currently relies on all operands being used
|
||||||
|
// for a single result. The result shape is then the broadcasted shape of
|
||||||
|
// all operands.
|
||||||
|
if (op.getNumResults() != 1) return failure();
|
||||||
|
|
||||||
|
// Get the minimum broadcast shapes of the operands.
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
SmallVector<Value, 8> shapes =
|
||||||
|
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
|
return rewriter.create<shape::ShapeOfOp>(loc, v).result();
|
||||||
|
}));
|
||||||
|
ValueRange reduced_shapes =
|
||||||
|
rewriter
|
||||||
|
.create<chlo::MinimumBroadcastShapesOp>(
|
||||||
|
loc,
|
||||||
|
SmallVector<Type, 8>(shapes.size(),
|
||||||
|
shape::getExtentTensorType(getContext())),
|
||||||
|
shapes)
|
||||||
|
.results();
|
||||||
|
// TODO(frgossen): Avoid this reshape if it is redundant in all cases.
|
||||||
|
SmallVector<Value, 8> reshaped_args;
|
||||||
|
for (auto it : llvm::zip(op.operands(), reduced_shapes)) {
|
||||||
|
Value arg = std::get<0>(it);
|
||||||
|
Value reduced_shape = std::get<1>(it);
|
||||||
|
reshaped_args.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, arg.getType(), arg, reduced_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the maximum rank among the reduced operand shapes.
|
||||||
|
Value max_rank;
|
||||||
|
for (Value shape : reduced_shapes) {
|
||||||
|
Value rank =
|
||||||
|
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
|
||||||
|
if (!max_rank) {
|
||||||
|
max_rank = rank;
|
||||||
|
} else {
|
||||||
|
max_rank = rewriter.create<mlir::SelectOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, max_rank, rank),
|
||||||
|
max_rank, rank);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Materialize rank specialization for ranks 1, ..., 8.
|
||||||
|
// TODO(frgossen): For clusters w/o a select operation, consider only ranks
|
||||||
|
// 1, ..., 5.
|
||||||
|
const int64_t kMinTargetRank = 1;
|
||||||
|
const int64_t kMaxTargetRank = 8;
|
||||||
|
Value unshaped_result = MaterializeAllRankSpecializations(
|
||||||
|
rewriter, loc, op, reduced_shapes, max_rank, kMinTargetRank,
|
||||||
|
kMaxTargetRank);
|
||||||
|
|
||||||
|
// Materialize final reshape once and for all rank specialization cases.
|
||||||
|
rewriter.replaceOp(
|
||||||
|
op, MaterializeFinalReshape(rewriter, loc, op, unshaped_result));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct RankSpecializationToSCFPass
|
struct RankSpecializationToSCFPass
|
||||||
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> {
|
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
|
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
|
||||||
shape::ShapeDialect>();
|
shape::ShapeDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
|
@ -325,7 +475,8 @@ void PopulateRankSpecializationClusterPatterns(
|
||||||
|
|
||||||
void PopulateRankSpecializationToSCFPatterns(
|
void PopulateRankSpecializationToSCFPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
patterns->insert<LowerSingleNonScalarOperandPattern>(context);
|
patterns->insert<LowerSingleNonScalarOperandPattern,
|
||||||
|
LowerMultipleNonScalarOperandPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
||||||
|
|
|
@ -19,6 +19,185 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
return %1 : tensor<*xf32>
|
return %1 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @add_mul
|
||||||
|
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
|
||||||
|
// CHECK-SCF-DAG: %[[C2:.*]] = constant 2
|
||||||
|
// CHECK-SCF-DAG: %[[C3:.*]] = constant 3
|
||||||
|
// CHECK-SCF-DAG: %[[C4:.*]] = constant 4
|
||||||
|
// CHECK-SCF-DAG: %[[C5:.*]] = constant 5
|
||||||
|
// CHECK-SCF-DAG: %[[C6:.*]] = constant 6
|
||||||
|
// CHECK-SCF-DAG: %[[C7:.*]] = constant 7
|
||||||
|
// CHECK-SCF-DAG: %[[C8:.*]] = constant 8
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]]
|
||||||
|
// Find maximum reduced rank.
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#1
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#2
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_RANK2:.*]] = shape.rank %[[REDUCED_SHAPES]]#0
|
||||||
|
// CHECK-SCF-DAG: %[[R2_GT_R0:.*]] = cmpi sgt, %[[REDUCED_RANK2]], %[[REDUCED_RANK0]]
|
||||||
|
// CHECK-SCF-DAG: %[[R20:.*]] = select %[[R2_GT_R0]], %[[REDUCED_RANK2]], %[[REDUCED_RANK0]]
|
||||||
|
// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]]
|
||||||
|
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %15, %[[REDUCED_RANK1]]
|
||||||
|
// Case 1:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||||
|
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?xf32>, tensor<?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Case 2:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||||
|
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?xf32>, tensor<?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?xf32>, tensor<?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Case 3:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||||
|
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Case 4:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||||
|
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Case 5:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||||
|
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Case 6:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||||
|
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Case 7:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||||
|
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Case 8:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
|
||||||
|
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||||
|
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
|
||||||
|
// Reshape the result.
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]]
|
||||||
|
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_1]], %[[RES_SHAPE]])
|
||||||
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Unary MHLO operation.
|
// Unary MHLO operation.
|
||||||
|
@ -64,6 +243,10 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||||
return %2 : tensor<3x?xf32>
|
return %2 : tensor<3x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @sqrt_ranked
|
||||||
|
// CHECK-SCF-NOT: dynamic_reshape
|
||||||
|
// CHECK-SCF: return
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Ternary operation.
|
// Ternary operation.
|
||||||
|
@ -81,6 +264,9 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @select_mixed
|
||||||
|
// CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Unary CHLO operation.
|
// Unary CHLO operation.
|
||||||
|
@ -141,6 +327,14 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>)
|
||||||
return %5 : tensor<*xf32>
|
return %5 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @mixed
|
||||||
|
// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor<?xf32>
|
||||||
|
// CHECK-SCF-DAG: %[[TMP1:.*]] = "mhlo.sqrt"(%{{.*}}) : (tensor<?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor<?xf32>, tensor<?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor<?xf32>, tensor<?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) : (tensor<?xf32>)
|
||||||
|
// CHECK-SCF: chlo.tan %[[TMP4]] : tensor<?xf32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Constant cluster operand.
|
// Constant cluster operand.
|
||||||
|
@ -228,3 +422,9 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
: (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
: (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %5 : tensor<*xf32>
|
return %5 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @xlogy
|
||||||
|
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %{{.*}}, %{{.*}} {comparison_direction = "EQ"} : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
|
|
Loading…
Reference in New Issue