Prepare to remove tensor_load and tensor_store special handling from hlo to lhlo legalization.
This updates the tests to no longer rely on tensor_store. Once all users of this behavior have adopted, the tensor_store support will be removed. PiperOrigin-RevId: 348624899
This commit is contained in:
		
							parent
							
								
									c4accdcc41
								
							
						
					
					
						commit
						ccdd07f8e4
					
				| 
						 | 
				
			
			@ -220,6 +220,7 @@ struct HloToLhloDynamicReshapeConverter
 | 
			
		|||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// TODO(b/175670649) Fix this to no longer access original tensor operands.
 | 
			
		||||
class HloToLhloDynamicBroadcastInDimOpConverter
 | 
			
		||||
    : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
 | 
			
		||||
 public:
 | 
			
		||||
| 
						 | 
				
			
			@ -462,19 +463,8 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
 | 
			
		|||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class HloToLhloTensorLoadOpConverter
 | 
			
		||||
    : public BaseOpConversion<mlir::TensorLoadOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseOpConversion<mlir::TensorLoadOp>::BaseOpConversion;
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      mlir::TensorLoadOp op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    rewriter.replaceOp(op, operands);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class HloToLhloTensorStoreOpConverter
 | 
			
		||||
// TODO(b/175789537) Remove this pattern.
 | 
			
		||||
class HloToLhloTensorStoreOpLegacyConverter
 | 
			
		||||
    : public BaseOpConversion<mlir::TensorStoreOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
 | 
			
		||||
| 
						 | 
				
			
			@ -569,9 +559,13 @@ struct HloLegalizeToLhlo
 | 
			
		|||
    target.addLegalDialect<lmhlo::LmhloDialect>();
 | 
			
		||||
    target.addLegalDialect<StandardOpsDialect>();
 | 
			
		||||
    target.addLegalDialect<tensor::TensorDialect>();
 | 
			
		||||
    target.addIllegalOp<mlir::TensorLoadOp>();
 | 
			
		||||
    target.addIllegalOp<mlir::TensorStoreOp>();
 | 
			
		||||
    target.addIllegalDialect<mhlo::MhloDialect>();
 | 
			
		||||
    // Declare tensor_load and tensor_store illegal.
 | 
			
		||||
    target.addIllegalOp<mlir::TensorLoadOp, mlir::TensorStoreOp>();
 | 
			
		||||
    // tensor_to_memref is illegal if it has uses.
 | 
			
		||||
    // TODO(b/175670649) Make tensor_to_memref illegal.
 | 
			
		||||
    target.addDynamicallyLegalOp<mlir::TensorToMemrefOp>(
 | 
			
		||||
        [](auto op) { return op->use_empty(); });
 | 
			
		||||
 | 
			
		||||
    BufferizeTypeConverter converter;
 | 
			
		||||
    auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
 | 
			
		||||
| 
						 | 
				
			
			@ -596,9 +590,15 @@ struct HloLegalizeToLhlo
 | 
			
		|||
    populateCallOpTypeConversionPattern(patterns, &context, converter);
 | 
			
		||||
    populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
 | 
			
		||||
        patterns, &context, converter);
 | 
			
		||||
    populateEliminateBufferizeMaterializationsPatterns(&context, converter,
 | 
			
		||||
                                                       patterns);
 | 
			
		||||
 | 
			
		||||
    populateShapeStructuralTypeConversionsAndLegality(&context, converter,
 | 
			
		||||
                                                      patterns, target);
 | 
			
		||||
 | 
			
		||||
    // TODO(b/175789537) Remove this pattern.
 | 
			
		||||
    patterns.insert<HloToLhloTensorStoreOpLegacyConverter>(&context);
 | 
			
		||||
 | 
			
		||||
    if (failed(applyPartialConversion(getOperation(), target,
 | 
			
		||||
                                      std::move(patterns))))
 | 
			
		||||
      signalPassFailure();
 | 
			
		||||
