[HLO:MLIR] Make binary op type reification emit shape_of instead of tensor ops
This gives cleaner code and allows shape optimizations to happen on the result. PiperOrigin-RevId: 362242975
This commit is contained in:
parent
9902e6ee32
commit
67a770e4e0
|
@ -3151,19 +3151,11 @@ LogicalResult deriveShapeFromFirstOperand(
|
|||
return failure();
|
||||
}
|
||||
auto loc = op->getLoc();
|
||||
SmallVector<Value, 4> shape_values;
|
||||
shape_values.reserve(operand_type.getRank());
|
||||
for (auto element : llvm::enumerate(operand_type.getShape())) {
|
||||
if (element.value() == ShapedType::kDynamicSize) {
|
||||
shape_values.push_back(
|
||||
builder->create<DimOp>(loc, operand, element.index()));
|
||||
} else {
|
||||
shape_values.push_back(
|
||||
builder->create<ConstantIndexOp>(loc, element.value()));
|
||||
}
|
||||
}
|
||||
*reifiedReturnShapes = SmallVector<Value, 1>{
|
||||
builder->create<tensor::FromElementsOp>(loc, shape_values)};
|
||||
// Some users rely on the result type being a static shape.
|
||||
auto shape_type =
|
||||
RankedTensorType::get(operand_type.getRank(), builder->getIndexType());
|
||||
reifiedReturnShapes->assign(
|
||||
{builder->create<shape::ShapeOfOp>(loc, shape_type, operand)});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -577,6 +577,7 @@ struct HloLegalizeToLhlo
|
|||
ConversionTarget target(context);
|
||||
target.addLegalDialect<lmhlo::LmhloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<shape::ShapeDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
// Declare tensor_load and tensor_store illegal.
|
||||
|
|
|
@ -466,11 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
|
|||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%result = "mhlo.add"(%lhs, %rhs)
|
||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
||||
|
@ -486,11 +482,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%result = "mhlo.tanh"(%arg0)
|
||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
||||
|
|
|
@ -5,10 +5,7 @@
|
|||
// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
|
||||
func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
|
||||
-> tensor<2xindex> {
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
|
||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<2x?xi1> -> tensor<2xindex>
|
||||
// CHECK: return %[[SHAPE]] : tensor<2xindex>
|
||||
%0 = "mhlo.select"(%pred, %a, %b)
|
||||
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
|
||||
|
@ -21,10 +18,7 @@ func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
|
|||
// CHECK-LABEL: @compare
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
|
||||
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> {
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<2x?xf32> -> tensor<2xindex>
|
||||
// CHECK: return %[[SHAPE]] : tensor<2xindex>
|
||||
%0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"}
|
||||
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
// CHECK-LABEL: @shape_of_unary
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
|
||||
func @shape_of_unary(%arg : tensor<?x32xi16>) {
|
||||
// CHECK-NOT: shape_of
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex>
|
||||
// CHECK: "use"(%[[SHAPE]])
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
|
||||
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||
"use"(%1) : (tensor<?xindex>) -> ()
|
||||
|
@ -17,7 +18,8 @@ func @shape_of_unary(%arg : tensor<?x32xi16>) {
|
|||
// CHECK-LABEL: @shape_of_nary
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
|
||||
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
||||
// CHECK-NOT: shape_of
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
|
||||
// CHECK: "use"(%[[SHAPE]])
|
||||
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
|
||||
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||
"use"(%1) : (tensor<?xindex>) -> ()
|
||||
|
|
Loading…
Reference in New Issue