[HLO] Adopt custom syntax for convolution dimensions and window attributes (HLO)
PiperOrigin-RevId: 374923250
This commit is contained in:
		
							parent
							
								
									fc88cf1ff4
								
							
						
					
					
						commit
						41f663ce47
					
				|  | @ -1409,6 +1409,15 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]> { | ||||||
|                                       [](bool v) { return v; }); |                                       [](bool v) { return v; }); | ||||||
|     } |     } | ||||||
|   }]; |   }]; | ||||||
|  | 
 | ||||||
|  |  let assemblyFormat = [{ | ||||||
|  |     `(`operands`)` | ||||||
|  |        `dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,` | ||||||
|  |        `window` `=` `{` custom<WindowAttributes>($window_strides, $padding, | ||||||
|  |                                                  $lhs_dilation, $rhs_dilation, | ||||||
|  |                                                  $window_reversal) `}` | ||||||
|  |        attr-dict `:` functional-type(operands, results) | ||||||
|  |   }]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> { | def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> { | ||||||
|  |  | ||||||
|  | @ -3460,6 +3460,9 @@ OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) { | ||||||
|   return DenseElementsAttr::get(base_type, results); |   return DenseElementsAttr::get(base_type, results); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | using mlir::hlo::parseWindowAttributes; | ||||||
|  | using mlir::hlo::printWindowAttributes; | ||||||
|  | 
 | ||||||
| }  // namespace mhlo
 | }  // namespace mhlo
 | ||||||
| }  // namespace mlir
 | }  // namespace mlir
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1440,3 +1440,29 @@ func @rng_uniform_invalid(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor< | ||||||
|   %0 = "mhlo.rng_uniform"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<7xi64>) -> tensor<?xf32> |   %0 = "mhlo.rng_uniform"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<7xi64>) -> tensor<?xf32> | ||||||
|   return |   return | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | // CHECK: func @conv2d_generic | ||||||
|  | // CHECK: mhlo.convolution | ||||||
|  | // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] | ||||||
|  | // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} | ||||||
|  | func @conv2d_generic(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { | ||||||
|  |   %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = | ||||||
|  |        {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, | ||||||
|  |        feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : | ||||||
|  |        (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> | ||||||
|  |   return %0 : tensor<1x8x8x16xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK: func @conv2d | ||||||
|  | // CHECK: mhlo.convolution | ||||||
|  | // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] | ||||||
|  | // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} | ||||||
|  | func @conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { | ||||||
|  |   %0 = mhlo.convolution(%arg0, %arg1) | ||||||
|  |          dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], | ||||||
|  |          window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} | ||||||
|  |          {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = ["DEFAULT", "DEFAULT"]} : | ||||||
|  |        (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> | ||||||
|  |   return %0 : tensor<1x8x8x16xf32> | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue