Add MinimumBroadcastShapesOp to chlo dialect.

This op is useful for rank specialization of broadcasts. Kernel Generator
needs to generate one kernel for each rank, so if we can minimize the rank
of the broadcast shape, we can support more cases with the same number of
special-cased kernels.

PiperOrigin-RevId: 360137827
This commit is contained in:
Adrian Kuegel 2021-03-01 02:22:55 -08:00 committed by TensorFlow MLIR Team
parent 2d818c4fd9
commit e6a1f5f0f9
3 changed files with 91 additions and 0 deletions

View File

@ -708,4 +708,49 @@ def HLOClient_BroadcastSelectOp : HLOClient_Op<
}];
}
//===----------------------------------------------------------------------===//
// Helper ops
//===----------------------------------------------------------------------===//
def HLOClient_MinimumBroadcastShapesOp :
HLOClient_Op<"minimum_broadcast_shapes", [NoSideEffect]> {
string summary = "Minimizes the rank of two or more shapes to be broadcasted";
string description = [{
Given two or more 1D tensors representing shapes, returns one 1D tensor for
each operand, where operand `i` corresponds to output `i`.
The returned tensors have the property that they specify a shape which is a
reshape of the corresponding input shape, and the broadcasted output shape
(using shape::BroadcastOp) of the returned shapes is a reshape of the
broadcasted output shape of the input shapes. Among all possibilities with
this property, the one is chosen which minimizes the rank of each returned
shape.
The general idea of this op is that it can be used for ops which have a
broadcasting semantic to operate on shapes with a possibly smaller rank
while preserving equivalence of the computed values. After computing the
result of the op using reshaped operands, the result can be reshaped to the
result that would have been originally computed.
Here is an example with two input shapes:
```mlir
chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1],
[1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3]
```
The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the
broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are
reshapes of each other, and also each output is a reshape of the
corresponding input.
}];
let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes);
let results = (outs Variadic<1DTensorOf<[Index]>>:$results);
let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)";
}
#endif // CHLO_OPS

View File

@ -337,6 +337,26 @@ static LogicalResult Verify(ConstantLikeOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// MinimumBroadcastShapesOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(MinimumBroadcastShapesOp op) {
// Check that the number of operands matches the number of outputs.
unsigned result_shapes_count = op.results().size();
unsigned operand_shapes_count = op.shapes().size();
if (operand_shapes_count != result_shapes_count) {
return op.emitOpError()
<< "number of operand shapes (" << operand_shapes_count
<< ") does not match number of result shapes ("
<< result_shapes_count << ")";
}
if (operand_shapes_count < 2) {
return op.emitOpError() << "number of operand shapes ("
<< operand_shapes_count << ") should be >= 2";
}
return success();
}
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,

26
tests/chlo_ops.mlir Normal file
View File

@ -0,0 +1,26 @@
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
// CHECK-LABEL: func @minimum_broadcast_shapes
func @minimum_broadcast_shapes(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>)
-> (tensor<?xindex>, tensor<?xindex>) {
%0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs :
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
return %0, %1 : tensor<?xindex>, tensor<?xindex>
}
// -----
func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>) {
// expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}}
%0 = chlo.minimum_broadcast_shapes %lhs, %rhs :
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return
}
// -----
func @minimum_broadcast_shapes_one_operand(%arg: tensor<?xindex>) {
// expected-error @+1{{number of operand shapes (1) should be >= 2}}
%0 = chlo.minimum_broadcast_shapes %arg : tensor<?xindex> -> tensor<?xindex>
return
}