2020-07-07 07:28:26 +08:00
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
// -----
2021-01-27 09:23:49 +08:00
func @invalid_allreduce ( %input0 : memref < 2x f32 > , %input1 : memref < 3x f32 > ) {
// expected-error@+1 {{requires operand #1 (type: 'memref<3xf32>') and result #1 (type: 'memref<2xf32>') to have same type}}
"lmhlo.all_reduce" ( %input0 , %input1 , %input0 , %input0 ) ( {
^bb0 ( %arg0 : tensor < f32 > , %arg1 : tensor < f32 > ) :
%add = mhlo. add %arg0 , %arg1 : tensor < f32 >
"mhlo.return" ( %add ) : ( tensor < f32 > ) -> ( )
} )
{ channel_id = { handle = 1 : i64 , type = 0 : i64 } , constrain_layout = false,
2021-02-02 02:22:48 +08:00
replica_groups = dense< [ [ 0 , 1 , 2 , 3 ] , [ 5 , 6 , 7 , 4 ] ] > : tensor < 2x4x i64 > ,
2021-01-27 09:23:49 +08:00
use_global_device_ids = false} : ( memref < 2x f32 > , memref < 3x f32 > , memref < 2x f32 > , memref < 2x f32 > ) -> ( )
return
}
// -----
func @invalid_allreduce ( %input0 : memref < 2x f32 > , %input1 : memref < 3x f16 > ) {
2021-02-02 02:22:48 +08:00
// expected-error@+1 {{requires the same element type for all operands}}
2021-01-27 09:23:49 +08:00
"lmhlo.all_reduce" ( %input0 , %input1 , %input0 , %input1 ) ( {
^bb0 ( %arg0 : tensor < f32 > , %arg1 : tensor < f32 > ) :
%add = mhlo. add %arg0 , %arg1 : tensor < f32 >
"mhlo.return" ( %add ) : ( tensor < f32 > ) -> ( )
} )
{ channel_id = { handle = 1 : i64 , type = 0 : i64 } , constrain_layout = false,
replica_groups = dense< [ [ 0 , 1 , 2 , 3 ] , [ 5 , 6 , 7 , 8 ] ] > : tensor < 2x4x i64 > ,
use_global_device_ids = false} : ( memref < 2x f32 > , memref < 3x f16 > , memref < 2x f32 > , memref < 3x f16 > ) -> ( )
return
}
// -----
2021-05-07 02:02:00 +08:00
// CHECK-LABEL: func @mixed_types_allgather
func @mixed_types_allgather ( %a0 : memref < 1x1x f32 > , %a1 : memref < 1x1x i32 > ) {
"lmhlo.all_gather" ( %a0 , %a1 , %a0 , %a1 ) { all_gather_dimension = 0 : i64 ,
constrain_layout = false, replica_groups = dense< 0 > : tensor < 1x1x i64 > ,
use_global_device_ids = false} : ( memref < 1x1x f32 > , memref < 1x1x i32 > , memref < 1x1x f32 > , memref < 1x1x i32 > ) -> ( )
return
}
// -----
2021-01-27 09:23:49 +08:00
2021-02-02 02:22:48 +08:00
func @invalid_allgather ( %input0 : memref < 2x f32 > , %output : memref < 8x f32 > ) {
// expected-error@+1 {{replica id #1 seen more than once}}
"lmhlo.all_gather" ( %input0 , %output )
{ channel_id = { handle = 1 : i64 , type = 0 : i64 } , constrain_layout = false,
replica_groups = dense< [ [ 0 , 1 , 1 , 3 ] , [ 5 , 6 , 7 , 8 ] ] > : tensor < 2x4x i64 > ,
use_global_device_ids = false, all_gather_dimension = 0 : i64 } : ( memref < 2x f32 > , memref < 8x f32 > ) -> ( )
return
}
// -----
func @invalid_alltoall ( %input0 : memref < 2x f32 > , %output : memref < 8x f32 > ) {
// expected-error@+1 {{replica id #4 not seen in replica groups}}
"lmhlo.all_to_all" ( %input0 , %output )
{ channel_id = { handle = 1 : i64 , type = 0 : i64 } , constrain_layout = false,
replica_groups = dense< [ [ 0 , 1 , 2 , 3 ] , [ 5 , 6 , 7 , 8 ] ] > : tensor < 2x4x i64 > ,
use_global_device_ids = false, all_gather_dimension = 0 : i64 } : ( memref < 2x f32 > , memref < 8x f32 > ) -> ( )
return
}
// -----
func @invalid_alltoall ( %input0 : memref < 2x f32 > , %output : memref < 8x f32 > ) {
// expected-error@+1 {{replica groups should be a rank 2 tensor of 64 bit integers}}
"lmhlo.all_to_all" ( %input0 , %output )
{ channel_id = { handle = 1 : i64 , type = 0 : i64 } , constrain_layout = false,
replica_groups = dense< 0 > : tensor < 1x i64 > ,
use_global_device_ids = false, all_gather_dimension = 0 : i64 } : ( memref < 2x f32 > , memref < 8x f32 > ) -> ( )
return
}
// -----
2020-07-07 07:28:26 +08:00
// CHECK-LABEL: func @ceil
func @ceil ( %input : memref < 2x2x f32 > , %result : memref < 2x2x f32 > ) {
2020-07-09 01:05:32 +08:00
"lmhlo.ceil" ( %input , %result ) : ( memref < 2x2x f32 > , memref < 2x2x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @ceil ( %input : memref < 2x2x i32 > , %result : memref < 2x2x i32 > ) {
// expected-error@+1{{must be memref of floating-point values}}
2020-07-09 01:05:32 +08:00
"lmhlo.ceil" ( %input , %result ) : ( memref < 2x2x i32 > , memref < 2x2x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @cos
func @cos ( %input : memref < 2x2x f32 > , %result : memref < 2x2x f32 > ) {
2020-07-09 01:05:32 +08:00
"lmhlo.cosine" ( %input , %result ) : ( memref < 2x2x f32 > , memref < 2x2x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @cos
func @cos ( %input : memref < 2x2x complex< f32 > > , %result : memref < 2x2x complex< f32 > > ) {
2020-07-09 01:05:32 +08:00
"lmhlo.cosine" ( %input , %result ) : ( memref < 2x2x complex< f32 > > , memref < 2x2x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @cos ( %input : memref < 2x2x i32 > , %result : memref < 2x2x i32 > ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.cosine" ( %input , %result ) : ( memref < 2x2x i32 > , memref < 2x2x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @sin
func @sin ( %input : memref < 2x2x f32 > , %result : memref < 2x2x f32 > ) {
2020-07-09 01:05:32 +08:00
"lmhlo.sine" ( %input , %result ) : ( memref < 2x2x f32 > , memref < 2x2x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @sin
func @sin ( %input : memref < 2x2x complex< f32 > > , %result : memref < 2x2x complex< f32 > > ) {
2020-07-09 01:05:32 +08:00
"lmhlo.sine" ( %input , %result ) : ( memref < 2x2x complex< f32 > > , memref < 2x2x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @sin ( %input : memref < 2x2x i32 > , %result : memref < 2x2x i32 > ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.sine" ( %input , %result ) : ( memref < 2x2x i32 > , memref < 2x2x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @add_memrefs
func @add_memrefs ( %arg0 : memref < 1x i32 > , %arg1 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.add" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @abs_memref
func @abs_memref ( %in : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.abs" ( %in , %out ) : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @convert_memref
func @convert_memref ( %in : memref < 10x f32 > , %out : memref < 10x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.convert" ( %in , %out ) : ( memref < 10x f32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @convert_memref ( %in : memref < 10x f32 > , %out : memref < 9x i32 > ) -> ( ) {
// expected-error@+1{{requires the same shape for all operands}}
2020-07-09 01:05:32 +08:00
"lmhlo.convert" ( %in , %out ) : ( memref < 10x f32 > , memref < 9x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @exp
func @exp ( %input : memref < 2x2x f32 > , %result : memref < 2x2x f32 > ) {
2020-07-09 01:05:32 +08:00
"lmhlo.exponential" ( %input , %result ) : ( memref < 2x2x f32 > , memref < 2x2x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @exp
func @exp ( %input : memref < 2x2x complex< f32 > > , %result : memref < 2x2x complex< f32 > > ) {
2020-07-09 01:05:32 +08:00
"lmhlo.exponential" ( %input , %result ) : ( memref < 2x2x complex< f32 > > , memref < 2x2x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @exp ( %input : memref < 2x2x i32 > , %result : memref < 2x2x i32 > ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.exponential" ( %input , %result ) : ( memref < 2x2x i32 > , memref < 2x2x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @log_memref
func @log_memref ( %in : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.log" ( %in , %out ) : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @log_memref
func @log_memref ( %in : memref < 10x complex< f32 > > , %out : memref < 10x complex< f32 > > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.log" ( %in , %out ) : ( memref < 10x complex< f32 > > , memref < 10x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @log_memref ( %in : memref < 10x i32 > , %out : memref < 10x i32 > ) -> ( ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.log" ( %in , %out ) : ( memref < 10x i32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @neg_memref
func @neg_memref ( %in : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.negate" ( %in , %out ) : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref ( %in : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.rsqrt" ( %in , %out ) : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref ( %in : memref < 10x complex< f32 > > , %out : memref < 10x complex< f32 > > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.rsqrt" ( %in , %out ) : ( memref < 10x complex< f32 > > , memref < 10x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @rsqrt_memref ( %in : memref < 10x i32 > , %out : memref < 10x i32 > ) -> ( ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.rsqrt" ( %in , %out ) : ( memref < 10x i32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref ( %in : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.sqrt" ( %in , %out ) : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref ( %in : memref < 10x complex< f32 > > , %out : memref < 10x complex< f32 > > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.sqrt" ( %in , %out ) : ( memref < 10x complex< f32 > > , memref < 10x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @sqrt_memref ( %in : memref < 10x i32 > , %out : memref < 10x i32 > ) -> ( ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.sqrt" ( %in , %out ) : ( memref < 10x i32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @sign_memref
func @sign_memref ( %in : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.sign" ( %in , %out ) : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @tanh_memref
func @tanh_memref ( %in : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.tanh" ( %in , %out ) : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @tanh_memref
func @tanh_memref ( %in : memref < 10x complex< f32 > > , %out : memref < 10x complex< f32 > > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.tanh" ( %in , %out ) : ( memref < 10x complex< f32 > > , memref < 10x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @tanh_memref ( %in : memref < 10x i32 > , %out : memref < 10x i32 > ) -> ( ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.tanh" ( %in , %out ) : ( memref < 10x i32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @tanh_memref ( %arg0 : memref < 1x f32 > , %arg1 : memref < 2x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
// expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}}
"lmhlo.tanh" ( %arg0 , %arg1 ) : ( memref < 1x f32 > , memref < 2x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @add_memref
func @add_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.add" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @div_memref
func @div_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.divide" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @max_memref
func @max_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.maximum" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @min_memref
func @min_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.minimum" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @mul_memref
func @mul_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.multiply" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @sub_memref
func @sub_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.subtract" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @and_memref
func @and_memref ( %lhs : memref < 10x i32 > , %rhs : memref < 10x i32 > , %out : memref < 10x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.and" ( %lhs , %rhs , %out ) : ( memref < 10x i32 > , memref < 10x i32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @and_memref
func @and_memref ( %lhs : memref < 10x i1 > , %rhs : memref < 10x i1 > , %out : memref < 10x i1 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.and" ( %lhs , %rhs , %out ) : ( memref < 10x i1 > , memref < 10x i1 > , memref < 10x i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @and_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
2020-07-09 01:05:32 +08:00
"lmhlo.and" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @or_memref
func @or_memref ( %lhs : memref < 10x i32 > , %rhs : memref < 10x i32 > , %out : memref < 10x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.or" ( %lhs , %rhs , %out ) : ( memref < 10x i32 > , memref < 10x i32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @or_memref
func @or_memref ( %lhs : memref < 10x i1 > , %rhs : memref < 10x i1 > , %out : memref < 10x i1 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.or" ( %lhs , %rhs , %out ) : ( memref < 10x i1 > , memref < 10x i1 > , memref < 10x i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @or_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
2020-07-09 01:05:32 +08:00
"lmhlo.or" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @xor_memref
func @xor_memref ( %lhs : memref < 10x i32 > , %rhs : memref < 10x i32 > , %out : memref < 10x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.xor" ( %lhs , %rhs , %out ) : ( memref < 10x i32 > , memref < 10x i32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @xor_memref
func @xor_memref ( %lhs : memref < 10x i1 > , %rhs : memref < 10x i1 > , %out : memref < 10x i1 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.xor" ( %lhs , %rhs , %out ) : ( memref < 10x i1 > , memref < 10x i1 > , memref < 10x i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @xor_memref ( %lhs : memref < 10x f32 > , %rhs : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
2020-07-09 01:05:32 +08:00
"lmhlo.xor" ( %lhs , %rhs , %out ) : ( memref < 10x f32 > , memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @broadcast_in_dim_memref
func @broadcast_in_dim_memref ( %arg0 : memref < 1x2x i32 > , %out : memref < 1x2x2x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.broadcast_in_dim" ( %arg0 , %out ) { broadcast_dimensions = dense< [ 1 , 2 ] > : tensor < 2x i64 > } : ( memref < 1x2x i32 > , memref < 1x2x2x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
func @broadcast_in_dim_zero_rank_memref ( %arg0 : memref < i32 > , %out : memref < 1x2x3x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.broadcast_in_dim" ( %arg0 , %out ) { broadcast_dimensions = dense< [ ] > : tensor < 0 xi64> } : ( memref < i32 > , memref < 1x2x3x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @reduce_memref
func @reduce_memref ( %input : memref < 10x f32 > , %init : memref < f32 > , %out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.reduce" ( %input , %init , %out ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %arg1 : memref < f32 > , %arg2 : memref < f32 > , %result : memref < f32 > ) :
2020-07-09 01:05:32 +08:00
"lmhlo.add" ( %arg1 , %arg2 , %result ) : ( memref < f32 > , memref < f32 > , memref < f32 > ) -> ( )
"lmhlo.terminator" ( ) : ( ) -> ( )
2020-07-07 07:28:26 +08:00
} ) { dimensions = dense< [ 0 ] > : tensor < 1x i64 > } : ( memref < 10x f32 > , memref < f32 > , memref < 1x f32 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @fusion_memref
func @fusion_memref ( %input1 : memref < 10x f32 > , %input2 : memref < 10x f32 > , %input3 : memref < 10x f32 > , %out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.fusion" ( ) ( {
2021-03-17 04:31:59 +08:00
%0 = memref . tensor _load %input1 : memref < 10x f32 >
%1 = memref . tensor _load %input2 : memref < 10x f32 >
2020-07-07 12:51:24 +08:00
%2 = "mhlo.add" ( %0 , %1 ) { name = "add" } : ( tensor < 10x f32 > , tensor < 10x f32 > ) -> tensor < 10x f32 >
2021-03-17 04:31:59 +08:00
%3 = memref . tensor _load %input3 : memref < 10x f32 >
2020-07-07 12:51:24 +08:00
%4 = "mhlo.multiply" ( %2 , %3 ) { name = "multiply" } : ( tensor < 10x f32 > , tensor < 10x f32 > ) -> tensor < 10x f32 >
2021-03-17 04:31:59 +08:00
memref . tensor _store %4 , %out : memref < 10x f32 >
2020-07-09 01:05:32 +08:00
"lmhlo.terminator" ( ) : ( ) -> ( )
2020-07-07 07:28:26 +08:00
} ) : ( ) -> ( )
return
}
// -----
// CHECK-LABEL: func @case_memref
func @case_memref ( %index : memref < i32 > , %operand_1 : memref < f32 > , %operand_2 : memref < f32 > , %operand_3 : memref < f32 > , %out : memref < f32 > ) -> ( ) {
2021-03-12 06:41:50 +08:00
"lmhlo.case" ( %index ) ( {
^bb0 :
"lmhlo.negate" ( %operand_1 , %out ) : ( memref < f32 > , memref < f32 > ) -> ( )
2020-07-09 01:05:32 +08:00
"lmhlo.terminator" ( ) : ( ) -> ( )
2020-07-07 07:28:26 +08:00
} , {
2021-03-12 06:41:50 +08:00
^bb0 :
"lmhlo.copy" ( %operand_2 , %out ) : ( memref < f32 > , memref < f32 > ) -> ( )
2020-07-09 01:05:32 +08:00
"lmhlo.terminator" ( ) : ( ) -> ( )
2020-07-07 07:28:26 +08:00
} , {
2021-03-12 06:41:50 +08:00
^bb0 :
"lmhlo.add" ( %operand_3 , %operand_3 , %out ) : ( memref < f32 > , memref < f32 > , memref < f32 > ) -> ( )
2020-07-09 01:05:32 +08:00
"lmhlo.terminator" ( ) : ( ) -> ( )
2020-07-07 07:28:26 +08:00
}
2021-03-12 06:41:50 +08:00
) : ( memref < i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.atan2" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs ( %arg0 : memref < 1x complex< f32 > > , %arg1 : memref < 1x complex< f32 > > , %arg_out : memref < 1x complex< f32 > > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.atan2" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x complex< f32 > > , memref < 1x complex< f32 > > , memref < 1x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @atan2_memrefs ( %arg0 : memref < 1x i32 > , %arg1 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.atan2" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @bitcast_convert_memrefs
func @bitcast_convert_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.bitcast_convert" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @bitcast_convert_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 2x i32 > ) -> ( ) {
// expected-error@+1{{requires the same shape for all operands}}
2020-07-09 01:05:32 +08:00
"lmhlo.bitcast_convert" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 2x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @clz_memrefs
func @clz_memrefs ( %arg0 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.count_leading_zeros" ( %arg0 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.exponential_minus_one" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs ( %arg0 : memref < 1x complex< f32 > > , %arg_out : memref < 1x complex< f32 > > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.exponential_minus_one" ( %arg0 , %arg_out ) : ( memref < 1x complex< f32 > > , memref < 1x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @floor_memrefs
func @floor_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.floor" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @floor_memrefs ( %arg0 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
// expected-error@+1{{must be memref of floating-point values}}
2020-07-09 01:05:32 +08:00
"lmhlo.floor" ( %arg0 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @imag_memrefs
func @imag_memrefs ( %arg0 : memref < 1x complex< f32 > > , %arg_out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.imag" ( %arg0 , %arg_out ) : ( memref < 1x complex< f32 > > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @imag_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
// expected-error@+1{{must be memref of complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.imag" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @real_memrefs
func @real_memrefs ( %arg0 : memref < 1x complex< f32 > > , %arg_out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.real" ( %arg0 , %arg_out ) : ( memref < 1x complex< f32 > > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @real_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
// expected-error@+1{{must be memref of complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.real" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @is_finite_memrefs
func @is_finite_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x i1 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.is_finite" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.log_plus_one" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs ( %arg0 : memref < 1x complex< f32 > > , %arg_out : memref < 1x complex< f32 > > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.log_plus_one" ( %arg0 , %arg_out ) : ( memref < 1x complex< f32 > > , memref < 1x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @log1p_memref ( %in : memref < 10x i32 > , %out : memref < 10x i32 > ) -> ( ) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
2020-07-09 01:05:32 +08:00
"lmhlo.log_plus_one" ( %in , %out ) : ( memref < 10x i32 > , memref < 10x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @not_memrefs
func @not_memrefs ( %arg0 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.not" ( %arg0 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @not_memrefs
func @not_memrefs ( %arg0 : memref < 1x i1 > , %arg_out : memref < 1x i1 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.not" ( %arg0 , %arg_out ) : ( memref < 1x i1 > , memref < 1x i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @not_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
2020-07-09 01:05:32 +08:00
"lmhlo.not" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @popcnt_memrefs
func @popcnt_memrefs ( %arg0 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.popcnt" ( %arg0 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @popcnt_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
2020-07-09 01:05:32 +08:00
"lmhlo.popcnt" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @reduce_precision_memrefs
func @reduce_precision_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.reduce_precision" ( %arg0 , %arg_out ) { exponent_bits = 4 : i32 , mantissa_bits = 4 : i32 } : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @round_memrefs
func @round_memrefs ( %arg0 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.round_nearest_afz" ( %arg0 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @round_memrefs ( %arg0 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
// expected-error@+1{{must be memref of floating-point values}}
2020-07-09 01:05:32 +08:00
"lmhlo.round_nearest_afz" ( %arg0 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @shift_left_memrefs
func @shift_left_memrefs ( %arg0 : memref < 1x i32 > , %arg1 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.shift_left" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @shift_left_memrefs ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
2020-07-09 01:05:32 +08:00
"lmhlo.shift_left" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @shift_right_arithmetic_memrefs
func @shift_right_arithmetic_memrefs ( %arg0 : memref < 1x i32 > , %arg1 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.shift_right_arithmetic" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @shift_right_arithmetic_memrefs ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
2020-07-09 01:05:32 +08:00
"lmhlo.shift_right_arithmetic" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @shift_right_logical_memrefs
func @shift_right_logical_memrefs ( %arg0 : memref < 1x i32 > , %arg1 : memref < 1x i32 > , %arg_out : memref < 1x i32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.shift_right_logical" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x i32 > , memref < 1x i32 > , memref < 1x i32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
func @shift_right_logical_memrefs ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > , %arg_out : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
2020-07-09 01:05:32 +08:00
"lmhlo.shift_right_logical" ( %arg0 , %arg1 , %arg_out ) : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @all_reduce_memrefs
func @all_reduce_memrefs ( %arg0 : memref < 10x f32 > , %arg_out : memref < 10x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.all_reduce" ( %arg0 , %arg_out ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %lhs : tensor < f32 > , %rhs : tensor < f32 > ) :
2020-07-07 12:51:24 +08:00
%max = mhlo. maximum %lhs , %rhs : tensor < f32 >
"mhlo.return" ( %max ) : ( tensor < f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
} )
{ replica_groups = dense< [ [ 0 , 2 , 4 , 6 ] , [ 1 , 3 , 5 , 7 ] ] > : tensor < 2x4x i64 > } : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
2020-07-09 01:05:32 +08:00
"lmhlo.all_reduce" ( %arg0 , %arg_out ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %lhs : tensor < f32 > , %rhs : tensor < f32 > ) :
2020-07-07 12:51:24 +08:00
%max = mhlo. maximum %lhs , %rhs : tensor < f32 >
"mhlo.return" ( %max ) : ( tensor < f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
} )
{
replica_groups = dense< [ [ 0 , 2 , 4 , 6 ] , [ 1 , 3 , 5 , 7 ] ] > : tensor < 2x4x i64 > ,
channel_id = { handle = 5 : i64 , type = 2 : i64 } ,
constrain_layout = true,
use_global_device_ids = true
} : ( memref < 10x f32 > , memref < 10x f32 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @collective_permute_memrefs
func @collective_permute_memrefs ( %arg0 : memref < 128x32x f32 > , %arg_out : memref < 128x32x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.collective_permute" ( %arg0 , %arg_out ) {
2020-07-07 07:28:26 +08:00
source_target_pairs = dense< [ [ 0 , 1 ] , [ 1 , 2 ] , [ 2 , 3 ] ] > : tensor < 3x2x i64 >
} : ( memref < 128x32x f32 > , memref < 128x32x f32 > ) -> ( )
2020-07-09 01:05:32 +08:00
"lmhlo.collective_permute" ( %arg0 , %arg_out ) {
2020-07-07 07:28:26 +08:00
source_target_pairs = dense< [ [ 0 , 1 ] , [ 1 , 2 ] , [ 2 , 3 ] ] > : tensor < 3x2x i64 > ,
channel_id = { handle = 5 : i64 , type = 2 : i64 }
} : ( memref < 128x32x f32 > , memref < 128x32x f32 > ) -> ( )
return
}
// -----
2021-03-11 07:36:22 +08:00
func @invalid_collective_permute ( %arg0 : memref < 128x32x f32 > , %arg_out : memref < 128x32x f32 > ) -> ( ) {
// expected-error@+1{{expect source_target_pairs attribute of shape (N, 2), but got (1, 3)}}
"lmhlo.collective_permute" ( %arg0 , %arg_out ) {
source_target_pairs = dense< [ [ 2 , 3 , 4 ] ] > : tensor < 1x3x i64 >
} : ( memref < 128x32x f32 > , memref < 128x32x f32 > ) -> ( )
return
}
// -----
func @invalid_collective_permute ( %arg0 : memref < 128x32x f32 > , %arg_out : memref < 128x32x f32 > ) -> ( ) {
// expected-error@+1{{duplicate sources not allowed.}}
"lmhlo.collective_permute" ( %arg0 , %arg_out ) {
source_target_pairs = dense< [ [ 1 , 2 ] , [ 1 , 3 ] ] > : tensor < 2x2x i64 >
} : ( memref < 128x32x f32 > , memref < 128x32x f32 > ) -> ( )
return
}
// -----
func @invalid_collective_permute ( %arg0 : memref < 128x32x f32 > , %arg_out : memref < 128x32x f32 > ) -> ( ) {
// expected-error@+1{{duplicate targets not allowed.}}
"lmhlo.collective_permute" ( %arg0 , %arg_out ) {
source_target_pairs = dense< [ [ 1 , 2 ] , [ 0 , 2 ] ] > : tensor < 2x2x i64 >
} : ( memref < 128x32x f32 > , memref < 128x32x f32 > ) -> ( )
return
}
// -----
2020-07-07 07:28:26 +08:00
// CHECK-LABEL: func @fft_memrefs
func @fft_memrefs ( %arg0 : memref < 3x9x f32 > , %arg_out : memref < 3x5x complex< f32 > > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.fft" ( %arg0 , %arg_out ) { fft_length = dense< 9 > : tensor < 1x i64 > , fft_type = "RFFT" } : ( memref < 3x9x f32 > , memref < 3x5x complex< f32 > > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @batch_norm_grad_memrefs
func @batch_norm_grad_memrefs ( %arg0 : memref < 8x8x8x8x f32 > , %arg1 : memref < 8x f32 > , %arg2 : memref < 8x f32 > ,
%arg3 : memref < 8x f32 > , %arg4 : memref < 8x8x8x8x f32 > ,
%grad_operand : memref < 8x8x8x8x f32 > , %grad_scale : memref < 8x f32 > ,
%grad_offset : memref < 8x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.batch_norm_grad" ( %arg0 , %arg1 , %arg2 , %arg3 , %arg4 , %grad_operand , %grad_scale , %grad_offset ) { epsilon = 1.000000e-03 : f32 , feature_index = 3 : i64 }
2020-07-07 07:28:26 +08:00
: ( memref < 8x8x8x8x f32 > , memref < 8x f32 > , memref < 8x f32 > , memref < 8x f32 > , memref < 8x8x8x8x f32 > ,
memref < 8x8x8x8x f32 > , memref < 8x f32 > , memref < 8x f32 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs ( %arg0 : memref < 8x8x8x8x f32 > , %arg1 : memref < 8x f32 > , %arg2 : memref < 8x f32 > ,
%arg3 : memref < 8x f32 > , %arg4 : memref < 8x f32 > , %arg_out : memref < 8x8x8x8x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.batch_norm_inference" ( %arg0 , %arg1 , %arg2 , %arg3 , %arg4 , %arg_out ) { epsilon = 1.000000e-03 : f32 , feature_index = 3 : i64 }
2020-07-07 07:28:26 +08:00
: ( memref < 8x8x8x8x f32 > , memref < 8x f32 > , memref < 8x f32 > , memref < 8x f32 > , memref < 8x f32 > , memref < 8x8x8x8x f32 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @batch_norm_training_memrefs
func @batch_norm_training_memrefs ( %arg0 : memref < 8x8x8x8x f32 > , %arg1 : memref < 8x f32 > , %arg2 : memref < 8x f32 > ,
%output : memref < 8x8x8x8x f32 > , %batch_mean : memref < 8x f32 > ,
%batch_var : memref < 8x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.batch_norm_training" ( %arg0 , %arg1 , %arg2 , %output , %batch_mean , %batch_var ) { epsilon = 1.000000e-03 : f32 , feature_index = 3 : i64 }
2020-07-07 07:28:26 +08:00
: ( memref < 8x8x8x8x f32 > , memref < 8x f32 > , memref < 8x f32 > , memref < 8x8x8x8x f32 > , memref < 8x f32 > , memref < 8x f32 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @cholesky_memrefs
func @cholesky_memrefs ( %arg0 : memref < 1x291x291x f32 > , %arg_out : memref < 1x291x291x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.cholesky" ( %arg0 , %arg_out ) : ( memref < 1x291x291x f32 > , memref < 1x291x291x f32 > ) -> ( )
"lmhlo.cholesky" ( %arg0 , %arg_out ) { lower = true } : ( memref < 1x291x291x f32 > , memref < 1x291x291x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @infeed_memrefs
func @infeed_memrefs ( %arg_out : memref < 3x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.infeed" ( %arg_out ) { config = "x" } : ( memref < 3x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @outfeed_memrefs
func @outfeed_memrefs ( %arg0 : memref < 3x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.outfeed" ( %arg0 ) { config = "x" } : ( memref < 3x f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @replica_id_memrefs
func @replica_id_memrefs ( %arg_out : memref < ui32> ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.replica_id" ( %arg_out ) : ( memref < ui32> ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs ( %arg0 : memref < 4x4x f32 > , %arg1 : memref < 3x4x f32 > , %arg_out : memref < 3x4x f32 > ) -> ( ) {
2021-02-06 01:16:49 +08:00
"lmhlo.triangular_solve" ( %arg0 , %arg1 , %arg_out )
{ layout_a = dense< [ 1 , 0 ] > : tensor < 2x index > ,
layout_b = dense< [ 1 , 0 ] > : tensor < 2x index > ,
layout_output = dense< [ 1 , 0 ] > : tensor < 2x index > ,
left_side = true, lower = true, transpose_a = "NO_TRANSPOSE" ,
unit_diagonal = true}
2020-07-07 07:28:26 +08:00
: ( memref < 4x4x f32 > , memref < 3x4x f32 > , memref < 3x4x f32 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @while_memrefs
2021-03-12 06:41:50 +08:00
func @while_memrefs ( %arg0 : memref < i64 > , %arg_out : memref < i64 > , %cond : memref < i1 > ) -> ( ) {
"lmhlo.while" ( %cond ) (
{ ^bb0 : "lmhlo.terminator" ( ) : ( ) -> ( ) } ,
{ ^bb0 : "lmhlo.terminator" ( ) : ( ) -> ( ) }
) : ( memref < i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @while_memrefs
2021-03-12 06:41:50 +08:00
func @while_memrefs ( %arg0 : memref < i64 > , %arg1 : memref < 5x f32 > , %arg0_out : memref < i64 > , %arg1_out : memref < 5x f32 > , %cond : memref < i1 > ) -> ( ) {
"lmhlo.while" ( %cond ) (
{ ^bb0 : "lmhlo.terminator" ( ) : ( ) -> ( ) } ,
{ ^bb0 : "lmhlo.terminator" ( ) : ( ) -> ( ) }
) : ( memref < i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @scatter_memrefs
func @scatter_memrefs ( %input : memref < 200x100x300x f32 > , %indices : memref < 10x2x i32 > ,
%updates : memref < 10x300x f32 > , %arg_out : memref < 200x100x300x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.scatter" ( %input , %indices , %updates , %arg_out ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %lhs : tensor < f32 > , %rhs : tensor < f32 > ) : // no predecessors
2020-07-07 12:51:24 +08:00
%add = mhlo. add %lhs , %rhs : tensor < f32 >
"mhlo.return" ( %add ) : ( tensor < f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
} ) {
scatter_dimension_numbers = {
update_window_dims = dense< [ 1 ] > : tensor < 1x i64 > ,
inserted_window_dims = dense< [ 0 , 1 ] > : tensor < 2x i64 > ,
scatter_dims_to_operand_dims = dense< [ 0 , 1 ] > : tensor < 2x i64 > ,
index_vector_dim = 1 : i64
} ,
indices_are_sorted = true,
unique_indices = true
} : ( memref < 200x100x300x f32 > , memref < 10x2x i32 > , memref < 10x300x f32 > , memref < 200x100x300x f32 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @map_memrefs
func @map_memrefs ( %arg0 : memref < 20x f32 > , %arg1 : memref < 20x f32 > , %arg_out : memref < 20x f32 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.map" ( %arg0 , %arg1 , %arg_out ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %a : tensor < f32 > , %b : tensor < f32 > ) :
2020-07-07 12:51:24 +08:00
%c = mhlo. add %a , %b : tensor < f32 >
"mhlo.return" ( %c ) : ( tensor < f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
} ) { dimensions = dense< 0 > : tensor < 1x i64 > } : ( memref < 20x f32 > , memref < 20x f32 > , memref < 20x f32 > ) -> ( )
return
}
// -----
func @map_memrefs ( %arg0 : memref < 20x f32 > , %arg1 : memref < 20x f32 > , %arg_out : memref < 10x f32 > ) -> ( ) {
// expected-error@+1{{requires the same shape for all operands}}
2020-07-09 01:05:32 +08:00
"lmhlo.map" ( %arg0 , %arg1 , %arg_out ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %a : tensor < f32 > , %b : tensor < f32 > ) :
2020-07-07 12:51:24 +08:00
%c = mhlo. add %a , %b : tensor < f32 >
"mhlo.return" ( %c ) : ( tensor < f32 > ) -> ( )
2020-07-07 07:28:26 +08:00
} ) { dimensions = dense< 0 > : tensor < 1x i64 > } : ( memref < 20x f32 > , memref < 20x f32 > , memref < 10x f32 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @rng_get_and_update_state_memrefs
func @rng_get_and_update_state_memrefs ( %state : memref < 1x ui64> ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.rng_get_and_update_state" ( %state ) { delta = 1 : i64 } : ( memref < 1x ui64> ) -> ( )
2020-07-07 07:28:26 +08:00
return
}
// -----
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs ( %arg0 : memref < 16x16x f32 > , %arg1 : memref < 16x16x f16 > ,
%out0 : memref < 16x16x f32 > , %out1 : memref < 16x16x f16 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.sort" ( %arg0 , %arg1 , %out0 , %out1 ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %a : tensor < f32 > , %b : tensor < f32 > , %c : tensor < f16 > , %d : tensor < f16 > ) :
2020-07-07 12:51:24 +08:00
%7 = "mhlo.compare" ( %a , %b ) { comparison_direction = "GT" } : ( tensor < f32 > , tensor < f32 > ) -> tensor < i1 >
"mhlo.return" ( %7 ) : ( tensor < i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
} ) { dimension = 1 : i64 , is_stable = true} : ( memref < 16x16x f32 > , memref < 16x16x f16 > , memref < 16x16x f32 > , memref < 16x16x f16 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs ( %arg0 : memref < 16x16x f32 > , %arg1 : memref < 16x16x f16 > ,
%out0 : memref < 16x16x f32 > , %out1 : memref < 16x16x f16 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.sort" ( %arg0 , %arg1 , %out0 , %out1 ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %a : tensor < f32 > , %b : tensor < f32 > , %c : tensor < f16 > , %d : tensor < f16 > ) :
2020-07-07 12:51:24 +08:00
%7 = "mhlo.compare" ( %a , %b ) { comparison_direction = "GT" } : ( tensor < f32 > , tensor < f32 > ) -> tensor < i1 >
"mhlo.return" ( %7 ) : ( tensor < i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
} ) { dimension = 1 : i64 } : ( memref < 16x16x f32 > , memref < 16x16x f16 > , memref < 16x16x f32 > , memref < 16x16x f16 > ) -> ( )
return
}
// -----
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs ( %arg0 : memref < 16x16x f32 > , %arg1 : memref < 16x16x f16 > ,
%out0 : memref < 16x16x f32 > , %out1 : memref < 16x16x f16 > ) -> ( ) {
2020-07-09 01:05:32 +08:00
"lmhlo.sort" ( %arg0 , %arg1 , %out0 , %out1 ) ( {
2020-07-07 07:28:26 +08:00
^bb0 ( %a : tensor < f32 > , %b : tensor < f32 > , %c : tensor < f16 > , %d : tensor < f16 > ) :
2020-07-07 12:51:24 +08:00
%7 = "mhlo.compare" ( %a , %b ) { comparison_direction = "GT" } : ( tensor < f32 > , tensor < f32 > ) -> tensor < i1 >
"mhlo.return" ( %7 ) : ( tensor < i1 > ) -> ( )
2020-07-07 07:28:26 +08:00
} ) : ( memref < 16x16x f32 > , memref < 16x16x f16 > , memref < 16x16x f32 > , memref < 16x16x f16 > ) -> ( )
return
}
2021-02-23 00:41:59 +08:00
// -----
// CHECK-LABEL: func @valid_custom_call
func @valid_custom_call ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > ) -> ( ) {
"lmhlo.custom_call" ( %arg0 , %arg0 , %arg1 , %arg1 ) {
backend_config = "" ,
call_target_name = "foo" ,
has_side_effects = false,
operand_segment_sizes = dense< 2 > : vector < 2x i32 > ,
target_arg_mapping = {
num_args = 4 : i64 ,
num_results = 3 : i64 ,
args_to_target_args = [ 0 , 3 ] ,
results_to_target_results = [ 1 , 2 ]
}
} : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
return
}
// -----
func @invalid_custom_call ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{number of entries in the mapping for args (1) should match the number of args for the operation (2)}}
"lmhlo.custom_call" ( %arg0 , %arg0 , %arg1 , %arg1 ) {
backend_config = "" ,
call_target_name = "foo" ,
has_side_effects = false,
operand_segment_sizes = dense< 2 > : vector < 2x i32 > ,
target_arg_mapping = {
num_args = 4 : i64 ,
num_results = 3 : i64 ,
args_to_target_args = [ 0 ] ,
results_to_target_results = [ 1 , 2 ]
}
} : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
return
}
// -----
func @invalid_custom_call ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{number of entries in the mapping for results (1) should match the number of results for the operation (2)}}
"lmhlo.custom_call" ( %arg0 , %arg0 , %arg1 , %arg1 ) {
backend_config = "" ,
call_target_name = "foo" ,
has_side_effects = false,
operand_segment_sizes = dense< 2 > : vector < 2x i32 > ,
target_arg_mapping = {
num_args = 4 : i64 ,
num_results = 3 : i64 ,
args_to_target_args = [ 0 , 3 ] ,
results_to_target_results = [ 1 ]
}
} : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
return
}
// -----
func @invalid_custom_call ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{entry 0 cannot appear more than once in the mapping for args}}
"lmhlo.custom_call" ( %arg0 , %arg0 , %arg1 , %arg1 ) {
backend_config = "" ,
call_target_name = "foo" ,
has_side_effects = false,
operand_segment_sizes = dense< 2 > : vector < 2x i32 > ,
target_arg_mapping = {
num_args = 4 : i64 ,
num_results = 3 : i64 ,
args_to_target_args = [ 0 , 0 ] ,
results_to_target_results = [ 1 , 2 ]
}
} : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
return
}
// -----
func @invalid_custom_call ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{entry 1 cannot appear more than once in the mapping for results}}
"lmhlo.custom_call" ( %arg0 , %arg0 , %arg1 , %arg1 ) {
backend_config = "" ,
call_target_name = "foo" ,
has_side_effects = false,
operand_segment_sizes = dense< 2 > : vector < 2x i32 > ,
target_arg_mapping = {
num_args = 4 : i64 ,
num_results = 3 : i64 ,
args_to_target_args = [ 0 , 1 ] ,
results_to_target_results = [ 1 , 1 ]
}
} : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
return
}
// -----
func @invalid_custom_call ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{entries in mapping for args must be >= 0 and less than target's number of args (4)}}
"lmhlo.custom_call" ( %arg0 , %arg0 , %arg1 , %arg1 ) {
backend_config = "" ,
call_target_name = "foo" ,
has_side_effects = false,
operand_segment_sizes = dense< 2 > : vector < 2x i32 > ,
target_arg_mapping = {
num_args = 4 : i64 ,
num_results = 3 : i64 ,
args_to_target_args = [ 0 , 6 ] ,
results_to_target_results = [ 1 , 2 ]
}
} : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
return
}
// -----
func @invalid_custom_call ( %arg0 : memref < 1x f32 > , %arg1 : memref < 1x f32 > ) -> ( ) {
// expected-error @+1 {{entries in mapping for results must be >= 0 and less than target's number of results (3)}}
"lmhlo.custom_call" ( %arg0 , %arg0 , %arg1 , %arg1 ) {
backend_config = "" ,
call_target_name = "foo" ,
has_side_effects = false,
operand_segment_sizes = dense< 2 > : vector < 2x i32 > ,
target_arg_mapping = {
num_args = 4 : i64 ,
num_results = 3 : i64 ,
args_to_target_args = [ 0 , 1 ] ,
results_to_target_results = [ 1 , 3 ]
}
} : ( memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > , memref < 1x f32 > ) -> ( )
return
}
2021-04-26 20:42:39 +08:00
// -----
func @invalid_complex_abs_call ( %input : memref < 2x complex< f32 > > , %result : memref < 2x complex< f32 > > ) -> ( ) {
// expected-error @+1 {{requires output type to be the same as the element type of the input}}
"lmhlo.abs" ( %input , %result )
: ( memref < 2x complex< f32 > > , memref < 2x complex< f32 > > ) -> ( )
return
}
// -----
func @invalid_float_abs_call ( %input : memref < 2x f32 > , %result : memref < 2x f64 > ) -> ( ) {
// expected-error @+1 {{requires all operands to have the same type}}
"lmhlo.abs" ( %input , %result ) : ( memref < 2x f32 > , memref < 2x f64 > ) -> ( )
return
}