[MLIR][LHLO] Add ReshapeMemrefCastOp to LHLO.
PiperOrigin-RevId: 319799171
This commit is contained in:
parent
fcf3df1541
commit
72010faaa7
|
@ -414,6 +414,82 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeMemRefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
|
||||
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
|
||||
NoSideEffect]> {
|
||||
let summary = "reshape memref cast operation";
|
||||
let description = [{
|
||||
The `reshape_memref_cast` operation converts a memref from one type to an
|
||||
equivalent type with a provided shape. The data is never copied or moved.
|
||||
The source and destination types are compatible if both have the same
|
||||
element type, address space and identity layout map. The following
|
||||
combinations are possible:
|
||||
|
||||
a. Both are ranked memref types.
|
||||
|
||||
```mlir
|
||||
// Reshape statically-shaped memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
|
||||
%dst0 = reshape_memref_cast %src(%shape0)
|
||||
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
|
||||
```
|
||||
|
||||
b. Source type is ranked, destination type is unranked.
|
||||
|
||||
```mlir
|
||||
// Reshape dynamically-shaped 1D memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
|
||||
```
|
||||
|
||||
c. Source type is unranked, destination type is ranked.
|
||||
|
||||
```mlir
|
||||
// Flatten unranked memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
|
||||
```
|
||||
|
||||
d. Both are unranked memref types.
|
||||
|
||||
```mlir
|
||||
// Reshape unranked memref.
|
||||
%dst = reshape_memref_cast %src(%shape)
|
||||
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRankedOrUnrankedMemRef:$operand,
|
||||
MemRefRankOf<[AnySignlessInteger], [1]>:$shape
|
||||
);
|
||||
let results = (outs AnyRankedOrUnrankedMemRef:$result);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
||||
"Value operand, Value shape", [{
|
||||
result.addOperands(operand);
|
||||
result.addOperands(shape);
|
||||
result.types.push_back(resultType);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
let assemblyFormat = [{
|
||||
$operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape)
|
||||
`)` `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA Other op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -86,6 +86,46 @@ static LogicalResult Verify(DynamicMemRefCastOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeMemrefCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value ReshapeMemRefCastOp::getViewSource() { return operand(); }
|
||||
|
||||
static LogicalResult Verify(ReshapeMemRefCastOp op) {
|
||||
Type operandType = op.operand().getType();
|
||||
Type resultType = op.result().getType();
|
||||
|
||||
Type operandElementType = operandType.cast<ShapedType>().getElementType();
|
||||
Type resultElementType = resultType.cast<ShapedType>().getElementType();
|
||||
if (operandElementType != resultElementType)
|
||||
return op.emitOpError(
|
||||
"element types of source and destination memref "
|
||||
"types should be the same");
|
||||
|
||||
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
|
||||
if (!operandMemRefType.getAffineMaps().empty())
|
||||
return op.emitOpError(
|
||||
"operand memref type should have identity affine map");
|
||||
|
||||
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
|
||||
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
|
||||
if (resultMemRefType) {
|
||||
if (shapeSize == ShapedType::kDynamicSize)
|
||||
return op.emitOpError(
|
||||
"cannot use shape operand with dynamic length to "
|
||||
"cast statically-ranked memref type");
|
||||
if (shapeSize != resultMemRefType.getRank())
|
||||
return op.emitOpError(
|
||||
"length of shape operand differs from the result's memref rank");
|
||||
|
||||
if (!resultMemRefType.getAffineMaps().empty())
|
||||
return op.emitOpError(
|
||||
"result memref type should have identity affine map");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
||||
|
||||
|
|
Loading…
Reference in New Issue