// RUN: mlir-hlo-opt --lhlo-fusion -split-input-file %s -o - | FileCheck %s // CHECK-LABEL: @simple_kloop_fusion // CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref) -> memref func @simple_kloop_fusion(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) -> memref { // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref, memref, memref) -> () // CHECK: }) : () -> () // CHECK: return %[[ARG3]] : memref "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () "lmhlo.add"(%arg1, %arg2, %arg3) : (memref, memref, memref) -> () return %arg3 : memref } // ----- // CHECK-LABEL: @simple_multi_output_kloop_fusion // CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref) -> (memref, memref) func @simple_multi_output_kloop_fusion(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) -> (memref, memref) { // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref, memref, memref) -> () // CHECK: }) : () -> () // CHECK: return %[[ARG1]], %[[ARG3]] : memref, memref "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () "lmhlo.add"(%arg1, %arg2, %arg3) : (memref, memref, memref) -> () return %arg1, %arg3 : memref, memref } // ----- // CHECK-LABEL: @simple_multi_output_kloop_fusion_with_reorder // CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref, %[[ARG4:.*]]: memref<2xindex>, %[[ARG5:.*]]: memref) func @simple_multi_output_kloop_fusion_with_reorder(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref<2xindex>, %arg5: memref) -> (memref, memref, memref) { // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref, memref, memref) -> () // CHECK: }) : () -> () // CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[ARG4]], %[[ARG5]]) // CHECK: return %[[ARG1]], %[[ARG3]], %[[ARG5]] : memref, memref, memref "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () "lmhlo.dynamic_broadcast_in_dim"(%arg1, %arg4, %arg5) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (memref, memref<2xindex>, memref) -> () "lmhlo.add"(%arg1, %arg2, %arg3) : (memref, memref, memref) -> () return %arg1, %arg3, %arg5 : memref, memref, memref } // ----- // CHECK-LABEL: @same_num_elements_multi_output_kloop_fusion // CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<2xi64>, %[[ARG3:.*]]: memref, %[[ARG4:.*]]: memref, %[[ARG5:.*]]: memref) func @same_num_elements_multi_output_kloop_fusion(%arg0: memref, %arg1: memref, %arg2: memref<2xi64>, %arg3: memref, %arg4: memref, %arg5: memref) -> (memref, memref) { // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () // CHECK: "lmhlo.dynamic_reshape"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) // CHECK: "lmhlo.add"(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : (memref, memref, memref) -> () // CHECK: }) : () -> () // CHECK: return %[[ARG1]], %[[ARG5]] : memref, memref "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () "lmhlo.dynamic_reshape"(%arg1, %arg2, %arg3) : (memref, memref<2xi64>, memref) -> () "lmhlo.add"(%arg3, %arg4, %arg5) : (memref, memref, memref) -> () return %arg1, %arg5 : memref, memref } // ----- // CHECK-LABEL: @check_not_kloop_fusion func @check_not_kloop_fusion(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) -> (memref, memref) { // CHECK-NOT: "lmhlo.fusion" "lmhlo.add"(%arg0, %arg0, %arg1) : (memref, memref, memref) -> () "lmhlo.subtract"(%arg2, %arg2, %arg3) : (memref, memref, memref) -> () return %arg1, %arg3: memref, memref } // ----- // CHECK-LABEL: @kloop_fusion_with_dealloc // CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref) func @kloop_fusion_with_dealloc(%arg0: memref, %arg1: memref) -> (memref, memref) { // CHECK: %[[TMP3:.*]] = memref.alloc // CHECK: %[[TMP5:.*]] = memref.alloc // CHECK: %[[TMP9:.*]] = memref.alloc // CHECK: %[[TMP13:.*]] = memref.alloc // CHECK: %[[TMP16:.*]] = memref.alloc // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP3]]) : (memref, memref, memref) -> () // CHECK: "lmhlo.multiply"(%[[ARG0]], %[[ARG1]], %[[TMP5]]) : (memref, memref, memref) -> () // CHECK: "lmhlo.abs"(%[[TMP3]], %[[TMP9]]) : (memref, memref) -> () // CHECK: "lmhlo.abs"(%[[TMP5]], %[[TMP13]]) : (memref, memref) -> () // CHECK: "lmhlo.multiply"(%[[TMP9]], %[[TMP13]], %[[TMP16]]) : (memref, memref, memref) -> () // CHECK: }) : () -> () // CHECK: memref.dealloc %[[TMP3]] : memref // CHECK: memref.dealloc %[[TMP5]] : memref // CHECK: memref.dealloc %[[TMP13]] : memref // CHECK: return %[[TMP9]], %[[TMP16]] : memref, memref %c0 = constant 0 : index %c1 = constant 1 : index %0 = shape.shape_of %arg0 : memref -> tensor<2xindex> %1 = tensor.extract %0[%c0] : tensor<2xindex> %2 = tensor.extract %0[%c1] : tensor<2xindex> %3 = memref.alloc(%1, %2) : memref "lmhlo.add"(%arg0, %arg1, %3) : (memref, memref, memref) -> () %4 = memref.alloc(%1, %2) : memref "lmhlo.multiply"(%arg0, %arg1, %4) : (memref, memref, memref) -> () %5 = shape.shape_of %3 : memref -> tensor<2xindex> %6 = tensor.extract %5[%c0] : tensor<2xindex> %7 = tensor.extract %5[%c1] : tensor<2xindex> %8 = memref.alloc(%6, %7) : memref "lmhlo.abs"(%3, %8) : (memref, memref) -> () memref.dealloc %3 : memref %9 = shape.shape_of %4 : memref -> tensor<2xindex> %10 = tensor.extract %9[%c0] : tensor<2xindex> %11 = tensor.extract %9[%c1] : tensor<2xindex> %12 = memref.alloc(%10, %11) : memref "lmhlo.abs"(%4, %12) : (memref, memref) -> () memref.dealloc %4 : memref %13 = shape.shape_of %8 : memref -> tensor<2xindex> %14 = tensor.extract %13[%c0] : tensor<2xindex> %15 = tensor.extract %13[%c1] : tensor<2xindex> %16 = memref.alloc(%14, %15) : memref "lmhlo.multiply"(%8, %12, %16) : (memref, memref, memref) -> () memref.dealloc %12 : memref return %8, %16 : memref, memref } // ----- // CHECK-LABEL: @simple_kinput // CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref func @simple_kinput(%arg0: memref, %arg1: memref, %arg2: memref, %init: memref) -> memref { // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( { // CHECK: }) : () -> () // CHECK: return %[[ARG2]] : memref "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () "lmhlo.reduce"(%arg1, %init, %arg2) ( { ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref, memref, memref) -> () return %arg2: memref } // ----- // CHECK-LABEL: @multi_output_kinput // CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref func @multi_output_kinput(%arg0: memref, %arg1: memref, %arg2: memref, %init: memref) -> (memref, memref) { // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref, memref) -> () // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( { // CHECK: }) : () -> () // CHECK: return %[[ARG1]], %[[ARG2]] : memref, memref "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () "lmhlo.reduce"(%arg1, %init, %arg2) ( { ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref, memref, memref) -> () return %arg1, %arg2: memref, memref } // ----- // CHECK-LABEL: @row_red_and_row_red_kinput // CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref, %[[ARG4:.*]]: memref, %[[ARG5:.*]]: memref, %[[ARG6:.*]]: memref func @row_red_and_row_red_kinput(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %init: memref) -> (memref, memref) { // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref, memref, memref) -> () // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref, memref) -> () // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( { // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( { // CHECK: }) : () -> () // CHECK: return %[[ARG3]], %[[ARG4]] : memref, memref "lmhlo.add"(%arg0, %arg1, %arg2) : (memref, memref, memref) -> () "lmhlo.abs"(%arg2, %arg5) : (memref, memref) -> () "lmhlo.reduce"(%arg5, %init, %arg3) ( { ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () "lmhlo.reduce"(%arg2, %init, %arg4) ( { ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () return %arg3, %arg4: memref, memref } // ----- // CHECK-LABEL: @row_red_and_col_red_kinput // CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, %[[ARG3:.*]]: memref, %[[ARG4:.*]]: memref, %[[ARG5:.*]]: memref, %[[ARG6:.*]]: memref func @row_red_and_col_red_kinput(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %init: memref) -> (memref, memref) { // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref, memref, memref) -> () // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref, memref) -> () // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( { // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( { // CHECK: }) : () -> () // CHECK: return %[[ARG3]], %[[ARG4]] : memref, memref "lmhlo.add"(%arg0, %arg1, %arg2) : (memref, memref, memref) -> () "lmhlo.abs"(%arg2, %arg5) : (memref, memref) -> () "lmhlo.reduce"(%arg5, %init, %arg3) ( { ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () "lmhlo.reduce"(%arg2, %init, %arg4) ( { ^bb0(%targ1: memref, %targ2: memref, %tresult: memref): "lmhlo.add"(%targ1, %targ2, %tresult) : (memref, memref, memref) -> () "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref, memref, memref) -> () return %arg3, %arg4: memref, memref } // ----- // CHECK-LABEL: @reduce_should_not_have_consumer_in_the_fusion // CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref func @reduce_should_not_have_consumer_in_the_fusion(%arg0: memref, %arg1: memref) -> (memref, memref) { // CHECK: %[[TMP4:.*]] = memref.alloc // CHECK: %[[TMP7:.*]] = memref.alloc // CHECK: %[[TMP8:.*]] = memref.alloc // CHECK: %[[TMP9:.*]] = memref.alloc // CHECK: "lmhlo.fusion"() ( { // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP4]]) : (memref, memref, memref) -> () // CHECK: "lmhlo.subtract"(%[[ARG0]], %[[TMP4]], %[[TMP7]]) : (memref, memref, memref) -> () // CHECK: "lmhlo.constant"(%[[TMP8]]) {value = dense<0.000000e+00> : tensor} : (memref) -> () // CHECK: "lmhlo.reduce"(%[[TMP7]], %[[TMP8]], %[[TMP9]]) ( { // CHECK: }) : () -> () // CHECK: memref.dealloc %[[TMP4]] : memref // CHECK: memref.dealloc %[[TMP8]] : memref // CHECK: %[[TMP12:.*]] = memref.alloc // CHECK: "lmhlo.add"(%[[TMP9]], %[[TMP9]], %[[TMP12]]) : (memref, memref, memref) -> () // CHECK: memref.dealloc %[[TMP9]] : memref // CHECK: return %[[TMP7]], %[[TMP12]] : memref, memref %c1 = constant 1 : index %c0 = constant 0 : index %0 = shape.shape_of %arg0 : memref -> tensor<2xindex> %1 = tensor.extract %0[%c0] : tensor<2xindex> %2 = tensor.extract %0[%c1] : tensor<2xindex> %3 = memref.alloc(%1, %2) : memref "lmhlo.add"(%arg0, %arg1, %3) : (memref, memref, memref) -> () %4 = shape.shape_of %arg0 : memref -> tensor<2xindex> %5 = tensor.extract %4[%c0] : tensor<2xindex> %6 = tensor.extract %4[%c1] : tensor<2xindex> %7 = memref.alloc(%5, %6) : memref "lmhlo.subtract"(%arg0, %3, %7) : (memref, memref, memref) -> () memref.dealloc %3 : memref %8 = memref.alloc() : memref "lmhlo.constant"(%8) {value = dense<0.000000e+00> : tensor} : (memref) -> () %9 = memref.alloc(%5) : memref "lmhlo.reduce"(%7, %8, %9) ( { ^bb0(%arg2: memref, %arg3: memref, %arg4: memref): // no predecessors "lmhlo.add"(%arg2, %arg3, %arg4) : (memref, memref, memref) -> () "lmhlo.terminator"() : () -> () }) {dimensions = dense<1> : tensor<1xi64>} : (memref, memref, memref) -> () memref.dealloc %8 : memref %10 = shape.shape_of %9 : memref -> tensor<1xindex> %11 = tensor.extract %10[%c0] : tensor<1xindex> %12 = memref.alloc(%11) : memref "lmhlo.add"(%9, %9, %12) : (memref, memref, memref) -> () memref.dealloc %9 : memref return %7, %12 : memref, memref } // ----- // CHECK-LABEL: @const_should_not_be_output func @const_should_not_be_output(%arg0: memref) -> (memref, memref) { // CHECK-NOT: lmhlo.fusion %0 = memref.alloc() : memref "lmhlo.constant"(%0) {value = dense<1.000000e+00> : tensor} : (memref) -> () %1 = memref.alloc() : memref "lmhlo.add"(%arg0, %0, %1) : (memref, memref, memref) -> () return %0, %1 : memref, memref }