PR #50211: [MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50211

support hlo-to-lhlo conversion for RealDynamicSliceOp and ReduceOp
Copybara import of the project:

--
c417b336670a1fc256f7026dfe8080e46d13d79a by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp

PiperOrigin-RevId: 378972113
This commit is contained in:
Wenyi Zhao 2021-06-11 16:31:53 -07:00 committed by TensorFlow MLIR Team
parent 95ba03534f
commit 8388303fd2
5 changed files with 161 additions and 7 deletions

View File

@ -967,7 +967,7 @@ def HLO_AllToAllOp : HLO_Op<"all_to_all",
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
def HLO_ReduceOp: HLO_Op<"reduce", [ def HLO_ReduceOp: HLO_ShapedInterfaceOp<"reduce", [
RecursiveSideEffects, RecursiveSideEffects,
SameVariadicOperandSize, SameVariadicOperandSize,
SingleBlockImplicitTerminator<"ReturnOp">, SingleBlockImplicitTerminator<"ReturnOp">,
@ -2199,7 +2199,7 @@ def HLO_ReducePrecisionOp :
let results = (outs HLO_FpTensor:$output); let results = (outs HLO_FpTensor:$output);
} }
def HLO_RealDynamicSliceOp: HLO_Op< def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp<
"real_dynamic_slice", "real_dynamic_slice",
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>, [NoSideEffect, AllElementTypesMatch<["operand", "result"]>,
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {

View File

@ -75,6 +75,7 @@ MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp); MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(OrOp); MAP_HLO_TO_LHLO(OrOp);
MAP_HLO_TO_LHLO(PowOp); MAP_HLO_TO_LHLO(PowOp);
MAP_HLO_TO_LHLO(RealDynamicSliceOp);
MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp); MAP_HLO_TO_LHLO(ReshapeOp);

View File

@ -1626,6 +1626,52 @@ static LogicalResult Verify(RealDynamicSliceOp op) {
return success(); return success();
} }
LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
RealDynamicSliceOp::Adaptor adaptor(operands);
Value operand = adaptor.operand();
Value start_indices = adaptor.start_indices();
Value limit_indices = adaptor.limit_indices();
Value strides = adaptor.strides();
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
// Not support unranked type a.t.m.
if (!operand_type) return failure();
Location loc = this->getLoc();
SmallVector<Value, 4> shape_values;
shape_values.reserve(operand_type.getRank());
Type shape_scalar_type =
start_indices.getType().cast<ShapedType>().getElementType();
Value one = builder.create<ConstantIndexOp>(loc, 1);
one = MaybeCastTo(builder, loc, one, shape_scalar_type);
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
Value offset = builder.create<ConstantIndexOp>(loc, element.index());
Value value_start =
builder.create<tensor::ExtractOp>(loc, start_indices, offset);
Value value_limit =
builder.create<tensor::ExtractOp>(loc, limit_indices, offset);
Value value_stride =
builder.create<tensor::ExtractOp>(loc, strides, offset);
// size = (limit - start + stride - 1) / stride
shape_values.push_back(builder.create<SignedDivIOp>(
loc,
builder.create<SubIOp>(
loc,
builder.create<AddIOp>(
loc, value_stride,
builder.create<SubIOp>(loc, value_limit, value_start)),
one),
value_stride));
}
reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
loc, shape_scalar_type, shape_values));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// InfeedOp // InfeedOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2072,6 +2118,46 @@ void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) { MLIRContext* context) {
results.insert<LowerBoolSplatConstantsIntoRegion>(context); results.insert<LowerBoolSplatConstantsIntoRegion>(context);
} }
LogicalResult ReduceOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
ReduceOp::Adaptor adaptor(operands);
auto inputs = adaptor.inputs();
auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
// Not support unranked type a.t.m.
if (!operand_type) return failure();
Location loc = this->getLoc();
SmallVector<Value, 4> shape_values;
SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
shape_values.reserve(operand_type.getRank());
Type shape_scalar_type = builder.getIndexType();
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, loc, v, shape_scalar_type);
};
for (const auto& element : llvm::enumerate(operand_type.getShape())) {
int64_t idx = element.index();
auto it = std::find(dimensions.begin(), dimensions.end(), idx);
if (it != dimensions.end()) {
continue;
}
Value value_dim = to_shape_scalar_type(
builder.create<memref::DimOp>(loc, inputs[0], element.index()));
shape_values.push_back(value_dim);
}
Value output_shape = builder.create<tensor::FromElementsOp>(
loc, shape_scalar_type, shape_values);
for (size_t i = 0; i < inputs.size(); ++i) {
reifiedReturnShapes.push_back(output_shape);
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// RngNormalOp // RngNormalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -434,11 +434,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
<< "tensor to buffer conversion expects a single block " << "tensor to buffer conversion expects a single block "
"in the region containing the operation"; "in the region containing the operation";
} }
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end()); SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) { if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
buffer_args.push_back(InsertAlloc(loc, result, &rewriter));
}
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args, auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
op->getAttrs()); op->getAttrs());
@ -671,7 +668,8 @@ void populateDynamicHLOToLHLOOnlyConversionPattern(
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>, patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
HloToLhloOpConverter<mhlo::DynamicIotaOp>, HloToLhloOpConverter<mhlo::DynamicIotaOp>,
HloToLhloOpConverter<mhlo::DynamicPadOp>, HloToLhloOpConverter<mhlo::DynamicPadOp>,
HloToLhloOpConverter<mhlo::DynamicReshapeOp> HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
HloToLhloOpConverter<mhlo::RealDynamicSliceOp>
>(*converter, context); >(*converter, context);
// clang-format on // clang-format on
} }

