diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 2767dcd..2794727 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -642,11 +642,33 @@ void DynamicIotaOp::getCanonicalizationPatterns( results.insert(context); } +static Value castToIndexTensor(OpBuilder& builder, Location loc, + Value shape_op) { + ShapedType result_ty = shape::getExtentTensorType( + builder.getContext(), + shape_op.getType().cast().getDimSize(0)); + if (shape_op.getType() == result_ty) return shape_op; // Nothing to do. + // index_cast is not defined on tensors, so emit a tensor.generate instead. + return builder.create( + loc, result_ty, + result_ty.hasStaticShape() + ? ValueRange{} + : ValueRange{builder.create(loc, shape_op, 0)}, + [&](OpBuilder& b, Location loc, ValueRange args) { + Value dim = args.front(); + Value extent = b.create(loc, shape_op, dim); + Value casted = + b.create(loc, extent, result_ty.getElementType()); + b.create(loc, casted); + }); +} + LogicalResult DynamicIotaOp::reifyReturnTypeShapes( - OpBuilder&, ValueRange operands, + OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { DynamicIotaOp::Adaptor adaptor(operands); - reifiedReturnShapes.push_back(adaptor.output_shape()); + reifiedReturnShapes.push_back( + castToIndexTensor(builder, getLoc(), adaptor.output_shape())); return success(); } @@ -1192,10 +1214,11 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns( } LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( - OpBuilder&, ValueRange operands, + OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { DynamicBroadcastInDimOp::Adaptor adaptor(operands); - reifiedReturnShapes.push_back(adaptor.output_dimensions()); + reifiedReturnShapes.push_back( + castToIndexTensor(builder, getLoc(), adaptor.output_dimensions())); return success(); } @@ -1627,10 +1650,11 @@ static LogicalResult Verify(DynamicReshapeOp op) { } LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( - OpBuilder&, ValueRange operands, + OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { DynamicReshapeOp::Adaptor adaptor(operands); - reifiedReturnShapes.push_back(adaptor.output_shape()); + reifiedReturnShapes.push_back( + castToIndexTensor(builder, getLoc(), adaptor.output_shape())); return success(); } diff --git a/tests/reify-result-types.mlir b/tests/reify-result-types.mlir new file mode 100644 index 0000000..136c724 --- /dev/null +++ b/tests/reify-result-types.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-hlo-opt -resolve-shaped-type-result-dims -canonicalize \ +// RUN: -split-input-file %s -o - | FileCheck %s + +// CHECK-LABEL: @dynamic_broadcast_i32_shape +func @dynamic_broadcast_i32_shape(%arg0 : tensor, %arg1 : tensor<*xf32>) + -> index { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor + // CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index + // CHECK: return %[[RESULT]] + %c0 = constant 0 : index + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %arg0) + { broadcast_dimensions = dense<0> : tensor<1xi64> } + : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = memref.dim %0, %c0 : tensor<*xf32> + return %1 : index +} + +// ----- + +// CHECK-LABEL: @dynamic_iota_i32_shape +func @dynamic_iota_i32_shape(%arg0 : tensor) -> index { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor + // CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index + // CHECK: return %[[RESULT]] + %c0 = constant 0 : index + %0 = "mhlo.dynamic_iota"(%arg0) + {iota_dimension = 0 : i64} + : (tensor) -> tensor + %1 = memref.dim %0, %c0 : tensor + return %1 : index +} + +// ----- + +// CHECK-LABEL: @dynamic_reshape_i32_shape +func @dynamic_reshape_i32_shape(%arg0 : tensor, %arg1 : tensor<*xf32>) + -> index { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[DIM:.*]] = tensor.extract %arg0[%[[C0]]] : tensor + // CHECK: %[[RESULT:.*]] = index_cast %[[DIM]] : i32 to index + // CHECK: return %[[RESULT]] + %c0 = constant 0 : index + %0 = "mhlo.dynamic_reshape"(%arg1, %arg0) + { broadcast_dimensions = dense<0> : tensor<1xi64> } + : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = memref.dim %0, %c0 : tensor<*xf32> + return %1 : index +}