// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s // CHECK-LABEL: func @multi_outputs_same func @multi_outputs_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor %2 = "mhlo.add"(%1, %1) : (tensor, tensor) -> tensor // CHECK: %[[RET:.*]]:2 = "mhlo.fusion" // CHECK-NEXT: mhlo.add // CHECK-NEXT: mhlo.subtract // CHECK-NEXT: mhlo.add // CHECK-NEXT: mhlo.return return %1, %2 : tensor, tensor } // ----- // CHECK-LABEL: func @multi_outputs_same_2 func @multi_outputs_same_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor) { %0 = "mhlo.abs"(%arg0) : (tensor) -> tensor %1 = "mhlo.abs"(%arg1) : (tensor) -> tensor %2 = "mhlo.add"(%0, %1) : (tensor, tensor) -> tensor %3 = "mhlo.abs"(%0) : (tensor) -> tensor %4 = "mhlo.abs"(%1) : (tensor) -> tensor // CHECK: %[[RET:.*]]:3 = "mhlo.fusion" // CHECK-NEXT: mhlo.abs // CHECK-NEXT: mhlo.abs // CHECK-NEXT: mhlo.add // CHECK-NEXT: mhlo.abs // CHECK-NEXT: mhlo.abs // CHECK-NEXT: mhlo.return return %2, %3, %4 : tensor, tensor, tensor } // ----- // CHECK-LABEL: func @multi_outputs_not_sure_same func @multi_outputs_not_sure_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { %0 = "mhlo.add"(%arg0, %arg0) : (tensor, tensor) -> tensor // CHECK-NOT: mhlo.fusion %1 = "mhlo.subtract"(%arg1, %arg1) : (tensor, tensor) -> tensor return %0, %1 : tensor, tensor } // ----- // CHECK-LABEL: func @reduce func @reduce(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor // CHECK: %[[RET0:.*]] = "mhlo.fusion" // CHECK-NEXT: mhlo.add // CHECK-NEXT: mhlo.subtract // CHECK-NEXT: mhlo.return // Currently we do not support fuse arguments and ops without direct producer-consumer // relationship. Thus Reduce Op should not be fused with above two ops. %2 = mhlo.constant dense<0.000000e+00> : tensor %3 = "mhlo.reduce"(%arg0, %2) ( { ^bb0(%arg2: tensor, %arg3: tensor): %4 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor "mhlo.return"(%4) : (tensor) -> () }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor %4 = "mhlo.add"(%3, %3) : (tensor, tensor) -> tensor // Above two ops should not be fused since reduce op can not be // fused with its consumer. // CHECK-NOT: mhlo.fusion return %1, %4 : tensor, tensor } // ----- // CHECK-LABEL: func @reduce_2 func @reduce_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor %1 = "mhlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor %2 = mhlo.constant dense<0.000000e+00> : tensor %3 = "mhlo.reduce"(%1, %2) ( { ^bb0(%arg2: tensor, %arg3: tensor): %4 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor "mhlo.return"(%4) : (tensor) -> () }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor // CHECK: %[[RET0:.*]]:2 = "mhlo.fusion" // CHECK-NEXT: mhlo.add // CHECK-NEXT: mhlo.subtract // CHECK-NEXT: mhlo.constant // CHECK-NEXT: mhlo.reduce // CHECK: mhlo.return // Following op should not be fused with the above ops since reduce op can not be // fused with its consumer. // CHECK-NOT: mhlo.fusion %4 = "mhlo.add"(%3, %3) : (tensor, tensor) -> tensor return %1, %4 : tensor, tensor }