View File

@ -81,3 +81,72 @@ func @dynamic_pad(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: tensor<2xin
%0 = "mhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32> %0 = "mhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32>
return %0: tensor<?x?xf32> return %0: tensor<?x?xf32>
} }
// -----
// CHECK-LABEL: func @real_dynamic_slice
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[START:.*]]: memref<2xi32>, %[[LIMIT:.*]]: memref<2xi32>, %[[STRIDE:.*]]: memref<2xi32>) -> memref<?x?xf32>
func @real_dynamic_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) -> tensor<?x?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[T0:.*]] = memref.load %[[START]][%c0] : memref<2xi32>
// CHECK: %[[T1:.*]] = memref.load %[[LIMIT]][%c0] : memref<2xi32>
// CHECK: %[[T2:.*]] = memref.load %[[STRIDE]][%c0] : memref<2xi32>
// CHECK: %[[T3:.*]] = subi %[[T1]], %[[T0]] : i32
// CHECK: %[[T4:.*]] = addi %[[T2]], %[[T3]] : i32
// CHECK: %[[T5:.*]] = subi %[[T4]], %c1_i32 : i32
// CHECK: %[[T6:.*]] = divi_signed %[[T5]], %[[T2]] : i32
// CHECK: %[[T7:.*]] = memref.load %[[START]][%c1] : memref<2xi32>
// CHECK: %[[T8:.*]] = memref.load %[[LIMIT]][%c1] : memref<2xi32>
// CHECK: %[[T9:.*]] = memref.load %[[STRIDE]][%c1] : memref<2xi32>
// CHECK: %[[T10:.*]] = subi %[[T8]], %[[T7]] : i32
// CHECK: %[[T11:.*]] = addi %[[T9]], %[[T10]] : i32
// CHECK: %[[T12:.*]] = subi %[[T11]], %c1_i32 : i32
// CHECK: %[[T13:.*]] = divi_signed %[[T12]], %[[T9]] : i32
// CHECK: %[[T14:.*]] = index_cast %[[T6]] : i32 to index
// CHECK: %[[T15:.*]] = index_cast %[[T13]] : i32 to index
// CHECK: %[[T16:.*]] = memref.alloc(%[[T14]], %[[T15]]) : memref<?x?xf32>
// CHECK: "lmhlo.real_dynamic_slice"(%[[ARG]], %[[START]], %[[LIMIT]], %[[STRIDE]], %[[T16]])
%0 = "mhlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
return %0: tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @row_reduce
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[VAL:.*]]: memref<f32>) -> memref<?xf32>
func @row_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref<?x?xf32>
// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM0]]) : memref<?xf32>
// CHECK: lmhlo.reduce
// CHECK-SAME: %[[ARG]], %[[VAL]], %[[OUT]]
// CHECK: return %[[OUT]] : memref<?xf32>
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>}
: (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
return %0: tensor<?xf32>
}
// -----
// CHECK-LABEL: func @column_reduce
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[VAL:.*]]: memref<f32>) -> memref<?xf32>
func @column_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %c1 : memref<?x?xf32>
// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]]) : memref<?xf32>
// CHECK: lmhlo.reduce
// CHECK-SAME: %[[ARG]], %[[VAL]], %[[OUT]]
// CHECK: return %[[OUT]] : memref<?xf32>
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>}
: (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
return %0: tensor<?xf32>
}