Add shape function for MHLO RngNormal and RngUniform
PiperOrigin-RevId: 368276963
This commit is contained in:
parent
c01e96d095
commit
fdd75daed6
|
@ -1344,7 +1344,9 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
|
||||||
// MHLO RNG Operators.
|
// MHLO RNG Operators.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
|
def HLO_RngUniformOp : HLO_Op<"rng_uniform",
|
||||||
|
InferTensorType<["inferReturnTypeComponents"]>.traits>,
|
||||||
|
BASE_HLO_RngUniformOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_PredIntOrFpTensor:$a,
|
HLO_PredIntOrFpTensor:$a,
|
||||||
HLO_PredIntOrFpTensor:$b,
|
HLO_PredIntOrFpTensor:$b,
|
||||||
|
@ -1354,9 +1356,18 @@ def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
|
||||||
let results = (outs HLO_PredIntOrFpTensor);
|
let results = (outs HLO_PredIntOrFpTensor);
|
||||||
|
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// Returns whether the return types are compatible.
|
||||||
|
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||||
|
return succeeded(::mlir::verifyCompatibleShapes(l, r));
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
|
def HLO_RngNormalOp : HLO_Op<"rng_normal",
|
||||||
|
InferTensorType<["inferReturnTypeComponents"]>.traits>,
|
||||||
|
BASE_HLO_RngNormalOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_FpTensor:$mu,
|
HLO_FpTensor:$mu,
|
||||||
HLO_FpTensor:$sigma,
|
HLO_FpTensor:$sigma,
|
||||||
|
@ -1366,6 +1377,13 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
|
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// Returns whether the return types are compatible.
|
||||||
|
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||||
|
return succeeded(::mlir::verifyCompatibleShapes(l, r));
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>,
|
def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>,
|
||||||
|
|
|
@ -146,6 +146,46 @@ static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
|
||||||
}
|
}
|
||||||
|
|
||||||
#include "mhlo_canonicalize.inc"
|
#include "mhlo_canonicalize.inc"
|
||||||
|
|
||||||
|
// Common shape function helper for RngNormal and RngUniform.
|
||||||
|
static LogicalResult rngInferReturnTypeComponents(
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
|
||||||
|
if (operands.size() != 3)
|
||||||
|
return emitOptionalError(location, "expected 3 operands");
|
||||||
|
|
||||||
|
SmallVector<int64_t> shapeVector;
|
||||||
|
Value shapeOperand = operands[2];
|
||||||
|
auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
|
||||||
|
Type elementType = getElementTypeOrSelf(operands[1]);
|
||||||
|
|
||||||
|
// Match constant shape arguments.
|
||||||
|
DenseIntElementsAttr shape;
|
||||||
|
if (!matchPattern(shapeOperand, m_Constant(&shape))) {
|
||||||
|
if (!shapeOperandType.hasRank()) {
|
||||||
|
inferredReturnShapes.emplace_back(elementType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (shapeOperandType.getRank() != 1)
|
||||||
|
return emitOptionalError(location, "shape operand required to be 1D");
|
||||||
|
int size = shapeOperandType.getDimSize(0);
|
||||||
|
if (size == ShapedType::kDynamicSize) {
|
||||||
|
inferredReturnShapes.emplace_back(elementType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
shapeVector.resize(size, ShapedType::kDynamicSize);
|
||||||
|
inferredReturnShapes.emplace_back(shapeVector, elementType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
shapeVector.reserve(shape.size());
|
||||||
|
for (const APInt& fp : shape.getIntValues())
|
||||||
|
shapeVector.push_back(fp.getSExtValue());
|
||||||
|
inferredReturnShapes.emplace_back(shapeVector, elementType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1866,6 +1906,30 @@ LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// RngNormalOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult RngNormalOp::inferReturnTypeComponents(
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
|
||||||
|
return rngInferReturnTypeComponents(context, location, operands, attributes,
|
||||||
|
regions, inferredReturnShapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// RngUniformOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult RngUniformOp::inferReturnTypeComponents(
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
|
||||||
|
return rngInferReturnTypeComponents(context, location, operands, attributes,
|
||||||
|
regions, inferredReturnShapes);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// SelectOp
|
// SelectOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1423,3 +1423,20 @@ func @reduce_window_invalid(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %ini
|
||||||
window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x3xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
|
window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x3xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
|
||||||
return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32>
|
return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @rng_normal_invalid(%arg0: tensor<f32>, %arg1: tensor<f32>) {
|
||||||
|
%cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64>
|
||||||
|
// expected-error @+1 {{tensor<7xf32>}}
|
||||||
|
%0 = "mhlo.rng_normal"(%arg0, %arg1, %cst) : (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<12xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @rng_uniform_invalid(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<7xi64>) {
|
||||||
|
// expected-error @+1 {{tensor<?x?x?x?x?x?x?xf32>}}
|
||||||
|
%0 = "mhlo.rng_uniform"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<7xi64>) -> tensor<?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue