[HLO][Linalg] Support scalar broadcasts in point-wise converter
This is needed for operations that support this limited form of broadcasting, namely `mhlo.select`. PiperOrigin-RevId: 376655844
This commit is contained in:
parent
832f39b871
commit
cc1b22e618
|
@ -231,76 +231,81 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|||
LogicalResult matchAndRewrite(
|
||||
OpTy op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
|
||||
if (!t0) return failure();
|
||||
|
||||
unsigned nloops = t0.getRank();
|
||||
auto fail = [&](ShapedType t) {
|
||||
return !t || !t.hasRank() || t.getRank() != nloops ||
|
||||
!(t.getElementType().isSignlessIntOrFloat() ||
|
||||
t.getElementType().isa<ComplexType>());
|
||||
// Find maximum rank / number of loops.
|
||||
auto get_rank = [](Value v) {
|
||||
return v.getType().cast<ShapedType>().getRank();
|
||||
};
|
||||
if (llvm::any_of(op.getOperation()->getResultTypes(), [&](Type t) {
|
||||
return fail(this->typeConverter->convertType(t)
|
||||
.template dyn_cast<ShapedType>());
|
||||
auto is_scalar = [&](Value v) { return get_rank(v) == 0; };
|
||||
auto it = llvm::find_if_not(args, is_scalar);
|
||||
Value max_rank_arg = it != args.end() ? *it : args.front();
|
||||
int64_t nloops = get_rank(max_rank_arg);
|
||||
|
||||
if (isLHLO && nloops == 0) return failure();
|
||||
|
||||
// Apply only if all operands are scalar or have the same rank. Some ops,
|
||||
// like `mhlo.select`, support implicit broadcasting of scalars.
|
||||
if (!llvm::all_of(args, [&](Value v) {
|
||||
int64_t r = get_rank(v);
|
||||
return r == 0 || r == nloops;
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Operands must be os same rank or scalar.");
|
||||
}
|
||||
|
||||
// Find result type, if on tensors.
|
||||
Optional<ShapedType> result_ty;
|
||||
if (!isLHLO) {
|
||||
result_ty = this->typeConverter->convertType(op->getResultTypes().front())
|
||||
.template dyn_cast<ShapedType>();
|
||||
|
||||
// Check result type compatibility.
|
||||
if (!result_ty || !result_ty->hasRank() ||
|
||||
result_ty->getRank() != nloops ||
|
||||
!(result_ty->getElementType().isSignlessIntOrFloat() ||
|
||||
result_ty->getElementType().isa<ComplexType>())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "mismatched operand/result types or iterator count");
|
||||
}
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
|
||||
|
||||
// This doesnt account for implicit broadcast, but the working assumption
|
||||
// in HLO/LHLO is that are broadcasts are made explicit.
|
||||
|
||||
if (isLHLO && !nloops) return failure();
|
||||
|
||||
int num_inputs = (isLHLO ? args.size() - 1 : args.size());
|
||||
|
||||
ValueRange inputs(args.take_front(num_inputs));
|
||||
for (Value in : inputs)
|
||||
body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
|
||||
|
||||
SmallVector<Value, 4> output_buffers;
|
||||
if (isLHLO) {
|
||||
output_buffers.append(args.begin() + num_inputs, args.end());
|
||||
} else {
|
||||
Type result_type = this->typeConverter->convertType(
|
||||
op.getOperation()->getResult(0).getType());
|
||||
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
|
||||
output_buffers.push_back(GetInitTensor(
|
||||
rewriter, loc, result_type.cast<ShapedType>(), dyn_sizes));
|
||||
op_result_types.push_back(result_type);
|
||||
}
|
||||
body_result_types = llvm::to_vector<4>(llvm::map_range(
|
||||
output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
|
||||
|
||||
AffineMap common_indexing_map =
|
||||
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
||||
: AffineMap::get(nloops, 0, rewriter.getContext());
|
||||
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
||||
common_indexing_map);
|
||||
// Find input/output values and types.
|
||||
auto loc = op.getLoc();
|
||||
ValueRange inputs = isLHLO ? args.drop_back() : args;
|
||||
Value output;
|
||||
if (isLHLO) {
|
||||
output = args.back();
|
||||
} else {
|
||||
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, max_rank_arg);
|
||||
output = GetInitTensor(rewriter, loc, *result_ty, dyn_sizes);
|
||||
}
|
||||
|
||||
// Create indexing maps.
|
||||
AffineMap scalar_map = AffineMap::get(nloops, 0, rewriter.getContext());
|
||||
AffineMap id_map = rewriter.getMultiDimIdentityMap(nloops);
|
||||
SmallVector<AffineMap, 4> maps;
|
||||
for (Value v : inputs) maps.push_back(is_scalar(v) ? scalar_map : id_map);
|
||||
maps.push_back(id_map);
|
||||
|
||||
// Build `linalg.generic` op.
|
||||
bool failed = false;
|
||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||
loc, op_result_types, inputs, output_buffers, indexing_maps,
|
||||
loc, result_ty ? *result_ty : TypeRange{}, inputs, output, maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
||||
// That method needs to be moved out of there.
|
||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||
op, body_result_types,
|
||||
Type inner_result_ty = getElementTypeOrSelf(output);
|
||||
Value inner_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||
op, inner_result_ty,
|
||||
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
||||
if (op_result == nullptr) {
|
||||
if (inner_result == nullptr) {
|
||||
failed = true;
|
||||
} else {
|
||||
nested_builder.create<linalg::YieldOp>(loc, op_result);
|
||||
nested_builder.create<linalg::YieldOp>(loc, inner_result);
|
||||
}
|
||||
});
|
||||
if (failed) return failure();
|
||||
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
||||
rewriter.replaceOp(op, linalg_op->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -425,6 +425,28 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[SCALAR_MAP:.*]] = affine_map<(d0, d1) -> ()>
|
||||
// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @select_scalar_pred_dyn
|
||||
// CHECK-SAME: (%[[PRED:.*]]: tensor<i1>, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>)
|
||||
func @select_scalar_pred_dyn(%pred : tensor<i1>, %lhs: tensor<2x?xf32>, %rhs: tensor<2x?xf32>) -> tensor<2x?xf32> {
|
||||
%0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<i1>, tensor<2x?xf32>, tensor<2x?xf32>) -> (tensor<2x?xf32>)
|
||||
return %0 : tensor<2x?xf32>
|
||||
}
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1
|
||||
// CHECK-DAG: %[[DIM:.*]] = memref.dim %[[LHS]], %[[C1]]
|
||||
// CHECK-DAG: %[[DST:.*]] = linalg.init_tensor [2, %[[DIM]]]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[SCALAR_MAP]], #[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[PRED]], %[[LHS]], %[[RHS]] : tensor<i1>, tensor<2x?xf32>, tensor<2x?xf32>)
|
||||
// CHECK-SAME: outs(%[[DST]] : tensor<2x?xf32>)
|
||||
// CHECK: ^bb0(%[[PRED_:.*]]: i1, %[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %{{.*}}: f32):
|
||||
// CHECK: %[[RES:.*]] = select %[[PRED_]], %[[LHS_]], %[[RHS_]] : f32
|
||||
// CHECK: linalg.yield %[[RES]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
|
|
Loading…
Reference in New Issue