[HLO][Linalg] Support scalar broadcasts in point-wise converter
This is needed for operations that support this limited form of broadcasting, namely `mhlo.select`. PiperOrigin-RevId: 376655844
This commit is contained in:
parent
832f39b871
commit
cc1b22e618
|
@ -231,76 +231,81 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
OpTy op, ArrayRef<Value> args,
|
OpTy op, ArrayRef<Value> args,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto loc = op.getLoc();
|
// Find maximum rank / number of loops.
|
||||||
ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
|
auto get_rank = [](Value v) {
|
||||||
if (!t0) return failure();
|
return v.getType().cast<ShapedType>().getRank();
|
||||||
|
|
||||||
unsigned nloops = t0.getRank();
|
|
||||||
auto fail = [&](ShapedType t) {
|
|
||||||
return !t || !t.hasRank() || t.getRank() != nloops ||
|
|
||||||
!(t.getElementType().isSignlessIntOrFloat() ||
|
|
||||||
t.getElementType().isa<ComplexType>());
|
|
||||||
};
|
};
|
||||||
if (llvm::any_of(op.getOperation()->getResultTypes(), [&](Type t) {
|
auto is_scalar = [&](Value v) { return get_rank(v) == 0; };
|
||||||
return fail(this->typeConverter->convertType(t)
|
auto it = llvm::find_if_not(args, is_scalar);
|
||||||
.template dyn_cast<ShapedType>());
|
Value max_rank_arg = it != args.end() ? *it : args.front();
|
||||||
|
int64_t nloops = get_rank(max_rank_arg);
|
||||||
|
|
||||||
|
if (isLHLO && nloops == 0) return failure();
|
||||||
|
|
||||||
|
// Apply only if all operands are scalar or have the same rank. Some ops,
|
||||||
|
// like `mhlo.select`, support implicit broadcasting of scalars.
|
||||||
|
if (!llvm::all_of(args, [&](Value v) {
|
||||||
|
int64_t r = get_rank(v);
|
||||||
|
return r == 0 || r == nloops;
|
||||||
})) {
|
})) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "mismatched operand/result types or iterator count");
|
op, "Operands must be os same rank or scalar.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct the indexing maps needed for linalg.generic ops.
|
// Find result type, if on tensors.
|
||||||
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
|
Optional<ShapedType> result_ty;
|
||||||
|
if (!isLHLO) {
|
||||||
|
result_ty = this->typeConverter->convertType(op->getResultTypes().front())
|
||||||
|
.template dyn_cast<ShapedType>();
|
||||||
|
|
||||||
// This doesnt account for implicit broadcast, but the working assumption
|
// Check result type compatibility.
|
||||||
// in HLO/LHLO is that are broadcasts are made explicit.
|
if (!result_ty || !result_ty->hasRank() ||
|
||||||
|
result_ty->getRank() != nloops ||
|
||||||
|
!(result_ty->getElementType().isSignlessIntOrFloat() ||
|
||||||
|
result_ty->getElementType().isa<ComplexType>())) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "mismatched operand/result types or iterator count");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (isLHLO && !nloops) return failure();
|
// Find input/output values and types.
|
||||||
|
auto loc = op.getLoc();
|
||||||
int num_inputs = (isLHLO ? args.size() - 1 : args.size());
|
ValueRange inputs = isLHLO ? args.drop_back() : args;
|
||||||
|
Value output;
|
||||||
ValueRange inputs(args.take_front(num_inputs));
|
|
||||||
for (Value in : inputs)
|
|
||||||
body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
|
|
||||||
|
|
||||||
SmallVector<Value, 4> output_buffers;
|
|
||||||
if (isLHLO) {
|
if (isLHLO) {
|
||||||
output_buffers.append(args.begin() + num_inputs, args.end());
|
output = args.back();
|
||||||
} else {
|
} else {
|
||||||
Type result_type = this->typeConverter->convertType(
|
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, max_rank_arg);
|
||||||
op.getOperation()->getResult(0).getType());
|
output = GetInitTensor(rewriter, loc, *result_ty, dyn_sizes);
|
||||||
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
|
|
||||||
output_buffers.push_back(GetInitTensor(
|
|
||||||
rewriter, loc, result_type.cast<ShapedType>(), dyn_sizes));
|
|
||||||
op_result_types.push_back(result_type);
|
|
||||||
}
|
}
|
||||||
body_result_types = llvm::to_vector<4>(llvm::map_range(
|
|
||||||
output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
|
|
||||||
|
|
||||||
AffineMap common_indexing_map =
|
// Create indexing maps.
|
||||||
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
AffineMap scalar_map = AffineMap::get(nloops, 0, rewriter.getContext());
|
||||||
: AffineMap::get(nloops, 0, rewriter.getContext());
|
AffineMap id_map = rewriter.getMultiDimIdentityMap(nloops);
|
||||||
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
SmallVector<AffineMap, 4> maps;
|
||||||
common_indexing_map);
|
for (Value v : inputs) maps.push_back(is_scalar(v) ? scalar_map : id_map);
|
||||||
|
maps.push_back(id_map);
|
||||||
|
|
||||||
|
// Build `linalg.generic` op.
|
||||||
bool failed = false;
|
bool failed = false;
|
||||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||||
loc, op_result_types, inputs, output_buffers, indexing_maps,
|
loc, result_ty ? *result_ty : TypeRange{}, inputs, output, maps,
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||||
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
||||||
// That method needs to be moved out of there.
|
// That method needs to be moved out of there.
|
||||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
Type inner_result_ty = getElementTypeOrSelf(output);
|
||||||
op, body_result_types,
|
Value inner_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||||
|
op, inner_result_ty,
|
||||||
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
||||||
if (op_result == nullptr) {
|
if (inner_result == nullptr) {
|
||||||
failed = true;
|
failed = true;
|
||||||
} else {
|
} else {
|
||||||
nested_builder.create<linalg::YieldOp>(loc, op_result);
|
nested_builder.create<linalg::YieldOp>(loc, inner_result);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
if (failed) return failure();
|
if (failed) return failure();
|
||||||
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
rewriter.replaceOp(op, linalg_op->getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -425,6 +425,28 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[SCALAR_MAP:.*]] = affine_map<(d0, d1) -> ()>
|
||||||
|
// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @select_scalar_pred_dyn
|
||||||
|
// CHECK-SAME: (%[[PRED:.*]]: tensor<i1>, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>)
|
||||||
|
func @select_scalar_pred_dyn(%pred : tensor<i1>, %lhs: tensor<2x?xf32>, %rhs: tensor<2x?xf32>) -> tensor<2x?xf32> {
|
||||||
|
%0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<i1>, tensor<2x?xf32>, tensor<2x?xf32>) -> (tensor<2x?xf32>)
|
||||||
|
return %0 : tensor<2x?xf32>
|
||||||
|
}
|
||||||
|
// CHECK-DAG: %[[C1:.*]] = constant 1
|
||||||
|
// CHECK-DAG: %[[DIM:.*]] = memref.dim %[[LHS]], %[[C1]]
|
||||||
|
// CHECK-DAG: %[[DST:.*]] = linalg.init_tensor [2, %[[DIM]]]
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[SCALAR_MAP]], #[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]]
|
||||||
|
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
|
||||||
|
// CHECK-SAME: ins(%[[PRED]], %[[LHS]], %[[RHS]] : tensor<i1>, tensor<2x?xf32>, tensor<2x?xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[DST]] : tensor<2x?xf32>)
|
||||||
|
// CHECK: ^bb0(%[[PRED_:.*]]: i1, %[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %{{.*}}: f32):
|
||||||
|
// CHECK: %[[RES:.*]] = select %[[PRED_]], %[[LHS_]], %[[RHS_]] : f32
|
||||||
|
// CHECK: linalg.yield %[[RES]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()>
|
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()>
|
||||||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
// CHECK-LABEL: func @broadcast_scalar
|
// CHECK-LABEL: func @broadcast_scalar
|
||||||
|
|
Loading…
Reference in New Issue