// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-move-up-dynamic-broadcasts-for-fusion --canonicalize --cse %s | FileCheck %s // Shape computations shall be reified. // CHECK-LABEL: @shape_of_unary // CHECK-SAME: (%[[ARG:.*]]: tensor) func @shape_of_unary(%arg : tensor) { // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor // CHECK: "use"(%[[SHAPE]]) %0 = "mhlo.convert"(%arg) : (tensor) -> tensor %1 = shape.shape_of %0 : tensor -> tensor "use"(%1) : (tensor) -> () return } // ----- // Shape computations shall be reified. // CHECK-LABEL: @shape_of_nary // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) func @shape_of_nary(%arg0 : tensor, %arg1 : tensor) { // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor // CHECK: "use"(%[[SHAPE]]) %0 = mhlo.subtract %arg0, %arg1 : tensor %1 = mhlo.subtract %0, %arg1 : tensor %2 = shape.shape_of %1 : tensor -> tensor "use"(%2) : (tensor) -> () return } // ----- // Broadcasts can be moved up over unary shape-preserving operations. // CHECK-LABEL: @bcast_unary // CHECK-SAME: (%[[ARG:.*]]: tensor, %[[OUT_DIMS:.*]]: tensor<3xindex>) func @bcast_unary(%arg : tensor, %out_dims : tensor<3xindex>) -> tensor { // CHECK: %[[BCASTED_OPERAND:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[OUT_DIMS]]) // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor // CHECK: "mhlo.convert"(%[[BCASTED_OPERAND]]) : (tensor) -> tensor %0 = "mhlo.convert"(%arg) : (tensor) -> tensor %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor, tensor<3xindex>) -> tensor return %1 : tensor } // ----- // Broadcasts can be moved up over n-ary shape-preserving operations. // CHECK-LABEL: @bcast_nary // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[OUT_DIMS:.*]]: tensor<3xindex>) func @bcast_nary(%arg0 : tensor, %arg1 : tensor, %out_dims : tensor<3xindex>) -> tensor { // CHECK-NOT: subtract // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[OUT_DIMS]]) // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[OUT_DIMS]]) // CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : tensor %0 = mhlo.subtract %arg0, %arg1 : tensor %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor, tensor<3xindex>) -> tensor return %1 : tensor } // ----- // Exemplary IR as it appears in the lowering with `tf.Sub` and `tf.Cast`. // CHECK-LABEL: @cast_sub // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor func @cast_sub(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-NOT: convert // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %{{.*}}) // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %{{.*}}) // CHECK: %[[CONVERTED_BCASTED_ARG0:.*]] = "mhlo.convert"(%[[BCASTED_ARG0]]) : (tensor) -> tensor // CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG1]], %[[CONVERTED_BCASTED_ARG0]] : tensor %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor %1 = shape.shape_of %arg1 : tensor -> tensor %2 = shape.shape_of %0 : tensor -> tensor %3 = shape.cstr_broadcastable %1, %2 : tensor, tensor %4 = shape.assuming %3 -> (tensor) { %5 = shape.shape_of %arg1 : tensor -> tensor %6 = shape.shape_of %0 : tensor -> tensor %7 = shape.broadcast %5, %6 : tensor, tensor -> tensor %8 = tensor.cast %7 : tensor to tensor<3xindex> %9 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %8) { broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xindex>) -> tensor %10 = "mhlo.dynamic_broadcast_in_dim"(%0, %8) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor %11 = mhlo.subtract %9, %10 : tensor shape.assuming_yield %11 : tensor } return %4 : tensor } // ----- // CHECK-LABEL: @inline_bcasted_shape_operands // CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor) func @inline_bcasted_shape_operands(%a : tensor, %b : tensor, %c : tensor) -> !shape.witness { // CHECK-NOT: shape.broadcast // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[A]], %[[B]], %[[C]] // CHECK: return %[[WITNESS]] : !shape.witness %0 = shape.broadcast %a, %b : tensor, tensor -> tensor %1 = shape.cstr_broadcastable %0, %c : tensor, tensor return %1 : !shape.witness }