Add MinimumBroadcastShapesOp to chlo dialect.
This op is useful for rank specialization of broadcasts. Kernel Generator needs to generate one kernel for each rank, so if we can minimize the rank of the broadcast shape, we can support more cases with the same number of special-cased kernels. PiperOrigin-RevId: 360137827
This commit is contained in:
parent
2d818c4fd9
commit
e6a1f5f0f9
|
@ -708,4 +708,49 @@ def HLOClient_BroadcastSelectOp : HLOClient_Op<
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLOClient_MinimumBroadcastShapesOp :
|
||||
HLOClient_Op<"minimum_broadcast_shapes", [NoSideEffect]> {
|
||||
string summary = "Minimizes the rank of two or more shapes to be broadcasted";
|
||||
|
||||
string description = [{
|
||||
Given two or more 1D tensors representing shapes, returns one 1D tensor for
|
||||
each operand, where operand `i` corresponds to output `i`.
|
||||
|
||||
The returned tensors have the property that they specify a shape which is a
|
||||
reshape of the corresponding input shape, and the broadcasted output shape
|
||||
(using shape::BroadcastOp) of the returned shapes is a reshape of the
|
||||
broadcasted output shape of the input shapes. Among all possibilities with
|
||||
this property, the one is chosen which minimizes the rank of each returned
|
||||
shape.
|
||||
|
||||
The general idea of this op is that it can be used for ops which have a
|
||||
broadcasting semantic to operate on shapes with a possibly smaller rank
|
||||
while preserving equivalence of the computed values. After computing the
|
||||
result of the op using reshaped operands, the result can be reshaped to the
|
||||
result that would have been originally computed.
|
||||
|
||||
Here is an example with two input shapes:
|
||||
|
||||
```mlir
|
||||
chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1],
|
||||
[1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3]
|
||||
```
|
||||
|
||||
The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the
|
||||
broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are
|
||||
reshapes of each other, and also each output is a reshape of the
|
||||
corresponding input.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes);
|
||||
let results = (outs Variadic<1DTensorOf<[Index]>>:$results);
|
||||
|
||||
let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)";
|
||||
|
||||
}
|
||||
|
||||
#endif // CHLO_OPS
|
||||
|
|
|
@ -337,6 +337,26 @@ static LogicalResult Verify(ConstantLikeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinimumBroadcastShapesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static LogicalResult Verify(MinimumBroadcastShapesOp op) {
|
||||
// Check that the number of operands matches the number of outputs.
|
||||
unsigned result_shapes_count = op.results().size();
|
||||
unsigned operand_shapes_count = op.shapes().size();
|
||||
if (operand_shapes_count != result_shapes_count) {
|
||||
return op.emitOpError()
|
||||
<< "number of operand shapes (" << operand_shapes_count
|
||||
<< ") does not match number of result shapes ("
|
||||
<< result_shapes_count << ")";
|
||||
}
|
||||
if (operand_shapes_count < 2) {
|
||||
return op.emitOpError() << "number of operand shapes ("
|
||||
<< operand_shapes_count << ") should be >= 2";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @minimum_broadcast_shapes
|
||||
func @minimum_broadcast_shapes(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>)
|
||||
-> (tensor<?xindex>, tensor<?xindex>) {
|
||||
%0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs :
|
||||
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
||||
return %0, %1 : tensor<?xindex>, tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @minimum_broadcast_shapes_mismatch_operand_and_result_count(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>) {
|
||||
// expected-error @+1{{number of operand shapes (2) does not match number of result shapes (1)}}
|
||||
%0 = chlo.minimum_broadcast_shapes %lhs, %rhs :
|
||||
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @minimum_broadcast_shapes_one_operand(%arg: tensor<?xindex>) {
|
||||
// expected-error @+1{{number of operand shapes (1) should be >= 2}}
|
||||
%0 = chlo.minimum_broadcast_shapes %arg : tensor<?xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue