diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index c9db345..9a42d95 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -708,4 +708,49 @@ def HLOClient_BroadcastSelectOp : HLOClient_Op< }]; } +//===----------------------------------------------------------------------===// +// Helper ops +//===----------------------------------------------------------------------===// + +def HLOClient_MinimumBroadcastShapesOp : + HLOClient_Op<"minimum_broadcast_shapes", [NoSideEffect]> { + string summary = "Minimizes the rank of two or more shapes to be broadcasted"; + + string description = [{ + Given two or more 1D tensors representing shapes, returns one 1D tensor for + each operand, where operand `i` corresponds to output `i`. + + The returned tensors have the property that they specify a shape which is a + reshape of the corresponding input shape, and the broadcasted output shape + (using shape::BroadcastOp) of the returned shapes is a reshape of the + broadcasted output shape of the input shapes. Among all possibilities with + this property, the one is chosen which minimizes the rank of each returned + shape. + + The general idea of this op is that it can be used for ops which have a + broadcasting semantic to operate on shapes with a possibly smaller rank + while preserving equivalence of the computed values. After computing the + result of the op using reshaped operands, the result can be reshaped to the + result that would have been originally computed. + + Here is an example with two input shapes: + + ```mlir + chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1], + [1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3] + ``` + + The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the + broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are + reshapes of each other, and also each output is a reshape of the + corresponding input. + }]; + + let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes); + let results = (outs Variadic<1DTensorOf<[Index]>>:$results); + + let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)"; + +} + #endif // CHLO_OPS diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 57ae271..3a1caa9 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -337,6 +337,26 @@ static LogicalResult Verify(ConstantLikeOp op) { return success(); } +//===----------------------------------------------------------------------===// +// MinimumBroadcastShapesOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(MinimumBroadcastShapesOp op) { + // Check that the number of operands matches the number of outputs. + unsigned result_shapes_count = op.results().size(); + unsigned operand_shapes_count = op.shapes().size(); + if (operand_shapes_count != result_shapes_count) { + return op.emitOpError() + << "number of operand shapes (" << operand_shapes_count + << ") does not match number of result shapes (" + << result_shapes_count << ")"; + } + if (operand_shapes_count < 2) { + return op.emitOpError() << "number of operand shapes (" + << operand_shapes_count << ") should be >= 2"; + } + return success(); +} + LogicalResult ConstantLikeOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/tests/chlo_ops.mlir b/tests/chlo_ops.mlir new file mode 100644 index 0000000..a4d5f79 --- /dev/null +++ b/tests/chlo_ops.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s + +// CHECK-LABEL: func @minimum_broadcast_shapes +func @minimum_broadcast_shapes(%lhs: tensor, %rhs: tensor) + -> (tensor, tensor) { + %0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs : + tensor, tensor -> tensor, tensor + return %0, %1 : tensor, tensor +} + +// ----- + +func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor, %rhs: tensor) { + // expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}} + %0 = chlo.minimum_broadcast_shapes %lhs, %rhs : + tensor, tensor -> tensor + return +} + +// ----- + +func @minimum_broadcast_shapes_one_operand(%arg: tensor) { + // expected-error @+1{{number of operand shapes (1) should be >= 2}} + %0 = chlo.minimum_broadcast_shapes %arg : tensor -> tensor + return +}