PR #50236: [MLIR][DISC] Bufferize TransposeOp and ConcatenateOp

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50236

support hlo-to-lhlo conversion for TransposeOp and ConcatenateOp
Copybara import of the project:

--
62860e717f2a14fbd3ddfb634aa6ff132d245a72 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Bufferize TransposeOp and ConcatenateOp

--
ce2ff57c1edee1172cd2f36346cc0b34ec1c7467 by Wenyi Zhao <reyizero@gmail.com>:

fix

PiperOrigin-RevId: 379330954
This commit is contained in:
Wenyi Zhao 2021-06-14 12:35:47 -07:00 committed by TensorFlow MLIR Team
parent 23ebbb28d1
commit 7f94bd923b
5 changed files with 123 additions and 2 deletions

View File

@ -1382,7 +1382,7 @@ def HLO_ClampOp : HLO_Op<"clamp",
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
def HLO_ConcatenateOp : HLO_Op<"concatenate", def HLO_ConcatenateOp : HLO_ShapedInterfaceOp<"concatenate",
[NoSideEffect, SameOperandsAndResultElementType, [NoSideEffect, SameOperandsAndResultElementType,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> { DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "XLA's concatenate op"; let summary = "XLA's concatenate op";
@ -1901,7 +1901,7 @@ def HLO_TraceOp: HLO_Op<"trace", []> {
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }
def HLO_TransposeOp: HLO_Op<"transpose", def HLO_TransposeOp: HLO_ShapedInterfaceOp<"transpose",
[NoSideEffect, SameOperandsAndResultElementType]> { [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "Transpose operator"; let summary = "Transpose operator";
let description = [{ let description = [{

View File

@ -47,6 +47,7 @@ MAP_HLO_TO_LHLO(ClampOp);
MAP_HLO_TO_LHLO(ConstOp); MAP_HLO_TO_LHLO(ConstOp);
MAP_HLO_TO_LHLO(CompareOp); MAP_HLO_TO_LHLO(CompareOp);
MAP_HLO_TO_LHLO(ComplexOp); MAP_HLO_TO_LHLO(ComplexOp);
MAP_HLO_TO_LHLO(ConcatenateOp);
MAP_HLO_TO_LHLO(ConvOp); MAP_HLO_TO_LHLO(ConvOp);
MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp); MAP_HLO_TO_LHLO(CopyOp);

View File

@ -1417,6 +1417,57 @@ static LogicalResult Verify(ConcatenateOp op) {
return success(); return success();
} }
LogicalResult ConcatenateOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
ConcatenateOp::Adaptor adaptor(operands);
auto inputs = adaptor.val();
auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
// Not support unranked type a.t.m.
if (!operand_type) return failure();
Location loc = this->getLoc();
Type shape_scalar_type = builder.getIndexType();
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, loc, v, shape_scalar_type);
};
SmallVector<SmallVector<Value, 4>, 4> all_shape_values;
for (size_t input_id = 0; input_id < inputs.size(); ++input_id) {
Value operand = inputs[input_id];
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
if (!operand_type) return failure();
SmallVector<Value, 4> shape_vals;
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
Value value_dim = to_shape_scalar_type(
builder.create<memref::DimOp>(loc, operand, element.index()));
shape_vals.push_back(value_dim);
}
all_shape_values.emplace_back(std::move(shape_vals));
}
int axis = this->dimension();
auto& shape_values = all_shape_values[0];
for (size_t vec_id = 1; vec_id < all_shape_values.size(); ++vec_id) {
auto& other_shape_values = all_shape_values[vec_id];
if (other_shape_values.size() != shape_values.size()) {
this->emitOpError()
<< "Concatenate expects all operands must be of the same rank";
return failure();
}
shape_values[axis] = builder.create<AddIOp>(loc, shape_values[axis],
other_shape_values[axis]);
}
Value output_shape = builder.create<tensor::FromElementsOp>(
loc, shape_scalar_type, shape_values);
reifiedReturnShapes.push_back(output_shape);
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DynamicReshapeOp // DynamicReshapeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -3363,6 +3414,40 @@ static LogicalResult Verify(TransposeOp op) {
return success(); return success();
} }
LogicalResult TransposeOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
TransposeOp::Adaptor adaptor(operands);
Value operand = adaptor.operand();
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
// Not support unranked type a.t.m.
if (!operand_type) return failure();
Location loc = this->getLoc();
SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
SmallVector<Value, 4> shape_values(permutation.size());
Type shape_scalar_type = builder.getIndexType();
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, loc, v, shape_scalar_type);
};
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
int64_t idx = element.index();
auto it = std::find(permutation.begin(), permutation.end(), idx);
Value value_dim = to_shape_scalar_type(
builder.create<memref::DimOp>(loc, operand, element.index()));
shape_values[std::distance(permutation.begin(), it)] = value_dim;
}
Value output_shape = builder.create<tensor::FromElementsOp>(
loc, shape_scalar_type, shape_values);
reifiedReturnShapes.push_back(output_shape);
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TriangularSolveOp // TriangularSolveOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -696,6 +696,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::CeilOp>, HloToLhloOpConverter<mhlo::CeilOp>,
HloToLhloOpConverter<mhlo::CompareOp>, HloToLhloOpConverter<mhlo::CompareOp>,
HloToLhloOpConverter<mhlo::ComplexOp>, HloToLhloOpConverter<mhlo::ComplexOp>,
HloToLhloOpConverter<mhlo::ConcatenateOp>,
HloToLhloOpConverter<mhlo::ConstOp>, HloToLhloOpConverter<mhlo::ConstOp>,
HloToLhloOpConverter<mhlo::ConvOp>, HloToLhloOpConverter<mhlo::ConvOp>,
HloToLhloOpConverter<mhlo::ConvertOp>, HloToLhloOpConverter<mhlo::ConvertOp>,

