diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 07cf259..c49b5a9 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -231,76 +231,81 @@ class PointwiseToLinalgConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - auto loc = op.getLoc(); - ShapedType t0 = args[0].getType().template dyn_cast(); - if (!t0) return failure(); - - unsigned nloops = t0.getRank(); - auto fail = [&](ShapedType t) { - return !t || !t.hasRank() || t.getRank() != nloops || - !(t.getElementType().isSignlessIntOrFloat() || - t.getElementType().isa()); + // Find maximum rank / number of loops. + auto get_rank = [](Value v) { + return v.getType().cast().getRank(); }; - if (llvm::any_of(op.getOperation()->getResultTypes(), [&](Type t) { - return fail(this->typeConverter->convertType(t) - .template dyn_cast()); + auto is_scalar = [&](Value v) { return get_rank(v) == 0; }; + auto it = llvm::find_if_not(args, is_scalar); + Value max_rank_arg = it != args.end() ? *it : args.front(); + int64_t nloops = get_rank(max_rank_arg); + + if (isLHLO && nloops == 0) return failure(); + + // Apply only if all operands are scalar or have the same rank. Some ops, + // like `mhlo.select`, support implicit broadcasting of scalars. + if (!llvm::all_of(args, [&](Value v) { + int64_t r = get_rank(v); + return r == 0 || r == nloops; })) { return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); + op, "Operands must be os same rank or scalar."); } - // Construct the indexing maps needed for linalg.generic ops. - SmallVector body_arg_types, body_result_types, op_result_types; + // Find result type, if on tensors. + Optional result_ty; + if (!isLHLO) { + result_ty = this->typeConverter->convertType(op->getResultTypes().front()) + .template dyn_cast(); - // This doesnt account for implicit broadcast, but the working assumption - // in HLO/LHLO is that are broadcasts are made explicit. + // Check result type compatibility. + if (!result_ty || !result_ty->hasRank() || + result_ty->getRank() != nloops || + !(result_ty->getElementType().isSignlessIntOrFloat() || + result_ty->getElementType().isa())) { + return rewriter.notifyMatchFailure( + op, "mismatched operand/result types or iterator count"); + } + } - if (isLHLO && !nloops) return failure(); - - int num_inputs = (isLHLO ? args.size() - 1 : args.size()); - - ValueRange inputs(args.take_front(num_inputs)); - for (Value in : inputs) - body_arg_types.emplace_back(getElementTypeOrSelf(in.getType())); - - SmallVector output_buffers; + // Find input/output values and types. + auto loc = op.getLoc(); + ValueRange inputs = isLHLO ? args.drop_back() : args; + Value output; if (isLHLO) { - output_buffers.append(args.begin() + num_inputs, args.end()); + output = args.back(); } else { - Type result_type = this->typeConverter->convertType( - op.getOperation()->getResult(0).getType()); - auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]); - output_buffers.push_back(GetInitTensor( - rewriter, loc, result_type.cast(), dyn_sizes)); - op_result_types.push_back(result_type); + auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, max_rank_arg); + output = GetInitTensor(rewriter, loc, *result_ty, dyn_sizes); } - body_result_types = llvm::to_vector<4>(llvm::map_range( - output_buffers, [](Value v) { return getElementTypeOrSelf(v); })); - AffineMap common_indexing_map = - nloops ? rewriter.getMultiDimIdentityMap(nloops) - : AffineMap::get(nloops, 0, rewriter.getContext()); - SmallVector indexing_maps(args.size() + (isLHLO ? 0 : 1), - common_indexing_map); + // Create indexing maps. + AffineMap scalar_map = AffineMap::get(nloops, 0, rewriter.getContext()); + AffineMap id_map = rewriter.getMultiDimIdentityMap(nloops); + SmallVector maps; + for (Value v : inputs) maps.push_back(is_scalar(v) ? scalar_map : id_map); + maps.push_back(id_map); + // Build `linalg.generic` op. bool failed = false; auto linalg_op = rewriter.create( - loc, op_result_types, inputs, output_buffers, indexing_maps, + loc, result_ty ? *result_ty : TypeRange{}, inputs, output, maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. - Value op_result = lmhlo::HloOpToStdScalarOp::map( - op, body_result_types, + Type inner_result_ty = getElementTypeOrSelf(output); + Value inner_result = lmhlo::HloOpToStdScalarOp::map( + op, inner_result_ty, llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); - if (op_result == nullptr) { + if (inner_result == nullptr) { failed = true; } else { - nested_builder.create(loc, op_result); + nested_builder.create(loc, inner_result); } }); if (failed) return failure(); - rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); + rewriter.replaceOp(op, linalg_op->getResults()); return success(); } }; diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index c49cf55..b8bc527 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -425,6 +425,28 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, // ----- +// CHECK-DAG: #[[SCALAR_MAP:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @select_scalar_pred_dyn +// CHECK-SAME: (%[[PRED:.*]]: tensor, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>) +func @select_scalar_pred_dyn(%pred : tensor, %lhs: tensor<2x?xf32>, %rhs: tensor<2x?xf32>) -> tensor<2x?xf32> { + %0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor, tensor<2x?xf32>, tensor<2x?xf32>) -> (tensor<2x?xf32>) + return %0 : tensor<2x?xf32> +} +// CHECK-DAG: %[[C1:.*]] = constant 1 +// CHECK-DAG: %[[DIM:.*]] = memref.dim %[[LHS]], %[[C1]] +// CHECK-DAG: %[[DST:.*]] = linalg.init_tensor [2, %[[DIM]]] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[SCALAR_MAP]], #[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[PRED]], %[[LHS]], %[[RHS]] : tensor, tensor<2x?xf32>, tensor<2x?xf32>) +// CHECK-SAME: outs(%[[DST]] : tensor<2x?xf32>) +// CHECK: ^bb0(%[[PRED_:.*]]: i1, %[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %{{.*}}: f32): +// CHECK: %[[RES:.*]] = select %[[PRED_]], %[[LHS_]], %[[RHS_]] : f32 +// CHECK: linalg.yield %[[RES]] + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar