[MLIR] Use MLIR provided 0DTensorOf/1DTensorOf instead of custom constraints.
- Use MLIR provided constraints for HLO_ScalarIntTensor and HLO_DimensionTensor. - Update unit tests to expect new error messages. PiperOrigin-RevId: 333313131
This commit is contained in:
		
							parent
							
								
									7d01a60de8
								
							
						
					
					
						commit
						87d9e6951e
					
				| 
						 | 
					@ -45,9 +45,7 @@ def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">;
 | 
				
			||||||
def HLO_IntTensor : TensorOf<[HLO_Int]>;
 | 
					def HLO_IntTensor : TensorOf<[HLO_Int]>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Any integer tensor type with rank 0 (i.e. representing a single integer).
 | 
					// Any integer tensor type with rank 0 (i.e. representing a single integer).
 | 
				
			||||||
def HLO_ScalarIntTensor : ShapedContainerType<
 | 
					def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>;
 | 
				
			||||||
  [HLO_Int], And<[IsTensorTypePred, HasAnyRankOfPred<[0]>]>,
 | 
					 | 
				
			||||||
  "a 0-dim integer tensor">;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Any floating-point tensor types
 | 
					// Any floating-point tensor types
 | 
				
			||||||
def HLO_FpTensor : TensorOf<[AnyFloat]>;
 | 
					def HLO_FpTensor : TensorOf<[AnyFloat]>;
 | 
				
			||||||
| 
						 | 
					@ -67,10 +65,7 @@ def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
 | 
				
			||||||
def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
 | 
					def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Dynamic representation of a shape vector as a tensor.
 | 
					// Dynamic representation of a shape vector as a tensor.
 | 
				
			||||||
def HLO_DimensionTensor : ShapedContainerType<
 | 
					def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
 | 
				
			||||||
    [HLO_DimensionValue],
 | 
					 | 
				
			||||||
    And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
 | 
					 | 
				
			||||||
    "a 1D tensor of dimensions">;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// In general, static shaped tensor constraints should be avoided unless
 | 
					// In general, static shaped tensor constraints should be avoided unless
 | 
				
			||||||
// it is for a legacy op which is only correct with static shapes.
 | 
					// it is for a legacy op which is only correct with static shapes.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -739,7 +739,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
 | 
					func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
 | 
				
			||||||
  // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
 | 
					  // expected-error@+1 {{operand #1 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
 | 
				
			||||||
  %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
 | 
					  %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
 | 
				
			||||||
  return %0 : tensor<1x4xi32>
 | 
					  return %0 : tensor<1x4xi32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -755,7 +755,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
 | 
					func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
 | 
				
			||||||
  // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
 | 
					  // expected-error@+1 {{operand #2 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
 | 
				
			||||||
  %0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
 | 
					  %0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
 | 
				
			||||||
  return %0 : tensor<3x4xi64>
 | 
					  return %0 : tensor<3x4xi64>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue