diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index b3eabbf..87f0e78 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1344,7 +1344,9 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { // MHLO RNG Operators. //===----------------------------------------------------------------------===// -def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { +def HLO_RngUniformOp : HLO_Op<"rng_uniform", + InferTensorType<["inferReturnTypeComponents"]>.traits>, + BASE_HLO_RngUniformOp { let arguments = (ins HLO_PredIntOrFpTensor:$a, HLO_PredIntOrFpTensor:$b, @@ -1354,9 +1356,18 @@ def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { let results = (outs HLO_PredIntOrFpTensor); let hasCustomHLOConverter = 1; + + let extraClassDeclaration = [{ + // Returns whether the return types are compatible. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { + return succeeded(::mlir::verifyCompatibleShapes(l, r)); + } + }]; } -def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { +def HLO_RngNormalOp : HLO_Op<"rng_normal", + InferTensorType<["inferReturnTypeComponents"]>.traits>, + BASE_HLO_RngNormalOp { let arguments = (ins HLO_FpTensor:$mu, HLO_FpTensor:$sigma, @@ -1366,6 +1377,13 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { let results = (outs HLO_FpTensor); let hasCustomHLOConverter = 1; + + let extraClassDeclaration = [{ + // Returns whether the return types are compatible. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { + return succeeded(::mlir::verifyCompatibleShapes(l, r)); + } + }]; } def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index a0f63c9..62dc881 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -146,6 +146,46 @@ static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op, } #include "mhlo_canonicalize.inc" + +// Common shape function helper for RngNormal and RngUniform. +static LogicalResult rngInferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + if (operands.size() != 3) + return emitOptionalError(location, "expected 3 operands"); + + SmallVector shapeVector; + Value shapeOperand = operands[2]; + auto shapeOperandType = shapeOperand.getType().cast(); + Type elementType = getElementTypeOrSelf(operands[1]); + + // Match constant shape arguments. + DenseIntElementsAttr shape; + if (!matchPattern(shapeOperand, m_Constant(&shape))) { + if (!shapeOperandType.hasRank()) { + inferredReturnShapes.emplace_back(elementType); + return success(); + } + if (shapeOperandType.getRank() != 1) + return emitOptionalError(location, "shape operand required to be 1D"); + int size = shapeOperandType.getDimSize(0); + if (size == ShapedType::kDynamicSize) { + inferredReturnShapes.emplace_back(elementType); + return success(); + } + shapeVector.resize(size, ShapedType::kDynamicSize); + inferredReturnShapes.emplace_back(shapeVector, elementType); + return success(); + } + + shapeVector.reserve(shape.size()); + for (const APInt& fp : shape.getIntValues()) + shapeVector.push_back(fp.getSExtValue()); + inferredReturnShapes.emplace_back(shapeVector, elementType); + return success(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -1866,6 +1906,30 @@ LogicalResult ReduceOp::fold(ArrayRef operands, return failure(); } +//===----------------------------------------------------------------------===// +// RngNormalOp +//===----------------------------------------------------------------------===// + +LogicalResult RngNormalOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + return rngInferReturnTypeComponents(context, location, operands, attributes, + regions, inferredReturnShapes); +} + +//===----------------------------------------------------------------------===// +// RngUniformOp +//===----------------------------------------------------------------------===// + +LogicalResult RngUniformOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + return rngInferReturnTypeComponents(context, location, operands, attributes, + regions, inferredReturnShapes); +} + //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// diff --git a/tests/ops.mlir b/tests/ops.mlir index 604bd5a..f67c0ac 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -1423,3 +1423,20 @@ func @reduce_window_invalid(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %ini window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x3xi32>, tensor, tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32> } + +// ----- + +func @rng_normal_invalid(%arg0: tensor, %arg1: tensor) { + %cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> + // expected-error @+1 {{tensor<7xf32>}} + %0 = "mhlo.rng_normal"(%arg0, %arg1, %cst) : (tensor, tensor, tensor<1xi64>) -> tensor<12xf32> + return +} + +// ----- + +func @rng_uniform_invalid(%arg0: tensor, %arg1: tensor, %arg2: tensor<7xi64>) { + // expected-error @+1 {{tensor}} + %0 = "mhlo.rng_uniform"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<7xi64>) -> tensor + return +}