[HLO][Linalg] Support scalar broadcasts in point-wise converter

This is needed for operations that support this limited form of broadcasting,
namely `mhlo.select`.

PiperOrigin-RevId: 376655844
This commit is contained in:
A. Unique TensorFlower 2021-05-31 03:49:28 -07:00 committed by TensorFlow MLIR Team
parent 832f39b871
commit cc1b22e618
2 changed files with 73 additions and 46 deletions

View File

@ -231,76 +231,81 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> args, OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc(); // Find maximum rank / number of loops.
ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>(); auto get_rank = [](Value v) {
if (!t0) return failure(); return v.getType().cast<ShapedType>().getRank();
unsigned nloops = t0.getRank();
auto fail = [&](ShapedType t) {
return !t || !t.hasRank() || t.getRank() != nloops ||
!(t.getElementType().isSignlessIntOrFloat() ||
t.getElementType().isa<ComplexType>());
}; };
if (llvm::any_of(op.getOperation()->getResultTypes(), [&](Type t) { auto is_scalar = [&](Value v) { return get_rank(v) == 0; };
return fail(this->typeConverter->convertType(t) auto it = llvm::find_if_not(args, is_scalar);
.template dyn_cast<ShapedType>()); Value max_rank_arg = it != args.end() ? *it : args.front();
int64_t nloops = get_rank(max_rank_arg);
if (isLHLO && nloops == 0) return failure();
// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(args, [&](Value v) {
int64_t r = get_rank(v);
return r == 0 || r == nloops;
})) { })) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count"); op, "Operands must be os same rank or scalar.");
} }
// Construct the indexing maps needed for linalg.generic ops. // Find result type, if on tensors.
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types; Optional<ShapedType> result_ty;
if (!isLHLO) {
result_ty = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();
// This doesnt account for implicit broadcast, but the working assumption // Check result type compatibility.
// in HLO/LHLO is that are broadcasts are made explicit. if (!result_ty || !result_ty->hasRank() ||
result_ty->getRank() != nloops ||
!(result_ty->getElementType().isSignlessIntOrFloat() ||
result_ty->getElementType().isa<ComplexType>())) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}
}
if (isLHLO && !nloops) return failure(); // Find input/output values and types.
auto loc = op.getLoc();
int num_inputs = (isLHLO ? args.size() - 1 : args.size()); ValueRange inputs = isLHLO ? args.drop_back() : args;
Value output;
ValueRange inputs(args.take_front(num_inputs));
for (Value in : inputs)
body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
SmallVector<Value, 4> output_buffers;
if (isLHLO) { if (isLHLO) {
output_buffers.append(args.begin() + num_inputs, args.end()); output = args.back();
} else { } else {
Type result_type = this->typeConverter->convertType( auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, max_rank_arg);
op.getOperation()->getResult(0).getType()); output = GetInitTensor(rewriter, loc, *result_ty, dyn_sizes);
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
output_buffers.push_back(GetInitTensor(
rewriter, loc, result_type.cast<ShapedType>(), dyn_sizes));
op_result_types.push_back(result_type);
} }
body_result_types = llvm::to_vector<4>(llvm::map_range(
output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
AffineMap common_indexing_map = // Create indexing maps.
nloops ? rewriter.getMultiDimIdentityMap(nloops) AffineMap scalar_map = AffineMap::get(nloops, 0, rewriter.getContext());
: AffineMap::get(nloops, 0, rewriter.getContext()); AffineMap id_map = rewriter.getMultiDimIdentityMap(nloops);
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1), SmallVector<AffineMap, 4> maps;
common_indexing_map); for (Value v : inputs) maps.push_back(is_scalar(v) ? scalar_map : id_map);
maps.push_back(id_map);
// Build `linalg.generic` op.
bool failed = false; bool failed = false;
auto linalg_op = rewriter.create<linalg::GenericOp>( auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, op_result_types, inputs, output_buffers, indexing_maps, loc, result_ty ? *result_ty : TypeRange{}, inputs, output, maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace. // TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there. // That method needs to be moved out of there.
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>( Type inner_result_ty = getElementTypeOrSelf(output);
op, body_result_types, Value inner_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, inner_result_ty,
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
if (op_result == nullptr) { if (inner_result == nullptr) {
failed = true; failed = true;
} else { } else {
nested_builder.create<linalg::YieldOp>(loc, op_result); nested_builder.create<linalg::YieldOp>(loc, inner_result);
} }
}); });
if (failed) return failure(); if (failed) return failure();
rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); rewriter.replaceOp(op, linalg_op->getResults());
return success(); return success();
} }
}; };

View File

@ -425,6 +425,28 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
// ----- // -----
// CHECK-DAG: #[[SCALAR_MAP:.*]] = affine_map<(d0, d1) -> ()>
// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @select_scalar_pred_dyn
// CHECK-SAME: (%[[PRED:.*]]: tensor<i1>, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>)
func @select_scalar_pred_dyn(%pred : tensor<i1>, %lhs: tensor<2x?xf32>, %rhs: tensor<2x?xf32>) -> tensor<2x?xf32> {
%0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<i1>, tensor<2x?xf32>, tensor<2x?xf32>) -> (tensor<2x?xf32>)
return %0 : tensor<2x?xf32>
}
// CHECK-DAG: %[[C1:.*]] = constant 1
// CHECK-DAG: %[[DIM:.*]] = memref.dim %[[LHS]], %[[C1]]
// CHECK-DAG: %[[DST:.*]] = linalg.init_tensor [2, %[[DIM]]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[SCALAR_MAP]], #[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[PRED]], %[[LHS]], %[[RHS]] : tensor<i1>, tensor<2x?xf32>, tensor<2x?xf32>)
// CHECK-SAME: outs(%[[DST]] : tensor<2x?xf32>)
// CHECK: ^bb0(%[[PRED_:.*]]: i1, %[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[RES:.*]] = select %[[PRED_]], %[[LHS_]], %[[RHS_]] : f32
// CHECK: linalg.yield %[[RES]]
// -----
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar // CHECK-LABEL: func @broadcast_scalar