// RUN: mlir-hlo-opt %s -mhlo-sink-constants-to-control-flow | FileCheck %s // Tests sinking constants to a while loop. // CHECK-LABEL: func @sink_const_to_while func @sink_const_to_while(%arg0: tensor) -> tensor { // CHECK-NEXT: mhlo.while %c0 = mhlo.constant dense<1> : tensor %c1 = mhlo.constant dense<2> : tensor %0 = "mhlo.while"(%arg0) ( { ^bb0(%arg1: tensor): // CHECK: %[[ARG1A:.+]]: tensor // CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor // CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]]) %1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): // CHECK: %[[ARG1B:.+]]: tensor // CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor // CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]] %2 = mhlo.add %arg1, %arg1 : tensor // CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]] %3 = mhlo.add %c1, %2 : tensor // CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]] %4 = mhlo.add %c1, %3 : tensor "mhlo.return"(%4) : (tensor) -> () }) : (tensor) -> tensor return %0 : tensor } // Tests sinking constants to a conditional op. // CHECK-LABEL: func @sink_const_to_conditional func @sink_const_to_conditional(%arg0: tensor) -> tensor { %c0 = mhlo.constant dense<1> : tensor %c1 = mhlo.constant dense<2> : tensor %0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor %1 = "mhlo.tuple"(%arg0) : (tensor) -> tuple> // CHECK: mhlo.if %2 = "mhlo.if"(%0, %1, %1) ( { ^bb0(%arg1: tuple>): // CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor %3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor // CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]], %4 = mhlo.add %c0, %3 : tensor %5 = "mhlo.tuple"(%4) : (tensor) -> tuple> "mhlo.return"(%5) : (tuple>) -> () }, { ^bb0(%arg1: tuple>): // CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor %6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor // CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %7 = mhlo.add %c1, %6 : tensor %8 = "mhlo.tuple"(%7) : (tensor) -> tuple> "mhlo.return"(%8) : (tuple>) -> () }) : (tensor, tuple>, tuple>) -> tuple> %9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor return %9 : tensor } func @sink_const_to_sort(%arg0: tensor<16xf32>) { %c0 = constant dense<1.0> : tensor // CHECK: "mhlo.sort" %0 = "mhlo.sort"(%arg0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // CHECK: constant dense<1.000000e+00> %1 = "mhlo.divide"(%arg1, %c0) : (tensor, tensor) -> tensor %2 = "mhlo.divide"(%arg2, %c0) : (tensor, tensor) -> tensor %3 = "mhlo.compare"(%1, %2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%3) : (tensor) -> () }) {is_stable = true} : (tensor<16xf32>) -> tensor<16xi32> return }