[MLIR][LHLO] Add ReshapeMemrefCastOp to LHLO.

PiperOrigin-RevId: 319799171
This commit is contained in:
Alexander Belyaev 2020-07-06 16:56:07 +00:00 committed by Mehdi Amini
parent fcf3df1541
commit 72010faaa7
2 changed files with 116 additions and 0 deletions

View File

@ -414,6 +414,82 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
}];
}
//===----------------------------------------------------------------------===//
// ReshapeMemRefCastOp
//===----------------------------------------------------------------------===//
def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
NoSideEffect]> {
let summary = "reshape memref cast operation";
let description = [{
The `reshape_memref_cast` operation converts a memref from one type to an
equivalent type with a provided shape. The data is never copied or moved.
The source and destination types are compatible if both have the same
element type, address space and identity layout map. The following
combinations are possible:
a. Both are ranked memref types.
```mlir
// Reshape statically-shaped memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
%dst0 = reshape_memref_cast %src(%shape0)
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
```
b. Source type is ranked, destination type is unranked.
```mlir
// Reshape dynamically-shaped 1D memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
```
c. Source type is unranked, destination type is ranked.
```mlir
// Flatten unranked memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
```
d. Both are unranked memref types.
```mlir
// Reshape unranked memref.
%dst = reshape_memref_cast %src(%shape)
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
```
}];
let arguments = (ins
AnyRankedOrUnrankedMemRef:$operand,
MemRefRankOf<[AnySignlessInteger], [1]>:$shape
);
let results = (outs AnyRankedOrUnrankedMemRef:$result);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
"Value operand, Value shape", [{
result.addOperands(operand);
result.addOperands(shape);
result.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape)
`)` `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// XLA Other op definitions.
//===----------------------------------------------------------------------===//

View File

@ -86,6 +86,46 @@ static LogicalResult Verify(DynamicMemRefCastOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeMemrefCastOp
//===----------------------------------------------------------------------===//
Value ReshapeMemRefCastOp::getViewSource() { return operand(); }
static LogicalResult Verify(ReshapeMemRefCastOp op) {
Type operandType = op.operand().getType();
Type resultType = op.result().getType();
Type operandElementType = operandType.cast<ShapedType>().getElementType();
Type resultElementType = resultType.cast<ShapedType>().getElementType();
if (operandElementType != resultElementType)
return op.emitOpError(
"element types of source and destination memref "
"types should be the same");
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
if (!operandMemRefType.getAffineMaps().empty())
return op.emitOpError(
"operand memref type should have identity affine map");
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
if (resultMemRefType) {
if (shapeSize == ShapedType::kDynamicSize)
return op.emitOpError(
"cannot use shape operand with dynamic length to "
"cast statically-ranked memref type");
if (shapeSize != resultMemRefType.getRank())
return op.emitOpError(
"length of shape operand differs from the result's memref rank");
if (!resultMemRefType.getAffineMaps().empty())
return op.emitOpError(
"result memref type should have identity affine map");
}
return success();
}
#define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"