Add shape function for MHLO RngNormal and RngUniform

PiperOrigin-RevId: 368276963
This commit is contained in:
Jacques Pienaar 2021-04-13 12:58:38 -07:00 committed by TensorFlow MLIR Team
parent c01e96d095
commit fdd75daed6
3 changed files with 101 additions and 2 deletions

View File

@ -1344,7 +1344,9 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
// MHLO RNG Operators.
//===----------------------------------------------------------------------===//
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
def HLO_RngUniformOp : HLO_Op<"rng_uniform",
InferTensorType<["inferReturnTypeComponents"]>.traits>,
BASE_HLO_RngUniformOp {
let arguments = (ins
HLO_PredIntOrFpTensor:$a,
HLO_PredIntOrFpTensor:$b,
@ -1354,9 +1356,18 @@ def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
let results = (outs HLO_PredIntOrFpTensor);
let hasCustomHLOConverter = 1;
let extraClassDeclaration = [{
// Returns whether the return types are compatible.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return succeeded(::mlir::verifyCompatibleShapes(l, r));
}
}];
}
def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
def HLO_RngNormalOp : HLO_Op<"rng_normal",
InferTensorType<["inferReturnTypeComponents"]>.traits>,
BASE_HLO_RngNormalOp {
let arguments = (ins
HLO_FpTensor:$mu,
HLO_FpTensor:$sigma,
@ -1366,6 +1377,13 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
let results = (outs HLO_FpTensor);
let hasCustomHLOConverter = 1;
let extraClassDeclaration = [{
// Returns whether the return types are compatible.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return succeeded(::mlir::verifyCompatibleShapes(l, r));
}
}];
}
def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>,

View File

@ -146,6 +146,46 @@ static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
}
#include "mhlo_canonicalize.inc"
// Common shape function helper for RngNormal and RngUniform.
static LogicalResult rngInferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
if (operands.size() != 3)
return emitOptionalError(location, "expected 3 operands");
SmallVector<int64_t> shapeVector;
Value shapeOperand = operands[2];
auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
Type elementType = getElementTypeOrSelf(operands[1]);
// Match constant shape arguments.
DenseIntElementsAttr shape;
if (!matchPattern(shapeOperand, m_Constant(&shape))) {
if (!shapeOperandType.hasRank()) {
inferredReturnShapes.emplace_back(elementType);
return success();
}
if (shapeOperandType.getRank() != 1)
return emitOptionalError(location, "shape operand required to be 1D");
int size = shapeOperandType.getDimSize(0);
if (size == ShapedType::kDynamicSize) {
inferredReturnShapes.emplace_back(elementType);
return success();
}
shapeVector.resize(size, ShapedType::kDynamicSize);
inferredReturnShapes.emplace_back(shapeVector, elementType);
return success();
}
shapeVector.reserve(shape.size());
for (const APInt& fp : shape.getIntValues())
shapeVector.push_back(fp.getSExtValue());
inferredReturnShapes.emplace_back(shapeVector, elementType);
return success();
}
} // namespace
//===----------------------------------------------------------------------===//
@ -1866,6 +1906,30 @@ LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
return failure();
}
//===----------------------------------------------------------------------===//
// RngNormalOp
//===----------------------------------------------------------------------===//
LogicalResult RngNormalOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
return rngInferReturnTypeComponents(context, location, operands, attributes,
regions, inferredReturnShapes);
}
//===----------------------------------------------------------------------===//
// RngUniformOp
//===----------------------------------------------------------------------===//
LogicalResult RngUniformOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
return rngInferReturnTypeComponents(context, location, operands, attributes,
regions, inferredReturnShapes);
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

View File

@ -1423,3 +1423,20 @@ func @reduce_window_invalid(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %ini
window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x3xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32>
}
// -----
func @rng_normal_invalid(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64>
// expected-error @+1 {{tensor<7xf32>}}
%0 = "mhlo.rng_normal"(%arg0, %arg1, %cst) : (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<12xf32>
return
}
// -----
func @rng_uniform_invalid(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<7xi64>) {
// expected-error @+1 {{tensor<?x?x?x?x?x?x?xf32>}}
%0 = "mhlo.rng_uniform"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<7xi64>) -> tensor<?xf32>
return
}