| 
						 | 
				
			
			@ -668,9 +668,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
 | 
			
		|||
      HloToLhloOpConverter<mhlo::TransposeOp>,
 | 
			
		||||
      HloToLhloOpConverter<mhlo::XorOp>,
 | 
			
		||||
      HloToLhloReduceOpConverter,
 | 
			
		||||
      HloToLhloReturnOpConverter,
 | 
			
		||||
      HloToLhloTensorLoadOpConverter,
 | 
			
		||||
      HloToLhloTensorStoreOpConverter
 | 
			
		||||
      HloToLhloReturnOpConverter
 | 
			
		||||
  >(*converter, context);
 | 
			
		||||
  // clang-format on
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,8 +45,7 @@ func @trivial_broadcast_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x4xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -63,8 +62,7 @@ func @trivial_broadcast_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -97,8 +95,7 @@ func @broadcast_in_X_dim_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x4xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x4xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -114,8 +111,7 @@ func @broadcast_in_X_dim_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -145,8 +141,7 @@ func @broadcast_in_Y_dim_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<3x1xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x4xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -163,8 +158,7 @@ func @broadcast_in_Y_dim_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -197,8 +191,7 @@ func @broadcast_in_X_dim_transpose_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<4x1xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x4xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -214,8 +207,7 @@ func @broadcast_in_X_dim_transpose_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -245,8 +237,7 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x3xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x4xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -263,8 +254,7 @@ func @broadcast_in_Y_dim_transpose_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -288,8 +278,7 @@ func @broadcast_scalar_1d_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<1xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x4xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -306,8 +295,7 @@ func @broadcast_scalar_1d_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -331,8 +319,7 @@ func @broadcast_scalar_2d_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x1xf32>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x4xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -349,8 +336,7 @@ func @broadcast_scalar_2d_wrapper() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x4xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x4xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x4xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -386,8 +372,7 @@ func @broadcast_to_the_same_shape() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<2x3xf32>) -> tensor<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<2x3xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<2x3xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unraked_output = memref_cast %output_buf : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -401,8 +386,7 @@ func @broadcast_to_the_same_shape() {
 | 
			
		|||
    broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<2x3xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<2x3xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<2x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<2x3xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -433,8 +417,7 @@ func @broadcast_1d_to_2d() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>) -> tensor<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x3xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x3xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -451,8 +434,7 @@ func @broadcast_1d_to_2d() {
 | 
			
		|||
    broadcast_dimensions = dense<0> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x3xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x3xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			@ -484,8 +466,7 @@ func @broadcast_1d_to_2d_with_transpose() {
 | 
			
		|||
    broadcast_dimensions = dense<1> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>) -> tensor<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %output_buf = alloc() : memref<3x3xf32>
 | 
			
		||||
  tensor_store %output, %output_buf : memref<3x3xf32>
 | 
			
		||||
  %output_buf = tensor_to_memref %output : memref<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
| 
						 | 
				
			
			@ -501,8 +482,7 @@ func @broadcast_1d_to_2d_with_transpose() {
 | 
			
		|||
    broadcast_dimensions = dense<1> : tensor<1xi64>
 | 
			
		||||
  } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %dyn_output_buf = alloc() : memref<3x3xf32>
 | 
			
		||||
  tensor_store %dyn_output, %dyn_output_buf : memref<3x3xf32>
 | 
			
		||||
  %dyn_output_buf = tensor_to_memref %dyn_output : memref<3x3xf32>
 | 
			
		||||
 | 
			
		||||
  %unranked_dyn_output = memref_cast %dyn_output_buf
 | 
			
		||||
    : memref<3x3xf32> to memref<*xf32>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,8 +45,7 @@ func @reduce_add() {
 | 
			
		|||
  }) {dimensions = dense<1> : tensor<1xi64>}
 | 
			
		||||
      : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
 | 
			
		||||
 | 
			
		||||
  %output = alloc() : memref<2xf32>
 | 
			
		||||
  tensor_store %reduce, %output : memref<2xf32>
 | 
			
		||||
  %output = tensor_to_memref %reduce : memref<2xf32>
 | 
			
		||||
  %unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
 | 
			
		||||
| 
						 | 
				
			
			@ -83,8 +82,7 @@ func @reduce_max() {
 | 
			
		|||
  }) {dimensions = dense<1> : tensor<1xi64>}
 | 
			
		||||
      : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
 | 
			
		||||
 | 
			
		||||
  %output = alloc() : memref<2xf32>
 | 
			
		||||
  tensor_store %reduce, %output : memref<2xf32>
 | 
			
		||||
  %output = tensor_to_memref %reduce : memref<2xf32>
 | 
			
		||||
  %unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
 | 
			
		||||
  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
 | 
			
		||||
  // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,14 +3,12 @@
 | 
			
		|||
