Add shape function for MHLO RngNormal and RngUniform
PiperOrigin-RevId: 368276963
This commit is contained in:
parent
c01e96d095
commit
fdd75daed6
|
@ -1344,7 +1344,9 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
|
|||
// MHLO RNG Operators.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
|
||||
def HLO_RngUniformOp : HLO_Op<"rng_uniform",
|
||||
InferTensorType<["inferReturnTypeComponents"]>.traits>,
|
||||
BASE_HLO_RngUniformOp {
|
||||
let arguments = (ins
|
||||
HLO_PredIntOrFpTensor:$a,
|
||||
HLO_PredIntOrFpTensor:$b,
|
||||
|
@ -1354,9 +1356,18 @@ def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
|
|||
let results = (outs HLO_PredIntOrFpTensor);
|
||||
|
||||
let hasCustomHLOConverter = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns whether the return types are compatible.
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
return succeeded(::mlir::verifyCompatibleShapes(l, r));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
|
||||
def HLO_RngNormalOp : HLO_Op<"rng_normal",
|
||||
InferTensorType<["inferReturnTypeComponents"]>.traits>,
|
||||
BASE_HLO_RngNormalOp {
|
||||
let arguments = (ins
|
||||
HLO_FpTensor:$mu,
|
||||
HLO_FpTensor:$sigma,
|
||||
|
@ -1366,6 +1377,13 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
|
|||
let results = (outs HLO_FpTensor);
|
||||
|
||||
let hasCustomHLOConverter = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns whether the return types are compatible.
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
return succeeded(::mlir::verifyCompatibleShapes(l, r));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>,
|
||||
|
|
|
@ -146,6 +146,46 @@ static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
|
|||
}
|
||||
|
||||
#include "mhlo_canonicalize.inc"
|
||||
|
||||
// Common shape function helper for RngNormal and RngUniform.
|
||||
static LogicalResult rngInferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
|
||||
if (operands.size() != 3)
|
||||
return emitOptionalError(location, "expected 3 operands");
|
||||
|
||||
SmallVector<int64_t> shapeVector;
|
||||
Value shapeOperand = operands[2];
|
||||
auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
|
||||
Type elementType = getElementTypeOrSelf(operands[1]);
|
||||
|
||||
// Match constant shape arguments.
|
||||
DenseIntElementsAttr shape;
|
||||
if (!matchPattern(shapeOperand, m_Constant(&shape))) {
|
||||
if (!shapeOperandType.hasRank()) {
|
||||
inferredReturnShapes.emplace_back(elementType);
|
||||
return success();
|
||||
}
|
||||
if (shapeOperandType.getRank() != 1)
|
||||
return emitOptionalError(location, "shape operand required to be 1D");
|
||||
int size = shapeOperandType.getDimSize(0);
|
||||
if (size == ShapedType::kDynamicSize) {
|
||||
inferredReturnShapes.emplace_back(elementType);
|
||||
return success();
|
||||
}
|
||||
shapeVector.resize(size, ShapedType::kDynamicSize);
|
||||
inferredReturnShapes.emplace_back(shapeVector, elementType);
|
||||
return success();
|
||||
}
|
||||
|
||||
shapeVector.reserve(shape.size());
|
||||
for (const APInt& fp : shape.getIntValues())
|
||||
shapeVector.push_back(fp.getSExtValue());
|
||||
inferredReturnShapes.emplace_back(shapeVector, elementType);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1866,6 +1906,30 @@ LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
|
|||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RngNormalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult RngNormalOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
|
||||
return rngInferReturnTypeComponents(context, location, operands, attributes,
|
||||
regions, inferredReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RngUniformOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult RngUniformOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
|
||||
return rngInferReturnTypeComponents(context, location, operands, attributes,
|
||||
regions, inferredReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1423,3 +1423,20 @@ func @reduce_window_invalid(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %ini
|
|||
window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x3xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
|
||||
return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rng_normal_invalid(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||
%cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64>
|
||||
// expected-error @+1 {{tensor<7xf32>}}
|
||||
%0 = "mhlo.rng_normal"(%arg0, %arg1, %cst) : (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<12xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rng_uniform_invalid(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<7xi64>) {
|
||||
// expected-error @+1 {{tensor<?x?x?x?x?x?x?xf32>}}
|
||||
%0 = "mhlo.rng_uniform"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<7xi64>) -> tensor<?xf32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue