[MLIR][LHLO] Add ReshapeMemrefCastOp to LHLO.
PiperOrigin-RevId: 319799171
This commit is contained in:
parent
fcf3df1541
commit
72010faaa7
|
@ -414,6 +414,82 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReshapeMemRefCastOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
|
||||||
|
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
|
||||||
|
NoSideEffect]> {
|
||||||
|
let summary = "reshape memref cast operation";
|
||||||
|
let description = [{
|
||||||
|
The `reshape_memref_cast` operation converts a memref from one type to an
|
||||||
|
equivalent type with a provided shape. The data is never copied or moved.
|
||||||
|
The source and destination types are compatible if both have the same
|
||||||
|
element type, address space and identity layout map. The following
|
||||||
|
combinations are possible:
|
||||||
|
|
||||||
|
a. Both are ranked memref types.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
// Reshape statically-shaped memref.
|
||||||
|
%dst = reshape_memref_cast %src(%shape)
|
||||||
|
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
|
||||||
|
%dst0 = reshape_memref_cast %src(%shape0)
|
||||||
|
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
|
||||||
|
```
|
||||||
|
|
||||||
|
b. Source type is ranked, destination type is unranked.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
// Reshape dynamically-shaped 1D memref.
|
||||||
|
%dst = reshape_memref_cast %src(%shape)
|
||||||
|
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
|
||||||
|
```
|
||||||
|
|
||||||
|
c. Source type is unranked, destination type is ranked.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
// Flatten unranked memref.
|
||||||
|
%dst = reshape_memref_cast %src(%shape)
|
||||||
|
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
|
||||||
|
```
|
||||||
|
|
||||||
|
d. Both are unranked memref types.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
// Reshape unranked memref.
|
||||||
|
%dst = reshape_memref_cast %src(%shape)
|
||||||
|
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
AnyRankedOrUnrankedMemRef:$operand,
|
||||||
|
MemRefRankOf<[AnySignlessInteger], [1]>:$shape
|
||||||
|
);
|
||||||
|
let results = (outs AnyRankedOrUnrankedMemRef:$result);
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
||||||
|
"Value operand, Value shape", [{
|
||||||
|
result.addOperands(operand);
|
||||||
|
result.addOperands(shape);
|
||||||
|
result.types.push_back(resultType);
|
||||||
|
}]>];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||||
|
}];
|
||||||
|
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape)
|
||||||
|
`)` `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// XLA Other op definitions.
|
// XLA Other op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -86,6 +86,46 @@ static LogicalResult Verify(DynamicMemRefCastOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReshapeMemrefCastOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Value ReshapeMemRefCastOp::getViewSource() { return operand(); }
|
||||||
|
|
||||||
|
static LogicalResult Verify(ReshapeMemRefCastOp op) {
|
||||||
|
Type operandType = op.operand().getType();
|
||||||
|
Type resultType = op.result().getType();
|
||||||
|
|
||||||
|
Type operandElementType = operandType.cast<ShapedType>().getElementType();
|
||||||
|
Type resultElementType = resultType.cast<ShapedType>().getElementType();
|
||||||
|
if (operandElementType != resultElementType)
|
||||||
|
return op.emitOpError(
|
||||||
|
"element types of source and destination memref "
|
||||||
|
"types should be the same");
|
||||||
|
|
||||||
|
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
|
||||||
|
if (!operandMemRefType.getAffineMaps().empty())
|
||||||
|
return op.emitOpError(
|
||||||
|
"operand memref type should have identity affine map");
|
||||||
|
|
||||||
|
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
|
||||||
|
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
|
||||||
|
if (resultMemRefType) {
|
||||||
|
if (shapeSize == ShapedType::kDynamicSize)
|
||||||
|
return op.emitOpError(
|
||||||
|
"cannot use shape operand with dynamic length to "
|
||||||
|
"cast statically-ranked memref type");
|
||||||
|
if (shapeSize != resultMemRefType.getRank())
|
||||||
|
return op.emitOpError(
|
||||||
|
"length of shape operand differs from the result's memref rank");
|
||||||
|
|
||||||
|
if (!resultMemRefType.getAffineMaps().empty())
|
||||||
|
return op.emitOpError(
|
||||||
|
"result memref type should have identity affine map");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue