[MLIR][HLO:Linalg] Lower mhlo.dynamic_iota to indexed_generic
This is the same as iota, but instead of taking the dimensions from the result tensor we use the supplied shape extents tensor. PiperOrigin-RevId: 362298548
This commit is contained in:
parent
09f8046816
commit
94f9740c67
|
@ -93,14 +93,24 @@ Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
||||||
Value tensor) {
|
Value tensor,
|
||||||
|
Value shape_tensor = nullptr) {
|
||||||
auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
|
auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!tensor_type) return {};
|
if (!tensor_type) return {};
|
||||||
SmallVector<Value, 2> dyn_sizes;
|
SmallVector<Value, 2> dyn_sizes;
|
||||||
for (auto& en : llvm::enumerate(tensor_type.getShape())) {
|
for (auto& en : llvm::enumerate(tensor_type.getShape())) {
|
||||||
if (en.value() != ShapedType::kDynamicSize) continue;
|
if (en.value() != ShapedType::kDynamicSize) continue;
|
||||||
|
// If a shape tensor is present extract from there.
|
||||||
|
if (shape_tensor) {
|
||||||
|
Value extract = b.create<tensor::ExtractOp>(
|
||||||
|
loc, shape_tensor,
|
||||||
|
ValueRange{b.create<ConstantIndexOp>(loc, en.index())});
|
||||||
|
dyn_sizes.push_back(
|
||||||
|
b.create<IndexCastOp>(loc, b.getIndexType(), extract));
|
||||||
|
} else {
|
||||||
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
|
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return dyn_sizes;
|
return dyn_sizes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -868,17 +878,20 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
unsigned nloops = result_shaped_type.getRank();
|
unsigned nloops = result_shaped_type.getRank();
|
||||||
|
|
||||||
Location loc = iota_op.getLoc();
|
Location loc = iota_op.getLoc();
|
||||||
auto dyn_sizes = isLHLO
|
// If this is a dynamic iota, the first argument will be a shape tensor.
|
||||||
|
Value shape_tensor = args.size() > (isLHLO ? 1 : 0) ? args[0] : nullptr;
|
||||||
|
auto dyn_sizes =
|
||||||
|
isLHLO
|
||||||
? SmallVector<Value, 2>()
|
? SmallVector<Value, 2>()
|
||||||
: ExtractDynamicSizes(rewriter, loc,
|
: ExtractDynamicSizes(
|
||||||
GetResultValue<isLHLO>(iota_op));
|
rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor);
|
||||||
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
||||||
loc,
|
loc,
|
||||||
/*resultTensorTypes=*/
|
/*resultTensorTypes=*/
|
||||||
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
|
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
|
||||||
/*inputs=*/ValueRange{},
|
/*inputs=*/ValueRange{},
|
||||||
/*outputBuffers=*/
|
/*outputBuffers=*/
|
||||||
isLHLO ? ValueRange{args}
|
isLHLO ? ValueRange{args.back()}
|
||||||
: ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
|
: ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
|
||||||
dyn_sizes)},
|
dyn_sizes)},
|
||||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
|
@ -1636,6 +1649,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
BroadcastConverter<mhlo::BroadcastOp, false>,
|
BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||||
|
IotaConverter<mhlo::DynamicIotaOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||||
|
|
|
@ -839,6 +839,27 @@ func @iota() -> tensor<7x10xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @iota
|
||||||
|
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
|
||||||
|
func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
|
||||||
|
%result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
|
||||||
|
return %result : tensor<?x?x8xf32>
|
||||||
|
}
|
||||||
|
// CHECK: %[[E1:.*]] = tensor.extract %[[SHAPE]][%c0] : tensor<?xi32>
|
||||||
|
// CHECK: %[[I1:.*]] = index_cast %[[E1]] : i32 to index
|
||||||
|
// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32>
|
||||||
|
// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index
|
||||||
|
// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xf32>
|
||||||
|
// CHECK: linalg.indexed_generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||||
|
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[D2:.*]]: index, %{{.*}}: f32):
|
||||||
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
||||||
|
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @shift_left(%lhs: tensor<2x2xi32>,
|
func @shift_left(%lhs: tensor<2x2xi32>,
|
||||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
%result = "mhlo.shift_left"(%lhs, %rhs)
|
%result = "mhlo.shift_left"(%lhs, %rhs)
|
||||||
|
|
Loading…
Reference in New Issue