diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 4e4235d..7fec26a 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -414,6 +414,82 @@ def HLO_DynamicMemRefCastOp: Op, + NoSideEffect]> { + let summary = "reshape memref cast operation"; + let description = [{ + The `reshape_memref_cast` operation converts a memref from one type to an + equivalent type with a provided shape. The data is never copied or moved. + The source and destination types are compatible if both have the same + element type, address space and identity layout map. The following + combinations are possible: + + a. Both are ranked memref types. + + ```mlir + // Reshape statically-shaped memref. + %dst = reshape_memref_cast %src(%shape) + : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> + %dst0 = reshape_memref_cast %src(%shape0) + : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> + ``` + + b. Source type is ranked, destination type is unranked. + + ```mlir + // Reshape dynamically-shaped 1D memref. + %dst = reshape_memref_cast %src(%shape) + : (memref, memref) to memref<*xf32> + ``` + + c. Source type is unranked, destination type is ranked. + + ```mlir + // Flatten unranked memref. + %dst = reshape_memref_cast %src(%shape) + : (memref<*xf32>, memref<1xi32>) to memref + ``` + + d. Both are unranked memref types. + + ```mlir + // Reshape unranked memref. + %dst = reshape_memref_cast %src(%shape) + : (memref<*xf32>, memref) to memref<*xf32> + ``` + }]; + + let arguments = (ins + AnyRankedOrUnrankedMemRef:$operand, + MemRefRankOf<[AnySignlessInteger], [1]>:$shape + ); + let results = (outs AnyRankedOrUnrankedMemRef:$result); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, MemRefType resultType, " # + "Value operand, Value shape", [{ + result.addOperands(operand); + result.addOperands(shape); + result.types.push_back(resultType); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let verifier = [{ return Verify(*this); }]; + let assemblyFormat = [{ + $operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape) + `)` `->` type($result) + }]; +} + + //===----------------------------------------------------------------------===// // XLA Other op definitions. //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 3e374a4..305df4f 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -86,6 +86,46 @@ static LogicalResult Verify(DynamicMemRefCastOp op) { return success(); } +//===----------------------------------------------------------------------===// +// ReshapeMemrefCastOp +//===----------------------------------------------------------------------===// + +Value ReshapeMemRefCastOp::getViewSource() { return operand(); } + +static LogicalResult Verify(ReshapeMemRefCastOp op) { + Type operandType = op.operand().getType(); + Type resultType = op.result().getType(); + + Type operandElementType = operandType.cast().getElementType(); + Type resultElementType = resultType.cast().getElementType(); + if (operandElementType != resultElementType) + return op.emitOpError( + "element types of source and destination memref " + "types should be the same"); + + if (auto operandMemRefType = operandType.dyn_cast()) + if (!operandMemRefType.getAffineMaps().empty()) + return op.emitOpError( + "operand memref type should have identity affine map"); + + int64_t shapeSize = op.shape().getType().cast().getDimSize(0); + auto resultMemRefType = resultType.dyn_cast(); + if (resultMemRefType) { + if (shapeSize == ShapedType::kDynamicSize) + return op.emitOpError( + "cannot use shape operand with dynamic length to " + "cast statically-ranked memref type"); + if (shapeSize != resultMemRefType.getRank()) + return op.emitOpError( + "length of shape operand differs from the result's memref rank"); + + if (!resultMemRefType.getAffineMaps().empty()) + return op.emitOpError( + "result memref type should have identity affine map"); + } + return success(); +} + #define GET_OP_CLASSES #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"