[HLO:MLIR] Make binary op type reification emit shape_of instead of tensor ops

This gives cleaner code and allows shape optimizations to happen on the result.

PiperOrigin-RevId: 362242975
This commit is contained in:
Benjamin Kramer 2021-03-11 02:00:50 -08:00 committed by TensorFlow MLIR Team
parent 9902e6ee32
commit 67a770e4e0
5 changed files with 14 additions and 33 deletions

View File

@ -3151,19 +3151,11 @@ LogicalResult deriveShapeFromFirstOperand(
return failure(); return failure();
} }
auto loc = op->getLoc(); auto loc = op->getLoc();
SmallVector<Value, 4> shape_values; // Some users rely on the result type being a static shape.
shape_values.reserve(operand_type.getRank()); auto shape_type =
for (auto element : llvm::enumerate(operand_type.getShape())) { RankedTensorType::get(operand_type.getRank(), builder->getIndexType());
if (element.value() == ShapedType::kDynamicSize) { reifiedReturnShapes->assign(
shape_values.push_back( {builder->create<shape::ShapeOfOp>(loc, shape_type, operand)});
builder->create<DimOp>(loc, operand, element.index()));
} else {
shape_values.push_back(
builder->create<ConstantIndexOp>(loc, element.value()));
}
}
*reifiedReturnShapes = SmallVector<Value, 1>{
builder->create<tensor::FromElementsOp>(loc, shape_values)};
return success(); return success();
} }

View File

@ -577,6 +577,7 @@ struct HloLegalizeToLhlo
ConversionTarget target(context); ConversionTarget target(context);
target.addLegalDialect<lmhlo::LmhloDialect>(); target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<shape::ShapeDialect>();
target.addLegalDialect<tensor::TensorDialect>(); target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalDialect<mhlo::MhloDialect>(); target.addIllegalDialect<mhlo::MhloDialect>();
// Declare tensor_load and tensor_store illegal. // Declare tensor_load and tensor_store illegal.

View File

@ -466,11 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> { func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
%result = "mhlo.add"(%lhs, %rhs) %result = "mhlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
@ -486,11 +482,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%result = "mhlo.tanh"(%arg0) %result = "mhlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])

View File

@ -5,10 +5,7 @@
// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>, // CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>) func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
-> tensor<2xindex> { -> tensor<2xindex> {
// CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<2x?xi1> -> tensor<2xindex>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex>
// CHECK: return %[[SHAPE]] : tensor<2xindex> // CHECK: return %[[SHAPE]] : tensor<2xindex>
%0 = "mhlo.select"(%pred, %a, %b) %0 = "mhlo.select"(%pred, %a, %b)
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
@ -21,10 +18,7 @@ func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
// CHECK-LABEL: @compare // CHECK-LABEL: @compare
// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>, // CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> { func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> {
// CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<2x?xf32> -> tensor<2xindex>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex>
// CHECK: return %[[SHAPE]] : tensor<2xindex> // CHECK: return %[[SHAPE]] : tensor<2xindex>
%0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"} %0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"}
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1> : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>

View File

@ -4,7 +4,8 @@
// CHECK-LABEL: @shape_of_unary // CHECK-LABEL: @shape_of_unary
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>) // CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
func @shape_of_unary(%arg : tensor<?x32xi16>) { func @shape_of_unary(%arg : tensor<?x32xi16>) {
// CHECK-NOT: shape_of // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex>
// CHECK: "use"(%[[SHAPE]])
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16> %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex> %1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
"use"(%1) : (tensor<?xindex>) -> () "use"(%1) : (tensor<?xindex>) -> ()
@ -17,7 +18,8 @@ func @shape_of_unary(%arg : tensor<?x32xi16>) {
// CHECK-LABEL: @shape_of_nary // CHECK-LABEL: @shape_of_nary
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) { func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
// CHECK-NOT: shape_of // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
// CHECK: "use"(%[[SHAPE]])
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16> %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex> %1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
"use"(%1) : (tensor<?xindex>) -> () "use"(%1) : (tensor<?xindex>) -> ()