From e6a1f5f0f92fbbf3a1e756307524024de190cefd Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 1 Mar 2021 02:22:55 -0800 Subject: [PATCH] Add MinimumBroadcastShapesOp to chlo dialect. This op is useful for rank specialization of broadcasts. Kernel Generator needs to generate one kernel for each rank, so if we can minimize the rank of the broadcast shape, we can support more cases with the same number of special-cased kernels. PiperOrigin-RevId: 360137827 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 45 ++++++++++++++++++++ lib/Dialect/mhlo/IR/chlo_ops.cc | 20 +++++++++ tests/chlo_ops.mlir | 26 +++++++++++ 3 files changed, 91 insertions(+) create mode 100644 tests/chlo_ops.mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index c9db345..9a42d95 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -708,4 +708,49 @@ def HLOClient_BroadcastSelectOp : HLOClient_Op< }]; } +//===----------------------------------------------------------------------===// +// Helper ops +//===----------------------------------------------------------------------===// + +def HLOClient_MinimumBroadcastShapesOp : + HLOClient_Op<"minimum_broadcast_shapes", [NoSideEffect]> { + string summary = "Minimizes the rank of two or more shapes to be broadcasted"; + + string description = [{ + Given two or more 1D tensors representing shapes, returns one 1D tensor for + each operand, where operand `i` corresponds to output `i`. + + The returned tensors have the property that they specify a shape which is a + reshape of the corresponding input shape, and the broadcasted output shape + (using shape::BroadcastOp) of the returned shapes is a reshape of the + broadcasted output shape of the input shapes. Among all possibilities with + this property, the one is chosen which minimizes the rank of each returned + shape. + + The general idea of this op is that it can be used for ops which have a + broadcasting semantic to operate on shapes with a possibly smaller rank + while preserving equivalence of the computed values. After computing the + result of the op using reshaped operands, the result can be reshaped to the + result that would have been originally computed. + + Here is an example with two input shapes: + + ```mlir + chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1], + [1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3] + ``` + + The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the + broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are + reshapes of each other, and also each output is a reshape of the + corresponding input. + }]; + + let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes); + let results = (outs Variadic<1DTensorOf<[Index]>>:$results); + + let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)"; + +} + #endif // CHLO_OPS diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 57ae271..3a1caa9 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -337,6 +337,26 @@ static LogicalResult Verify(ConstantLikeOp op) { return success(); } +//===----------------------------------------------------------------------===// +// MinimumBroadcastShapesOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(MinimumBroadcastShapesOp op) { + // Check that the number of operands matches the number of outputs. + unsigned result_shapes_count = op.results().size(); + unsigned operand_shapes_count = op.shapes().size(); + if (operand_shapes_count != result_shapes_count) { + return op.emitOpError() + << "number of operand shapes (" << operand_shapes_count + << ") does not match number of result shapes (" + << result_shapes_count << ")"; + } + if (operand_shapes_count < 2) { + return op.emitOpError() << "number of operand shapes (" + << operand_shapes_count << ") should be >= 2"; + } + return success(); +} + LogicalResult ConstantLikeOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/tests/chlo_ops.mlir b/tests/chlo_ops.mlir new file mode 100644 index 0000000..a4d5f79 --- /dev/null +++ b/tests/chlo_ops.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s + +// CHECK-LABEL: func @minimum_broadcast_shapes +func @minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) + -> (tensor, tensor) { + %0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs : + tensor, tensor -> tensor, tensor + return %0, %1 : tensor, tensor +} + +// ----- + +func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor, %rhs: tensor) { + // expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}} + %0 = chlo.minimum_broadcast_shapes %lhs, %rhs : + tensor, tensor -> tensor + return +} + +// ----- + +func @minimum_broadcast_shapes_one_operand(%arg: tensor) { + // expected-error @+1{{number of operand shapes (1) should be >= 2}} + %0 = chlo.minimum_broadcast_shapes %arg : tensor -> tensor + return +}