View File

@ -150,3 +150,37 @@ func @column_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32>
: (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
return %0: tensor<?xf32> return %0: tensor<?xf32>
} }
// -----
// CHECK-LABEL: func @transpose
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>) -> memref<?x?xf32>
func @transpose(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref<?x?xf32>
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %c1 : memref<?x?xf32>
// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]], %[[DIM0]]) : memref<?x?xf32>
// CHECK: "lmhlo.transpose"(%[[ARG]], %[[OUT]])
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1,0]> : tensor<2xi64>} : (tensor<?x?xf32>) -> tensor<?x?xf32>
return %0: tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @concatenate
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xi32>, %[[ARG1:.*]]: memref<?x?xi32>, %[[ARG2:.*]]: memref<?x?xi32>) -> memref<?x?xi32>
func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK-NOT: tensor_load
// CHECK: %[[ARG0_DIM0:.*]] = memref.dim %[[ARG0]], %c0 : memref<?x?xi32>
// CHECK: %[[ARG0_DIM1:.*]] = memref.dim %[[ARG0]], %c1 : memref<?x?xi32>
// CHECK: %[[ARG1_DIM1:.*]] = memref.dim %[[ARG1]], %c1 : memref<?x?xi32>
// CHECK: %[[ARG2_DIM1:.*]] = memref.dim %[[ARG2]], %c1 : memref<?x?xi32>
// CHECK: %[[TMP:.*]] = addi %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
// CHECK: %[[OUT_DIM1:.*]] = addi %[[TMP]], %[[ARG2_DIM1]] : index
// CHECK: %[[OUT:.*]] = memref.alloc(%[[ARG0_DIM0]], %[[OUT_DIM1]]) : memref<?x?xi32>
// CHECK: "lmhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[OUT]])
%concat = "mhlo.concatenate"(%a, %b, %c) {
dimension = 1
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %concat : tensor<?x?xi32>
}