// RUN: | FILECHECK_OPTS="" FileCheck %s
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @attrs
 | 
			
		||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.exponential"(%tensor_operand)
 | 
			
		||||
func @attrs_copy(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.exponential"(%operand)
 | 
			
		||||
      {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -19,7 +17,6 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
 | 
			
		|||
  return %arg0 : tensor<4xf32>
 | 
			
		||||
}
 | 
			
		||||
//      CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
 | 
			
		||||
//  CHECK-NOT: "lmhlo.copy"
 | 
			
		||||
// CHECK-NEXT: return %[[ARG0]]
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -53,104 +50,81 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
 | 
			
		|||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @fusion
 | 
			
		||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
 | 
			
		||||
             %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
 | 
			
		||||
func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
 | 
			
		||||
             %summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
 | 
			
		||||
  // CHECK-NEXT:  %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
 | 
			
		||||
  %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
 | 
			
		||||
  %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
 | 
			
		||||
  %sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
 | 
			
		||||
  %sum = "mhlo.add"(%summand_1, %summand_2)
 | 
			
		||||
      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
 | 
			
		||||
  // CHECK-NEXT:  %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
 | 
			
		||||
  %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
 | 
			
		||||
  %result = "mhlo.multiply"(%sum, %multiplier)
 | 
			
		||||
      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
 | 
			
		||||
  // CHECK-NEXT:  dealloc %[[ADD_RESULT]] : memref<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT:  dealloc %[[MUL_RESULT]] : memref<2x2xf32>
 | 
			
		||||
  // CHECK-NEXT:  return
 | 
			
		||||
  return
 | 
			
		||||
  // CHECK-NEXT:  return %[[MUL_RESULT]] : memref<2x2xf32>
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @copy
 | 
			
		||||
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.copy"(%tensor_operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
func @copy(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.copy"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // TODO(herhut): An explicit copy should not be removed.
 | 
			
		||||
  // TODO-CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @exp
 | 
			
		||||
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.exponential"(%tensor_operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.exponential"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @log
 | 
			
		||||
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.log"(%tensor_operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.log"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @select
 | 
			
		||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
 | 
			
		||||
             %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_pred = tensor_load %pred : memref<2x2xi1>
 | 
			
		||||
  %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
 | 
			
		||||
  %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
 | 
			
		||||
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
 | 
			
		||||
             %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.select"(%pred, %lhs, %rhs)
 | 
			
		||||
      : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @compare
 | 
			
		||||
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
 | 
			
		||||
  %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
 | 
			
		||||
  %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
 | 
			
		||||
func @compare(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xi1> {
 | 
			
		||||
  %result = "mhlo.compare"(%lhs, %rhs)
 | 
			
		||||
      {comparison_direction = "EQ"}
 | 
			
		||||
      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
 | 
			
		||||
  // CHECK: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xi1>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xi1>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @broadcast
 | 
			
		||||
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<5xf32>
 | 
			
		||||
  %tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
 | 
			
		||||
func @broadcast(%operand: tensor<5xf32>) -> tensor<10x5xf32> {
 | 
			
		||||
  %result = "mhlo.broadcast_in_dim"(%operand)
 | 
			
		||||
      {broadcast_dimensions = dense<1> : tensor<1xi64>}
 | 
			
		||||
        : (tensor<5xf32>) -> tensor<10x5xf32>
 | 
			
		||||
  // CHECK: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<10x5xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<10x5xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -158,16 +132,14 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
 | 
			
		|||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @dyn_broadcast
 | 
			
		||||
func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
 | 
			
		||||
func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
 | 
			
		||||
  // CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<?x?xf32>
 | 
			
		||||
  %c1 = constant 1 : i64
 | 
			
		||||
  %shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
 | 
			
		||||
  %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
 | 
			
		||||
  %result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
 | 
			
		||||
    broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
 | 
			
		||||
  } : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
 | 
			
		||||
  %rank = rank %tensor_result : tensor<?x?x?xf32>
 | 
			
		||||
  return %rank : index
 | 
			
		||||
  return %result : tensor<?x?x?xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -196,82 +168,67 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
 | 
			
		|||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
 | 
			
		||||
 | 
			
		||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
 | 
			
		||||
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
 | 
			
		||||
// CHECK: return %[[RESULT]] : memref<?x?x?xf32>
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @complex
 | 
			
		||||
func @complex(%real: memref<2x2xf32>,
 | 
			
		||||
              %imag: memref<2x2xf32>,
 | 
			
		||||
              %result: memref<2x2xcomplex<f32>>) {
 | 
			
		||||
  %tensor_real = tensor_load %real : memref<2x2xf32>
 | 
			
		||||
  %tensor_imag = tensor_load %imag : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
 | 
			
		||||
func @complex(%real: tensor<2x2xf32>, %imag: tensor<2x2xf32>)
 | 
			
		||||
    -> tensor<2x2xcomplex<f32>> {
 | 
			
		||||
  %result = "mhlo.complex"(%real, %imag)
 | 
			
		||||
      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
 | 
			
		||||
  // CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xcomplex<f32>>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @complex_dyn
 | 
			
		||||
func @complex_dyn(%real: memref<?xf32>,
 | 
			
		||||
                  %imag: memref<?xf32>,
 | 
			
		||||
                  %result: memref<?xcomplex<f32>>) {
 | 
			
		||||
  %tensor_real = tensor_load %real : memref<?xf32>
 | 
			
		||||
  %tensor_imag = tensor_load %imag : memref<?xf32>
 | 
			
		||||
  %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
 | 
			
		||||
func @complex_dyn(%real: tensor<?xf32>, %imag: tensor<?xf32>)
 | 
			
		||||
    -> tensor<?xcomplex<f32>> {
 | 
			
		||||
  %result = "mhlo.complex"(%real, %imag)
 | 
			
		||||
      : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>>
 | 
			
		||||
  // CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<?xcomplex<f32>>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<?xcomplex<f32>>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @real
 | 
			
		||||
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
 | 
			
		||||
  %tensor_result = "mhlo.real"(%tensor_operand)
 | 
			
		||||
func @real(%operand: tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.real"(%operand)
 | 
			
		||||
      : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @real_dyn
 | 
			
		||||
func @real_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
 | 
			
		||||
  %tensor_result = "mhlo.real"(%tensor_operand)
 | 
			
		||||
func @real_dyn(%operand: tensor<?xcomplex<f32>>) -> tensor<?xf32> {
 | 
			
		||||
  %result = "mhlo.real"(%operand)
 | 
			
		||||
      : (tensor<?xcomplex<f32>>) -> tensor<?xf32>
 | 
			
		||||
  // CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<?xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<?xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @imag
 | 
			
		||||
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
 | 
			
		||||
  %tensor_result = "mhlo.imag"(%tensor_operand)
 | 
			
		||||
func @imag(%operand: tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.imag"(%operand)
 | 
			
		||||
      : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @gather
 | 
			
		||||
func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<13x7xf32>
 | 
			
		||||
  %tensor_idxs = tensor_load %idxs : memref<5xi32>
 | 
			
		||||
  %tensor_result =
 | 
			
		||||
    "mhlo.gather"(%tensor_operand, %tensor_idxs)
 | 
			
		||||
func @gather(%operand: tensor<13x7xf32>, %idxs: tensor<5xi32>)
 | 
			
		||||
    -> tensor<5x7xf32> {
 | 
			
		||||
  %result =
 | 
			
		||||
    "mhlo.gather"(%operand, %idxs)
 | 
			
		||||
      { dimension_numbers =
 | 
			
		||||
        { collapsed_slice_dims = dense<0> : tensor<1xi64>
 | 
			
		||||
        , index_vector_dim = 1 : i64
 | 
			
		||||
| 
						 | 
				
			
			@ -282,269 +239,222 @@ func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5
 | 
			
		|||
      , slice_sizes = dense<[1, 7]> : tensor<2xi64> }
 | 
			
		||||
      : (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
 | 
			
		||||
  // CHECK: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<5x7xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<5x7xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @imag_dyn
 | 
			
		||||
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
 | 
			
		||||
  %tensor_result = "mhlo.imag"(%tensor_operand)
 | 
			
		||||
func @imag_dyn(%operand: tensor<?xcomplex<f32>>) -> tensor<?xf32> {
 | 
			
		||||
  %result = "mhlo.imag"(%operand)
 | 
			
		||||
      : (tensor<?xcomplex<f32>>) -> tensor<?xf32>
 | 
			
		||||
  // CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<?xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<?xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @iota
 | 
			
		||||
func @iota(%result: memref<10xi32>) {
 | 
			
		||||
  %tensor_result = "mhlo.iota"()
 | 
			
		||||
// TODO(herhut): Dummy should not be required here.
 | 
			
		||||
func @iota(%dummy: tensor<?xf32>) -> tensor<10xi32> {
 | 
			
		||||
  %result = "mhlo.iota"()
 | 
			
		||||
      {iota_dimension = 0 : i64} : () -> tensor<10xi32>
 | 
			
		||||
  // CHECK: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<10xi32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<10xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @abs
 | 
			
		||||
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.abs"(%tensor_operand)
 | 
			
		||||
func @abs(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.abs"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.abs"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @and
 | 
			
		||||
func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
 | 
			
		||||
          %result: memref<2x2xi32>) {
 | 
			
		||||
  %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
 | 
			
		||||
  %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
 | 
			
		||||
  %tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
 | 
			
		||||
func @and(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
 | 
			
		||||
    -> tensor<2x2xi32> {
 | 
			
		||||
  %result = "mhlo.and"(%operand0, %operand1)
 | 
			
		||||
      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
			
		||||
  // CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @ceil
 | 
			
		||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.ceil"(%tensor_operand)
 | 
			
		||||
func @ceil(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.ceil"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.ceil"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @convert
 | 
			
		||||
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.convert"(%tensor_operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
 | 
			
		||||
  // CHECK-NOT: tensor_store
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
func @convert(%operand: tensor<2x2xf32>) -> tensor<2x2xi32> {
 | 
			
		||||
  %result = "mhlo.convert"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xi32>
 | 
			
		||||
  // CHECK: "lmhlo.convert"(%{{.*}}, %{{.*}})
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @cos
 | 
			
		||||
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.cosine"(%tensor_operand)
 | 
			
		||||
func @cos(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.cosine"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.cosine"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @floor
 | 
			
		||||
func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.floor"(%tensor_operand)
 | 
			
		||||
func @floor(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.floor"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.floor"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @neg
 | 
			
		||||
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.negate"(%tensor_operand)
 | 
			
		||||
func @neg(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.negate"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.negate"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @not
 | 
			
		||||
func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xi32>
 | 
			
		||||
  %tensor_result = "mhlo.not"(%tensor_operand)
 | 
			
		||||
func @not(%operand: tensor<2x2xi32>) -> tensor<2x2xi32> {
 | 
			
		||||
  %result = "mhlo.not"(%operand)
 | 
			
		||||
      : (tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
			
		||||
  // CHECK: "lmhlo.not"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @or
 | 
			
		||||
func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
 | 
			
		||||
         %result: memref<2x2xi32>) {
 | 
			
		||||
  %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
 | 
			
		||||
  %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
 | 
			
		||||
  %tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
 | 
			
		||||
func @or(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
 | 
			
		||||
    -> tensor<2x2xi32> {
 | 
			
		||||
  %result = "mhlo.or"(%operand0, %operand1)
 | 
			
		||||
      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
			
		||||
  // CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @rsqrt
 | 
			
		||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.rsqrt"(%tensor_operand)
 | 
			
		||||
func @rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.rsqrt"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @sign
 | 
			
		||||
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.sign"(%tensor_operand)
 | 
			
		||||
func @sign(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.sign"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.sign"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @sqrt
 | 
			
		||||
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.sqrt"(%tensor_operand)
 | 
			
		||||
func @sqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.sqrt"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @shift_left
 | 
			
		||||
func @shift_left(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
 | 
			
		||||
                 %result: memref<2x2xi32>) {
 | 
			
		||||
  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
 | 
			
		||||
  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
 | 
			
		||||
  %tensor_result = "mhlo.shift_left"(%tensor_lhs, %tensor_rhs)
 | 
			
		||||
func @shift_left(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
 | 
			
		||||
    -> tensor<2x2xi32> {
 | 
			
		||||
  %result = "mhlo.shift_left"(%lhs, %rhs)
 | 
			
		||||
      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
			
		||||
  // CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @shift_right_arithmetic
 | 
			
		||||
func @shift_right_arithmetic(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
 | 
			
		||||
                             %result: memref<2x2xi32>) {
 | 
			
		||||
  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
 | 
			
		||||
  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
 | 
			
		||||
  %tensor_result = "mhlo.shift_right_arithmetic"(%tensor_lhs, %tensor_rhs)
 | 
			
		||||
func @shift_right_arithmetic(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
 | 
			
		||||
    -> tensor<2x2xi32> {
 | 
			
		||||
  %result = "mhlo.shift_right_arithmetic"(%lhs, %rhs)
 | 
			
		||||
      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
			
		||||
  // CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @shift_right_logical
 | 
			
		||||
func @shift_right_logical(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
 | 
			
		||||
                          %result: memref<2x2xi32>) {
 | 
			
		||||
  %tensor_lhs = tensor_load %lhs : memref<2x2xi32>
 | 
			
		||||
  %tensor_rhs = tensor_load %rhs : memref<2x2xi32>
 | 
			
		||||
  %tensor_result = "mhlo.shift_right_logical"(%tensor_lhs, %tensor_rhs)
 | 
			
		||||
func @shift_right_logical(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
 | 
			
		||||
    -> tensor<2x2xi32> {
 | 
			
		||||
  %result = "mhlo.shift_right_logical"(%lhs, %rhs)
 | 
			
		||||
      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
			
		||||
  // CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @tanh
 | 
			
		||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.tanh"(%tensor_operand)
 | 
			
		||||
func @tanh(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.tanh"(%operand)
 | 
			
		||||
      : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.tanh"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @remainder
 | 
			
		||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
 | 
			
		||||
                %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
 | 
			
		||||
  %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
 | 
			
		||||
func @remainder(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
 | 
			
		||||
    -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.remainder"(%lhs, %rhs)
 | 
			
		||||
      : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @xor
 | 
			
		||||
func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
 | 
			
		||||
          %result: memref<2x2xi32>) {
 | 
			
		||||
  %tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
 | 
			
		||||
  %tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
 | 
			
		||||
  %tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
 | 
			
		||||
func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
 | 
			
		||||
    -> tensor<2x2xi32> {
 | 
			
		||||
  %result = "mhlo.xor"(%operand0, %operand1)
 | 
			
		||||
      : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
 | 
			
		||||
  // CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xi32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xi32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Dynamic shape binary element-wise operation.
 | 
			
		||||
// CHECK-LABEL: func @add_dyn
 | 
			
		||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
 | 
			
		||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
 | 
			
		||||
  %result = "mhlo.add"(%lhs, %rhs)
 | 
			
		||||
      : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
 | 
			
		||||
  // CHECK: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
| 
						 | 
				
			
			@ -560,14 +470,15 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
 | 
			
		|||
  // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
 | 
			
		||||
  // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
 | 
			
		||||
  // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<?x?xf32>
 | 
			
		||||
  // CHECK: return %[[RESULT]]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Dynamic shape unary element-wise operation.
 | 
			
		||||
// CHECK-LABEL: func @tanh_dyn
 | 
			
		||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
 | 
			
		||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
 | 
			
		||||
  %result = "mhlo.tanh"(%arg0)
 | 
			
		||||
      : (tensor<?x?xf32>) -> tensor<?x?xf32>
 | 
			
		||||
  // CHECK: %[[C0:.*]] = constant 0 : index
 | 
			
		||||
| 
						 | 
				
			
			@ -583,7 +494,8 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
 | 
			
		|||
  // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
 | 
			
		||||
  // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
 | 
			
		||||
  // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<?x?xf32>
 | 
			
		||||
  // CHECK: return %[[RESULT]]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -600,7 +512,8 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
 | 
			
		|||
//          rhs_contracting_dimensions = dense<0> : tensor<1xi64>}}
 | 
			
		||||
//        : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
 | 
			
		||||
  %dot = "mhlo.dot"(%arg0, %arg0)
 | 
			
		||||
          : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
 | 
			
		||||
          : (tensor<1024x1024xf32>, tensor<1024x1024xf32>)
 | 
			
		||||
              -> tensor<1024x1024xf32>
 | 
			
		||||
// CHECK: return %[[ALLOC]]
 | 
			
		||||
  return %dot : tensor<1024x1024xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -608,7 +521,8 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
 | 
			
		|||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @conv
 | 
			
		||||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
 | 
			
		||||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
 | 
			
		||||
    -> tensor<3x5x5x4xf32> {
 | 
			
		||||
  %c0 = constant 0 : index
 | 
			
		||||
  // CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
 | 
			
		||||
  // CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
 | 
			
		||||
| 
						 | 
				
			
			@ -663,63 +577,52 @@ func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
 | 
			
		|||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @transpose
 | 
			
		||||
func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
 | 
			
		||||
  %tensor_operand = tensor_load %operand : memref<2x2xf32>
 | 
			
		||||
  %tensor_result = "mhlo.transpose"(%tensor_operand) {permutation = dense<[1, 0]> : tensor<2xi64>}
 | 
			
		||||
func @transpose(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
 | 
			
		||||
  %result = "mhlo.transpose"(%operand) {permutation = dense<[1, 0]> : tensor<2xi64>}
 | 
			
		||||
              : (tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
 | 
			
		||||
  // CHECK-NOT: tensor_store
 | 
			
		||||
  tensor_store %tensor_result, %result : memref<2x2xf32>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<2x2xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @custom_call
 | 
			
		||||
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
 | 
			
		||||
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
 | 
			
		||||
  %arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
 | 
			
		||||
  %arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
 | 
			
		||||
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>)
 | 
			
		||||
func @custom_call(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> {
 | 
			
		||||
  // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>}
 | 
			
		||||
  %result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
 | 
			
		||||
  %result = "mhlo.custom_call"(%arg0, %arg1)
 | 
			
		||||
              {backend_config = "", call_target_name = "foo", has_side_effect = false}
 | 
			
		||||
              : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
 | 
			
		||||
  tensor_store %result_tensor, %result: memref<4x4xf16>
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @custom_call_multiout
 | 
			
		||||
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
 | 
			
		||||
func @custom_call_multiout(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
 | 
			
		||||
  %arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
 | 
			
		||||
  %arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
 | 
			
		||||
  // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
 | 
			
		||||
  %temp:2 = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
 | 
			
		||||
                   {backend_config = "", call_target_name = "foo", has_side_effect = false}
 | 
			
		||||
                   : (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
 | 
			
		||||
  %result_tensor = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
 | 
			
		||||
  tensor_store %result_tensor, %result: memref<4x4xf16>
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @isfinite
 | 
			
		||||
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
 | 
			
		||||
  %arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
 | 
			
		||||
  // CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
 | 
			
		||||
  %result_tensor = "mhlo.is_finite"(%arg0_tensor) : (tensor<2x2xf32>) -> tensor<2x2xi1>
 | 
			
		||||
  tensor_store %result_tensor, %result: memref<2x2xi1>
 | 
			
		||||
  return
 | 
			
		||||
  return %result : tensor<4x4xf16>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Test that assuming ops propagate memref types.
 | 
			
		||||
// CHECK-LABEL: func @shape_assuming_memref
 | 
			
		||||
func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
 | 
			
		||||
// CHECK-LABEL: func @custom_call_multiout
 | 
			
		||||
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>)
 | 
			
		||||
func @custom_call_multiout(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> {
 | 
			
		||||
  // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
 | 
			
		||||
  %temp:2 = "mhlo.custom_call"(%arg0, %arg1)
 | 
			
		||||
                   {backend_config = "", call_target_name = "foo", has_side_effect = false}
 | 
			
		||||
                   : (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
 | 
			
		||||
  %result = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
 | 
			
		||||
  return %result : tensor<4x4xf16>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @isfinite
 | 
			
		||||
func @isfinite(%arg0: tensor<2x2xf32>) -> tensor<2x2xi1> {
 | 
			
		||||
  // CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
 | 
			
		||||
  %result = "mhlo.is_finite"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xi1>
 | 
			
		||||
  return %result : tensor<2x2xi1>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Test that assuming ops propagate tensor types.
 | 
			
		||||
// CHECK-LABEL: func @shape_assuming_tensor
 | 
			
		||||
func @shape_assuming_tensor(%arg0: tensor<?xf16>) -> tensor<?xf16> {
 | 
			
		||||
  %0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
 | 
			
		||||
  %1 = shape.const_witness true
 | 
			
		||||
  // CHECK: shape.assuming %{{.*}} -> (memref<?xf16>)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue