diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 6530f7c..6ce42ed 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1473,8 +1473,8 @@ struct ReduceRegionXLAOpConversion : public OpConversionPattern { })) { return failure(); } - Value result = lmhlo::HloOpToStdScalarOp::map(op, args[0].getType(), - args, &rewriter); + Value result = lmhlo::HloOpToStdScalarOp::map( + op, getElementTypeOrSelf(op.getType()), args, &rewriter); rewriter.replaceOp(op, result); return success(); } @@ -1518,58 +1518,65 @@ class ReduceOnTensorsConversion : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { Location loc = op.getLoc(); mhlo::ReduceOp::Adaptor adaptor(args); - if (op.getNumOperands() != 2) { - return op.emitError("expects exactly two operands"); - } - Value src = adaptor.inputs()[0]; - auto src_type = src.getType().cast(); + + int num_inputs = static_cast(adaptor.inputs().size()); + auto src_type = adaptor.inputs()[0].getType().cast(); int src_rank = src_type.getRank(); if (!src_rank) { return rewriter.notifyMatchFailure(op, "expects known-rank args"); } - // Check if init_value is constant. If so, inline the value into the region. - Value init_value = adaptor.init_values()[0]; - Attribute init_const_val = GetInitValueAsConst(init_value); - if (init_const_val) { - init_value = rewriter.create( - init_value.getDefiningOp()->getLoc(), init_const_val); - } else { - init_value = rewriter.create(loc, init_value); + SmallVector reduction_dims = Extract1DVector(op.dimensions()); + + SmallVector inputs, outputs; + SmallVector indexing_maps; + for (int i = 0; i < num_inputs; ++i) { + Value src = adaptor.inputs()[i]; + if (src.getType() != src_type) return failure(); + + // Check if init_value is constant. If so, inline the value into the + // region. + Value init_value = adaptor.init_values()[i]; + Attribute init_const_val = GetInitValueAsConst(init_value); + if (init_const_val) { + init_value = rewriter.create( + init_value.getDefiningOp()->getLoc(), init_const_val); + } else { + init_value = rewriter.create(loc, init_value); + } + + inputs.push_back(src); + auto result_type = op.getResult(i).getType().cast(); + SmallVector dyn_shape = GetReduceOpInitTensorDynSizes( + rewriter, loc, src, result_type, reduction_dims); + auto init_tensor = GetInitTensor(rewriter, loc, result_type, dyn_shape); + Value filled_tensor = + rewriter.create(loc, init_tensor, init_value) + .result(); + outputs.push_back(filled_tensor); } - // Prepare indexing maps for linalg generic op. The elements are for src and - // dst. Transpose `src` to make the reduction loops be the innermost, + // Prepare indexing maps for linalg generic op. The elements are for src + // and dst. Transpose `src` to make the reduction loops be the innermost, // because it's easier to fully utilize processors. - SmallVector indexing_maps; - SmallVector reduction_dims = Extract1DVector(op.dimensions()); - indexing_maps.emplace_back(GetTransposeMapForReduction( - rewriter.getContext(), src_rank, reduction_dims)); + indexing_maps.append( + num_inputs, GetTransposeMapForReduction(rewriter.getContext(), src_rank, + reduction_dims)); // The indexing map of `dst` should drop the reduction loops. Since the // reduction loops now are all in the innermost, drops - // `reduction_dims.size()` dimensions. We don't need an inverse permutation - // here because they are the same. + // `reduction_dims.size()` dimensions. We don't need an inverse + // permutation here because they are the same. SmallVector exprs; for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i) exprs.push_back(rewriter.getAffineDimExpr(i)); - indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0, - exprs, rewriter.getContext())); - - SmallVector inputs = {adaptor.inputs()[0]}; - Type result_type = op.getResult(0).getType(); - auto shaped_type = result_type.cast(); - SmallVector dyn_shape = GetReduceOpInitTensorDynSizes( - rewriter, loc, adaptor.inputs()[0], result_type.cast(), - reduction_dims); - auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape); - Value filled_tensor = - rewriter.create(loc, init_tensor, init_value) - .getResult(0); + indexing_maps.append(num_inputs, + AffineMap::get(src_rank, /*symbolCount=*/0, exprs, + rewriter.getContext())); auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/op.getResultTypes(), inputs, - /*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps, + /*outputBuffers=*/ValueRange{outputs}, indexing_maps, GetParallelAndReductionIterators(src_rank, reduction_dims.size())); // Convert the signature of the body. The reduce op region apply function @@ -1579,10 +1586,10 @@ class ReduceOnTensorsConversion : public OpConversionPattern { // be converted to "(f32, f32, f32)". Region& region = linalg_op.region(); rewriter.inlineRegionBefore(op.body(), region, region.end()); - TypeConverter::SignatureConversion signatureConverter(2); - signatureConverter.addInputs(0, src_type.getElementType()); - signatureConverter.addInputs(1, src_type.getElementType()); - rewriter.applySignatureConversion(®ion, signatureConverter); + TypeConverter::SignatureConversion signature_converter(num_inputs * 2); + for (int i = 0; i < num_inputs * 2; ++i) + signature_converter.addInputs(i, src_type.getElementType()); + rewriter.applySignatureConversion(®ion, signature_converter); rewriter.replaceOp(op, linalg_op.getResults()); return success(); } @@ -2272,6 +2279,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, ReduceRegionXLAOpConversion, ReduceRegionXLAOpConversion, ReduceRegionXLAOpConversion, + ReduceRegionXLAOpConversion, + ReduceRegionXLAOpConversion, ReduceRegionReturnOpConversion>(context); } diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index b853fc1..d4ab359 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1506,6 +1506,49 @@ func @reduce_dynamic(%arg0: tensor, %arg1: tensor) -> tensor, %arg1: tensor<9x2xi32>) -> (tensor<2xi32>, tensor<2xi32>) { + %cst0 = mhlo.constant dense<-2147483648> : tensor + %cst1 = mhlo.constant dense<0> : tensor + %res0, %res1 = "mhlo.reduce"(%arg0, %arg1, %cst0, %cst1) ( { + ^bb0(%arg2: tensor, %arg3: tensor, %arg15: tensor, %arg16: tensor): // no predecessors + %669 = "mhlo.compare"(%arg2, %arg15) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + %670 = "mhlo.select"(%669, %arg2, %arg15) : (tensor, tensor, tensor) -> tensor + %671 = "mhlo.compare"(%arg2, %arg15) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + %672 = mhlo.minimum %arg3, %arg16 : tensor + %673 = "mhlo.select"(%669, %arg3, %arg16) : (tensor, tensor, tensor) -> tensor + %674 = "mhlo.select"(%671, %672, %673) : (tensor, tensor, tensor) -> tensor + "mhlo.return"(%670, %674) : (tensor, tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor, tensor) -> (tensor<2xi32>, tensor<2xi32>) + return %res0, %res1 : tensor<2xi32>, tensor<2xi32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK: func @variadic_reduce +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[CST0:.*]] = constant -2147483648 : i32 +// CHECK: %[[INIT0:.*]] = linalg.init_tensor [2] : tensor<2xi32> +// CHECK: %[[FILL0:.*]] = linalg.fill(%[[INIT0]], %[[CST0]]) +// CHECK: %[[CST1:.*]] = constant 0 : i32 +// CHECK: %[[INIT1:.*]] = linalg.init_tensor [2] : tensor<2xi32> +// CHECK: %[[FILL1:.*]] = linalg.fill(%[[INIT1]], %[[CST1]]) +// CHECK: %[[RES:.+]]:2 = linalg.generic { +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<9x2xi32>, tensor<9x2xi32>) +// CHECK-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<2xi32>, tensor<2xi32>) +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32): +// CHECK-NEXT: %[[T1:.*]] = cmpi sge, %[[IN0]], %[[OUT0]] : i32 +// CHECK-NEXT: %[[T2:.*]] = select %[[T1]], %[[IN0]], %[[OUT0]] : i32 +// CHECK-NEXT: %[[T3:.*]] = cmpi eq, %[[IN0]], %[[OUT0]] : i32 +// CHECK-NEXT: %[[T4:.*]] = cmpi slt, %[[IN1]], %[[OUT1]] : i32 +// CHECK-NEXT: %[[T5:.*]] = select %[[T4]], %[[IN1]], %[[OUT1]] : i32 +// CHECK-NEXT: %[[T6:.*]] = select %[[T1]], %[[IN1]], %[[OUT1]] : i32 +// CHECK-NEXT: %[[T7:.*]] = select %[[T3]], %[[T5]], %[[T6]] : i32 +// CHECK-NEXT: linalg.yield %[[T2]], %[[T7]] + +// ----- + func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { %0 = "mhlo.slice"(%arg0) { start_indices = dense<[1, 0]> : tensor<2xi64>,