[MLIR][HLO:Linalg] Lower mhlo.dynamic_iota to indexed_generic

This is the same as iota, but instead of taking the dimensions from the result
tensor we use the supplied shape extents tensor.

PiperOrigin-RevId: 362298548
This commit is contained in:
Benjamin Kramer 2021-03-11 08:30:08 -08:00 committed by TensorFlow MLIR Team
parent 09f8046816
commit 94f9740c67
2 changed files with 42 additions and 7 deletions

View File

@ -93,13 +93,23 @@ Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
} }
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc, SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
Value tensor) { Value tensor,
Value shape_tensor = nullptr) {
auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>(); auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
if (!tensor_type) return {}; if (!tensor_type) return {};
SmallVector<Value, 2> dyn_sizes; SmallVector<Value, 2> dyn_sizes;
for (auto& en : llvm::enumerate(tensor_type.getShape())) { for (auto& en : llvm::enumerate(tensor_type.getShape())) {
if (en.value() != ShapedType::kDynamicSize) continue; if (en.value() != ShapedType::kDynamicSize) continue;
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index())); // If a shape tensor is present extract from there.
if (shape_tensor) {
Value extract = b.create<tensor::ExtractOp>(
loc, shape_tensor,
ValueRange{b.create<ConstantIndexOp>(loc, en.index())});
dyn_sizes.push_back(
b.create<IndexCastOp>(loc, b.getIndexType(), extract));
} else {
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
}
} }
return dyn_sizes; return dyn_sizes;
} }
@ -868,17 +878,20 @@ class IotaConverter : public OpConversionPattern<OpTy> {
unsigned nloops = result_shaped_type.getRank(); unsigned nloops = result_shaped_type.getRank();
Location loc = iota_op.getLoc(); Location loc = iota_op.getLoc();
auto dyn_sizes = isLHLO // If this is a dynamic iota, the first argument will be a shape tensor.
? SmallVector<Value, 2>() Value shape_tensor = args.size() > (isLHLO ? 1 : 0) ? args[0] : nullptr;
: ExtractDynamicSizes(rewriter, loc, auto dyn_sizes =
GetResultValue<isLHLO>(iota_op)); isLHLO
? SmallVector<Value, 2>()
: ExtractDynamicSizes(
rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor);
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>( auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
loc, loc,
/*resultTensorTypes=*/ /*resultTensorTypes=*/
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type}, isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
/*inputs=*/ValueRange{}, /*inputs=*/ValueRange{},
/*outputBuffers=*/ /*outputBuffers=*/
isLHLO ? ValueRange{args} isLHLO ? ValueRange{args.back()}
: ValueRange{GetInitTensor(rewriter, loc, result_shaped_type, : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
dyn_sizes)}, dyn_sizes)},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
@ -1636,6 +1649,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
BroadcastConverter<mhlo::BroadcastOp, false>, BroadcastConverter<mhlo::BroadcastOp, false>,
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter, ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
IotaConverter<mhlo::DynamicIotaOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>, PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>, PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>, PointwiseToLinalgConverter<mhlo::AndOp, false>,

View File

@ -839,6 +839,27 @@ func @iota() -> tensor<7x10xf32> {
// ----- // -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @iota
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
%result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
return %result : tensor<?x?x8xf32>
}
// CHECK: %[[E1:.*]] = tensor.extract %[[SHAPE]][%c0] : tensor<?xi32>
// CHECK: %[[I1:.*]] = index_cast %[[E1]] : i32 to index
// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32>
// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index
// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xf32>
// CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[D2:.*]]: index, %{{.*}}: f32):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
// -----
func @shift_left(%lhs: tensor<2x2xi32>, func @shift_left(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%result = "mhlo.shift_left"(%lhs, %rhs) %result = "mhlo.shift_left"(%lhs, %rhs)