From da6593e9601ba03567f10ea9973c13e21a22adae Mon Sep 17 00:00:00 2001 From: Abhishek Varma <67887857+avarmapml@users.noreply.github.com> Date: Thu, 17 Jun 2021 14:54:39 -0700 Subject: [PATCH] PR #50073: [MLIR] Add GatherOp lowering from lmhlo to Affine. Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50073 -- Lowering of `GatherOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Abhishek Varma Copybara import of the project: -- 5b3dcd4ab31a69f305cd079b869ee35ba6dc8bf5 by Abhishek Varma : [MLIR] Add GatherOp lowering from lmhlo to Affine. -- Lowering of `GatherOp` is added from lmhlo to Affine. The lowering has been added as a part of `lhlo-legalize-to-affine` pass. Signed-off-by: Abhishek Varma PiperOrigin-RevId: 380052157 --- .../transforms/lhlo_legalize_to_affine.cc | 303 +++++++++++++++++- tests/lhlo-legalize-to-affine.mlir | 201 ++++++++++++ 2 files changed, 503 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index b211169..0346ab5 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -177,6 +177,306 @@ struct ConcatOpConverter : public OpRewritePattern { } }; +/// Returns a zero value of type `type`. `type` is expected to be either +/// int or float. +static Value getZeroValue(Type type, Location loc, PatternRewriter& rewriter) { + assert(type.isIntOrFloat() && "Expected int or float"); + + if (IntegerType intType = type.dyn_cast()) + return rewriter.create(loc, 0, intType.getWidth()); + + FloatType floatType = type.cast(); + return rewriter.create( + loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); +} + +/// Emits a nest to fill the given `buffer` of memref type with `fillValue`. +static void fillBuffer(Location loc, Value buffer, Value fillValue, + PatternRewriter& builder) { + OpBuilder::InsertionGuard guard(builder); + MemRefType bufType = buffer.getType().cast(); + unsigned rank = bufType.getRank(); + SmallVector dimSizes; + dimSizes.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimSizes.push_back(builder.create(loc, buffer, i)); + + AffineMap idSymMap = builder.getSymbolIdentityMap(); + AffineMap lbMap = builder.getConstantAffineMap(0); + SmallVector ivs(rank); + AffineForOp forOp; + for (unsigned i = 0; i < rank; ++i) { + forOp = builder.create(loc, llvm::None, lbMap, dimSizes[i], + idSymMap); + builder.setInsertionPointToStart(forOp.getBody()); + ivs[i] = forOp.getInductionVar(); + } + Type fillValueType = fillValue.getType(); + auto fillMemRefType = fillValueType.dyn_cast(); + assert(((fillMemRefType && fillMemRefType.getRank() == 0) || + fillValueType.isIntOrFloat()) && + "init value has to be a 0-d memref or int or fp"); + Value initVal = fillMemRefType ? builder.create( + loc, fillValue, /*indices=*/llvm::None) + : fillValue; + builder.create(loc, initVal, buffer, ivs); +} + +/// Converts GatherOp to Affine nest form. +/// Pseudocode: +/// 1. Fill a temporary output tensor with 0. +/// 2. Repeat the following for each batch dimension :- +/// 1. For each indices in 'operand' :- +/// 1. Get hold of start indices from 'start_indices'. +/// 2. Add offset to the start indices to get the final indices. +/// 3. Load value from 'operand' tensor : 'operand_val'. +/// 4. Load value from temporary output : 'prev_val'. +/// 5. If the final indices match current indices of 'operand' : +/// 'prev_val' = 'prev_val' + 'operand_val' +/// 6. Store 'prev_val' back to the temporary output. +class GatherOpConverter : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp op, + PatternRewriter& rewriter) const final { + Location loc = op.getLoc(); + + // Operand array. + Value operand = op.operand(); + MemRefType operand_type = operand.getType().cast(); + unsigned operand_rank = operand_type.getRank(); + ArrayRef operand_shape = operand_type.getShape(); + + // Start_indices array. + Value start_indices = op.start_indices(); + MemRefType start_indices_type = start_indices.getType().cast(); + unsigned start_indices_rank = start_indices_type.getRank(); + ArrayRef start_indices_shape = start_indices_type.getShape(); + + // Output array. + Value output = op.output(); + MemRefType output_type = output.getType().cast(); + ArrayRef output_shape = output_type.getShape(); + + if (!operand_type.hasStaticShape() || + !start_indices_type.hasStaticShape() || !output_type.hasStaticShape()) + return rewriter.notifyMatchFailure(op, "only static shaped type allowed"); + + mhlo::GatherDimensionNumbers gather_dim = op.dimension_numbersAttr(); + + // Collapsed_slice_dim. + DenseIntElementsAttr collapsed_slice_dims_attr = + gather_dim.collapsed_slice_dims(); + SmallVector collapsed_slice_dims; + for (const APInt& dim : collapsed_slice_dims_attr.getIntValues()) + collapsed_slice_dims.push_back(dim.getSExtValue()); + + // Offset_dim. + DenseIntElementsAttr offset_dims_attr = gather_dim.offset_dims(); + SmallVector offset_dims; + for (const APInt& dim : offset_dims_attr.getIntValues()) + offset_dims.push_back(dim.getSExtValue()); + + // Start_index_map. + DenseIntElementsAttr start_index_map_attr = gather_dim.start_index_map(); + SmallVector start_index_map; + for (const APInt& dim : start_index_map_attr.getIntValues()) + start_index_map.push_back(dim.getSExtValue()); + + // Index_vector_dim. + IntegerAttr index_vector_dim_attr = gather_dim.index_vector_dim(); + int64_t index_vector_dim = index_vector_dim_attr.getValue().getSExtValue(); + + // Slice_sizes. + DenseIntElementsAttr slice_sizes_attr = op.slice_sizesAttr(); + SmallVector slice_sizes; + for (const APInt& dim : slice_sizes_attr.getIntValues()) + slice_sizes.push_back(dim.getSExtValue()); + + // Creating constants with 0 value. We need the Integer type constant value + // because the indices type will be Integer. + Value zero_int_val = rewriter.create( + loc, 0, start_indices_type.getElementType()); + Type element_type = output_type.getElementType(); + Value zero_load_value = getZeroValue(element_type, loc, rewriter); + // Initializing the output buffer with 0. + fillBuffer(loc, output, zero_load_value, rewriter); + + // We fetch the shape of start_indices at index_vector_dim. In case + // index_vector_dim is equal to the rank of start_indices, we implicitly + // consider start_indices to have a trailing 1 dimension. + unsigned start_indices_numbers = + (index_vector_dim == start_indices_rank) + ? 1 + : start_indices_shape[index_vector_dim]; + // We create integer constants till start_incides_index which help us to + // fetch start_indices in affine transformation. + SmallVector start_indices_index; + for (unsigned i = 0; i < start_indices_numbers; i++) { + Value i_val = rewriter.create( + loc, i, start_indices_type.getElementType()); + i_val = rewriter.create(loc, i_val, rewriter.getIndexType()); + start_indices_index.push_back(i_val); + } + + // S_in contains the multiple indices that form a starting index used in the + // input/operand tensor. O_in contains the multiple offsets of corresponding + // starting index used in the input/operand tensor. We initialize both of + // them with 0. + SmallVector S_in; + SmallVector O_in; + Value zero_index_val = rewriter.create( + loc, zero_int_val, rewriter.getIndexType()); + for (unsigned i = 0; i < operand_rank; i++) { + S_in.push_back(zero_index_val); + O_in.push_back(zero_index_val); + } + + // batch_induction_vars stores the loop induction variables pertaining to + // the batches of start indices. + SmallVector batch_induction_vars; + // output_induction_vars stores the loop induction variables pertaining to + // both batches and offsets within the output tensor. + SmallVector output_induction_vars; + // Create loops to iterate over each batch of starting index. + for (unsigned i = 0; i < start_indices_rank; i++) { + // ith dimension of start_indices doesn't form a batch if it is equal to + // index_vector_dim. + if (i == index_vector_dim) continue; + AffineForOp for_op = + rewriter.create(loc, 0, start_indices_shape[i]); + batch_induction_vars.push_back(for_op.getInductionVar()); + output_induction_vars.push_back(for_op.getInductionVar()); + rewriter.setInsertionPointToStart(for_op.getBody()); + } + + // Create loops to iterate over each offset dimension within the output + // tensor. + for (unsigned i = 0, k = 0, e = offset_dims.size(); i < e; i++) { + AffineForOp for_op = + rewriter.create(loc, 0, output_shape[offset_dims[i]]); + rewriter.setInsertionPointToStart(for_op.getBody()); + // We try to fetch the first non-collapsed dimension. + while (k < collapsed_slice_dims.size() && collapsed_slice_dims[k] == i) + k++; + // Remapping the offset_dim[i] to the non-collapsed dimension. + O_in[k++] = for_op.getInductionVar(); + // We assume offset_dims to be sorted. So when inserted to + // output_induction_vars the loop induction variable gets inserted at the + // correct position. + output_induction_vars.insert( + output_induction_vars.begin() + offset_dims[i], + for_op.getInductionVar()); + } + + // Create loops to iterate over all dimensions within the operand tensor. + SmallVector operand_index; + for (unsigned i = 0, k = 0; i < operand_rank; i++) { + // We assume start_index_map to have sorted dimensions. We only include + // those dimensions of operand tensor which are present in + // start_index_map. + if (k < start_index_map.size() && i == start_index_map[k++]) { + AffineForOp for_op = + rewriter.create(loc, 0, operand_shape[i]); + operand_index.push_back(for_op.getInductionVar()); + rewriter.setInsertionPointToStart(for_op.getBody()); + } else { + operand_index.push_back(O_in[i]); + } + } + + // In case index_vector_dim is not equal to start_indices shape then we + // create another loop to iterate over starting index and update + // batch_induction_vars. + if (index_vector_dim != start_indices_rank) { + for (unsigned i = 0; i < start_indices_numbers; i++) { + batch_induction_vars.insert( + batch_induction_vars.begin() + index_vector_dim, + start_indices_index[i]); + Value start_index = rewriter.create(loc, start_indices, + batch_induction_vars); + start_index = rewriter.create(loc, start_index, + rewriter.getIndexType()); + S_in[start_index_map[i]] = start_index; + batch_induction_vars.erase(batch_induction_vars.begin() + + index_vector_dim); + } + } else { + // Since index_vector_dim is equal to start_indicesRank we can directly + // fetch the start_index from batch_induction_vars. + Value start_index = rewriter.create(loc, start_indices, + batch_induction_vars); + start_index = rewriter.create(loc, start_index, + rewriter.getIndexType()); + S_in[0] = start_index; + } + + // We load value at a particular operand index and populate the output + // tensor if the index constraints match. + Value load_value = + rewriter.create(loc, operand, operand_index); + SmallVector predicates; + // Adding offsets to the corresponding starting index and comparing it with + // the corresponding operand index. + for (unsigned k = 0, i = 0; k < start_index_map.size(); k++) { + i = start_index_map[k]; + Value add_start_index_offset = rewriter.create( + loc, rewriter.getIndexType(), S_in[i], O_in[i]); + Value predicate = rewriter.create( + loc, CmpIPredicate::eq, add_start_index_offset, operand_index[i]); + predicates.push_back(predicate); + } + + // Since the no. of predicates is equal to start_index_map.size() we + // iterate over pairs of predicates and join them with AndOp. + unsigned num_equality_checks = start_index_map.size() / 2; + // We store the final predicate formed by joining other predicates with + // AndOp in result_predicate. + Value result_predicate = nullptr; + for (unsigned i = 0; i < num_equality_checks; i += 2) { + Value predicateA = predicates[i]; + Value predicateB = predicates[i + 1]; + Value and_predicate = + rewriter.create(loc, predicateA, predicateB); + result_predicate = (i == 0) ? and_predicate + : rewriter.create( + loc, result_predicate, and_predicate); + } + // We fetch the last predicate value. In case this is the only predicate + // we let result_predicate be equal to this predicate value. Else if there + // are odd number of predicates we join it to other predicates using AndOp. + Value predicate = predicates.back(); + if (!result_predicate) result_predicate = predicate; + // In case there are odd number of predicates we join the last predicate + // to the result_predicate using AndOp. + else if (start_index_map.size() % 2 == 1) + result_predicate = + rewriter.create(loc, result_predicate, predicate); + + // We use the loaded value if the index computed by adding offsets to + // starting index is equal to the current operand index. We use 0 as a value + // otherwise. + Value select_load = rewriter.create( + loc, result_predicate, load_value, zero_load_value); + // We load value at output array. + Value output_value = + rewriter.create(loc, output, output_induction_vars); + + // The selected value is added to the previous value stored in output array. + if (element_type.isa()) + output_value = + rewriter.create(loc, element_type, select_load, output_value); + else + output_value = + rewriter.create(loc, element_type, select_load, output_value); + rewriter.create(loc, output_value, output, + output_induction_vars); + rewriter.eraseOp(op); + return success(); + } +}; + template struct BinaryOpConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -225,7 +525,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, BinaryOpConverter, BinaryOpConverter, ConcatOpConverter, - DotOpConverter>(context); + DotOpConverter, + GatherOpConverter>(context); // clang-format on } diff --git a/tests/lhlo-legalize-to-affine.mlir b/tests/lhlo-legalize-to-affine.mlir index 9177f43..f9a5cf0 100644 --- a/tests/lhlo-legalize-to-affine.mlir +++ b/tests/lhlo-legalize-to-affine.mlir @@ -238,3 +238,204 @@ func @concatenate_dynamic(%arg0: memref<1x?xf32>, %arg1: memref<1x?xf32>, %arg2: "lmhlo.copy"(%0, %arg2) : (memref<1x?xf32>, memref<1x?xf32>) -> () "lmhlo.terminator"() : () -> () } + +// Gather op. +// Test case 1: A general GatherOp test case. +// CHECK-LABEL: func @gather_1 +// CHECK-SAME: (%[[OPERAND:.*]]: memref<28996x512xf32>, %[[START_INDICES:.*]]: memref<1x128xi32>, %[[OUTPUT:.*]]: memref<1x128x512xf32>) +func @gather_1(%arg0: memref<28996x512xf32>, %arg1: memref<1x128xi32>, %arg2: memref<1x128x512xf32>) { + %0 = memref.alloc() : memref<1x128x512xf32> + "lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<0> : tensor<1xi64>, + index_vector_dim = 2 : i64, + offset_dims = dense<2> : tensor<1xi64>, + start_index_map = dense<0> : tensor<1xi64>}, + indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[1, 512]> : tensor<2xi64>} : + (memref<28996x512xf32>, memref<1x128xi32>, memref<1x128x512xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<1x128x512xf32>, memref<1x128x512xf32>) -> () + "lmhlo.terminator"() : () -> () +} +// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<1x128x512xf32> +// CHECK-NEXT: affine.for %{{.*}} = 0 to 1 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 128 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 512 { +// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<1x128x512xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 1 { +// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 128 { +// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 512 { +// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 28996 { +// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %[[batch1]]] : memref<1x128xi32> +// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index +// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[offset0]]] : memref<28996x512xf32> +// CHECK-NEXT: %[[pred:.*]] = cmpi eq, %[[S_in0]], %[[iv0]] : index +// CHECK-NEXT: %[[selected_value:.*]] = select %[[pred]], %[[operand_val]], %[[zero]] : f32 +// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<1x128x512xf32> +// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32 +// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<1x128x512xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// Test case 2: Checks for multi-dimensional starting indices. +// CHECK-LABEL: func @gather_2 +// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf32>, %[[START_INDICES:.*]]: memref<5x2xi32>, %[[OUTPUT:.*]]: memref<5x8x6xf32>) +func @gather_2(%arg0: memref<16x11xf32>, %arg1: memref<5x2xi32>, %arg2: memref<5x8x6xf32>) { + %0 = memref.alloc() : memref<5x8x6xf32> + "lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>, + index_vector_dim = 1 : i64, + offset_dims = dense<[1,2]> : tensor<2xi64>, + start_index_map = dense<[0,1]> : tensor<2xi64>}, + indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[8, 6]> : tensor<2xi64>} : + (memref<16x11xf32>, memref<5x2xi32>, memref<5x8x6xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<5x8x6xf32>, memref<5x8x6xf32>) -> () + "lmhlo.terminator"() : () -> () +} +// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %c0 = constant 0 : index +// CHECK-NEXT: %c1 = constant 1 : index +// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<5x8x6xf32> +// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 8 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 { +// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<5x8x6xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 5 { +// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 8 { +// CHECK-NEXT: affine.for %[[offset1:.*]] = 0 to 6 { +// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 { +// CHECK-NEXT: affine.for %[[iv1:.*]] = 0 to 11 { +// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c0] : memref<5x2xi32> +// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index +// CHECK-NEXT: %[[b:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c1] : memref<5x2xi32> +// CHECK-NEXT: %[[S_in1:.*]] = index_cast %[[b]] : i32 to index +// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[iv1]]] : memref<16x11xf32> +// CHECK-NEXT: %[[In0:.*]] = addi %[[S_in0]], %[[offset0]] : index +// CHECK-NEXT: %[[pred1:.*]] = cmpi eq, %[[In0]], %[[iv0]] : index +// CHECK-NEXT: %[[In1:.*]] = addi %[[S_in1]], %[[offset1]] : index +// CHECK-NEXT: %[[pred2:.*]] = cmpi eq, %[[In1]], %[[iv1]] : index +// CHECK-NEXT: %[[and1:.*]] = and %[[pred1]], %[[pred2]] : i1 +// CHECK-NEXT: %[[selected_value:.*]] = select %[[and1]], %[[operand_val]], %[[zero]] : f32 +// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[offset0]], %[[offset1]]] : memref<5x8x6xf32> +// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32 +// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[offset0]], %[[offset1]]] : memref<5x8x6xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// Test case 3: Checks for multi-dimensional start_indices with multi-dimensional batch size. This also tests for f16 type. +// CHECK-LABEL: func @gather_3 +// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf16>, %[[START_INDICES:.*]]: memref<4x2x5xi32>, %[[OUTPUT:.*]]: memref<4x5x8x6xf16>) +func @gather_3(%arg0: memref<16x11xf16>, %arg1: memref<4x2x5xi32>, %arg2: memref<4x5x8x6xf16>) { + %0 = memref.alloc() : memref<4x5x8x6xf16> + "lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>, + index_vector_dim = 1 : i64, + offset_dims = dense<[2,3]> : tensor<2xi64>, + start_index_map = dense<[0,1]> : tensor<2xi64>}, + indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[8, 6]> : tensor<2xi64>} : + (memref<16x11xf16>, memref<4x2x5xi32>, memref<4x5x8x6xf16>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<4x5x8x6xf16>, memref<4x5x8x6xf16>) -> () + "lmhlo.terminator"() : () -> () +} +// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f16 +// CHECK-NEXT: %c0 = constant 0 : index +// CHECK-NEXT: %c1 = constant 1 : index +// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<4x5x8x6xf16> +// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 8 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 { +// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<4x5x8x6xf16> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 4 { +// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 5 { +// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 8 { +// CHECK-NEXT: affine.for %[[offset1:.*]] = 0 to 6 { +// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 { +// CHECK-NEXT: affine.for %[[iv1:.*]] = 0 to 11 { +// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c0, %[[batch1]]] : memref<4x2x5xi32> +// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index +// CHECK-NEXT: %[[b:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %c1, %[[batch1]]] : memref<4x2x5xi32> +// CHECK-NEXT: %[[S_in1:.*]] = index_cast %[[b]] : i32 to index +// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[iv1]]] : memref<16x11xf16> +// CHECK-NEXT: %[[In0:.*]] = addi %[[S_in0]], %[[offset0]] : index +// CHECK-NEXT: %[[pred1:.*]] = cmpi eq, %[[In0]], %[[iv0]] : index +// CHECK-NEXT: %[[In1:.*]] = addi %[[S_in1]], %[[offset1]] : index +// CHECK-NEXT: %[[pred2:.*]] = cmpi eq, %[[In1]], %[[iv1]] : index +// CHECK-NEXT: %[[and1:.*]] = and %[[pred1]], %[[pred2]] : i1 +// CHECK-NEXT: %[[selected_value:.*]] = select %[[and1]], %[[operand_val]], %[[zero]] : f16 +// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]], %[[offset1]]] : memref<4x5x8x6xf16> +// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f16 +// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]], %[[offset1]]] : memref<4x5x8x6xf16> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// Test case 4: Changing starting_index_map : X -> [0,X] +// CHECK-LABEL: func @gather_4 +// CHECK-SAME: (%[[OPERAND:.*]]: memref<16x11xf32>, %[[START_INDICES:.*]]: memref<5x4xi32>, %[[OUTPUT:.*]]: memref<4x5x6xf32>) +func @gather_4(%arg0: memref<16x11xf32>, %arg1: memref<5x4xi32>, %arg2: memref<4x5x6xf32>) { + %0 = memref.alloc() : memref<4x5x6xf32> + "lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<0> : tensor<1xi64>, + index_vector_dim = 2 : i64, + offset_dims = dense<2> : tensor<1xi64>, + start_index_map = dense<0> : tensor<1xi64>}, + indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[1, 6]> : tensor<2xi64>} : + (memref<16x11xf32>, memref<5x4xi32>, memref<4x5x6xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<4x5x6xf32>, memref<4x5x6xf32>) -> () + "lmhlo.terminator"() : () -> () +} +// CHECK-NEXT: %[[zero:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[temp_output:.*]] = memref.alloc() : memref<4x5x6xf32> +// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 5 { +// CHECK-NEXT: affine.for %{{.*}} = 0 to 6 { +// CHECK-NEXT: affine.store %[[zero]], %[[temp_output]][%{{.*}}, %{{.*}}, %{{.*}}] : memref<4x5x6xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: affine.for %[[batch0:.*]] = 0 to 5 { +// CHECK-NEXT: affine.for %[[batch1:.*]] = 0 to 4 { +// CHECK-NEXT: affine.for %[[offset0:.*]] = 0 to 6 { +// CHECK-NEXT: affine.for %[[iv0:.*]] = 0 to 16 { +// CHECK-NEXT: %[[a:.*]] = affine.load %[[START_INDICES]][%[[batch0]], %[[batch1]]] : memref<5x4xi32> +// CHECK-NEXT: %[[S_in0:.*]] = index_cast %[[a]] : i32 to index +// CHECK-NEXT: %[[operand_val:.*]] = affine.load %[[OPERAND]][%[[iv0]], %[[offset0]]] : memref<16x11xf32> +// CHECK-NEXT: %[[pred:.*]] = cmpi eq, %[[S_in0]], %[[iv0]] : index +// CHECK-NEXT: %[[selected_value:.*]] = select %[[pred]], %[[operand_val]], %[[zero]] : f32 +// CHECK-NEXT: %[[prev_value:.*]] = affine.load %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<4x5x6xf32> +// CHECK-NEXT: %[[final_value:.*]] = addf %[[selected_value]], %[[prev_value]] : f32 +// CHECK-NEXT: affine.store %[[final_value]], %[[temp_output]][%[[batch0]], %[[batch1]], %[[offset0]]] : memref<4x5x6xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// Test case 5: Testing for more than two equality checks. +// CHECK-LABEL: func @gather_5 +func @gather_5(%arg0: memref<28996x512x256xf32>, %arg1: memref<10x3xi32>, %arg2: memref<10x20x10x5xf32>) { + %0 = memref.alloc() : memref<10x20x10x5xf32> + "lmhlo.gather"(%arg0, %arg1, %0) {dimension_numbers = { collapsed_slice_dims = dense<-1> : tensor<1xi64>, + index_vector_dim = 1 : i64, + offset_dims = dense<[1,2,3]> : tensor<3xi64>, + start_index_map = dense<[0,1,2]> : tensor<3xi64>}, + indices_are_sorted = false, name = "gather.381", slice_sizes = dense<[20, 10, 5]> : tensor<3xi64>} : + (memref<28996x512x256xf32>, memref<10x3xi32>, memref<10x20x10x5xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<10x20x10x5xf32>, memref<10x20x10x5xf32>) -> () + "lmhlo.terminator"() : () -> () +} +// CHECK: %[[and1:.*]] = and %{{.*}}, %{{.*}} : i1 +// CHECK-NEXT: and %[[and1]], %{{.*}} : i1