PR #49970: [MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49970 1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim 2, add a flag `convert-to-lmhlo-only` to seperate following two case: - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo counterparts, do not apply any optimization (e.g. elide any buffer copy). Buffer optimization is not easy in dynamic shape world especially when involving control flow, thus we leave this to another dedicated pass. - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo ops (e.g. reshape) to memref dialect directly and Lowers others to their lmhlo counterparts. Copybara import of the project: -- 562bd65a368f6194405c4ae6900e3b4388a5ec03 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim 1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim 2, add a flag `convert-to-lmhlo-only` to seperate following two case: - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo counterparts, do not apply any optimization (e.g. elide any buffer copy). Buffer optimization is not easy in dynamic shape world especially when involving control flow, thus we leave this to another dedicated pass. - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo ops (e.g. reshape) to memref dialect directly and Lowers others to their lmhlo counterparts. PiperOrigin-RevId: 377603395
This commit is contained in:
parent
8b3a75ea25
commit
ade873a5e0
1
BUILD
1
BUILD
|
@ -880,6 +880,7 @@ cc_library(
|
||||||
":hlo",
|
":hlo",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
":map_hlo_to_lhlo_op",
|
":map_hlo_to_lhlo_op",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:MemRefDialect",
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
|
|
|
@ -35,6 +35,11 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class HLO_ShapedInterfaceOp<string mnemonic, list<OpTrait> traits> :
|
||||||
|
HLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||||
|
["reifyReturnTypeShapes"]>]> {
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">;
|
def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">;
|
||||||
def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">;
|
def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">;
|
||||||
def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">;
|
def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">;
|
||||||
|
@ -1277,9 +1282,8 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [
|
def HLO_DynamicBroadcastInDimOp : HLO_ShapedInterfaceOp<
|
||||||
NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
"dynamic_broadcast_in_dim", [NoSideEffect]> {
|
||||||
["inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
|
|
||||||
let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
|
let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
|
||||||
let description = [{
|
let description = [{
|
||||||
This is a generalization of the BroadcastInDimOp which accepts its output
|
This is a generalization of the BroadcastInDimOp which accepts its output
|
||||||
|
@ -1671,7 +1675,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> {
|
def HLO_DynamicReshapeOp: HLO_ShapedInterfaceOp<"dynamic_reshape", [NoSideEffect]> {
|
||||||
let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
|
let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Reshapes `operand` to `output_shape`.
|
Reshapes `operand` to `output_shape`.
|
||||||
|
|
|
@ -1535,7 +1535,7 @@ def LHLO_DynamicReshapeOp: LHLO_Op<"dynamic_reshape", []> {
|
||||||
}];
|
}];
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
Arg<LHLO_IntBuffer, "", [MemRead]>:$shape,
|
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$shape,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,6 +54,8 @@ MAP_HLO_TO_LHLO(CosOp);
|
||||||
MAP_HLO_TO_LHLO(CustomCallOp);
|
MAP_HLO_TO_LHLO(CustomCallOp);
|
||||||
MAP_HLO_TO_LHLO(DivOp);
|
MAP_HLO_TO_LHLO(DivOp);
|
||||||
MAP_HLO_TO_LHLO(DotOp);
|
MAP_HLO_TO_LHLO(DotOp);
|
||||||
|
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
|
||||||
|
MAP_HLO_TO_LHLO(DynamicReshapeOp);
|
||||||
MAP_HLO_TO_LHLO(ExpOp);
|
MAP_HLO_TO_LHLO(ExpOp);
|
||||||
MAP_HLO_TO_LHLO(Expm1Op);
|
MAP_HLO_TO_LHLO(Expm1Op);
|
||||||
MAP_HLO_TO_LHLO(FloorOp);
|
MAP_HLO_TO_LHLO(FloorOp);
|
||||||
|
|
|
@ -29,6 +29,12 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> {
|
||||||
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||||
let summary = "Legalize from HLO dialect to LHLO dialect.";
|
let summary = "Legalize from HLO dialect to LHLO dialect.";
|
||||||
let constructor = "createLegalizeToLhloPass()";
|
let constructor = "createLegalizeToLhloPass()";
|
||||||
|
let options = [
|
||||||
|
Option<"convert_to_lmhlo_only_", "convert-to-lmhlo-only", "bool",
|
||||||
|
/*default=*/"false", "If enabled, simply lower all mhlo ops to their lmhlo counterparts, "
|
||||||
|
"otherwise, some metadata-only ops (e.g. reshape) may be lowerred "
|
||||||
|
"to memref dialect to elide some buffer copy.">,
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
|
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
|
||||||
|
|
|
@ -50,7 +50,8 @@ std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
|
||||||
|
|
||||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
/// buffers if necessary.
|
/// buffers if necessary.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||||
|
bool convert_to_lmhlo_only = false);
|
||||||
|
|
||||||
// Lowers from HLO dialect to Linalg dialect.
|
// Lowers from HLO dialect to Linalg dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
||||||
|
|
|
@ -45,15 +45,24 @@ void PopulateGatherToTorchIndexSelectPatterns(
|
||||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||||
MLIRContext *ctx);
|
MLIRContext *ctx);
|
||||||
|
|
||||||
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
|
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO or memref
|
||||||
void populateDynamicHLOToLHLOConversionPattern(
|
// dialect.
|
||||||
|
void populateDynamicHLOToLHLOOrMemRefConversionPattern(
|
||||||
MLIRContext *context, BufferizeTypeConverter *converter,
|
MLIRContext *context, BufferizeTypeConverter *converter,
|
||||||
OwningRewritePatternList *patterns, bool insert_copy = true);
|
OwningRewritePatternList *patterns, bool insert_copy = true);
|
||||||
|
|
||||||
|
// Collection of rewrite patterns for simply lowering all mhlo ops to their
|
||||||
|
// lmhlo counterparts, do not apply any optimization (e.g. elide any buffer
|
||||||
|
// copy).
|
||||||
|
void populateDynamicHLOToLHLOOnlyConversionPattern(
|
||||||
|
MLIRContext *context, BufferizeTypeConverter *converter,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
|
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
|
||||||
void populateHLOToLHLOConversionPattern(MLIRContext *context,
|
void populateHLOToLHLOConversionPattern(MLIRContext *context,
|
||||||
BufferizeTypeConverter *converter,
|
BufferizeTypeConverter *converter,
|
||||||
OwningRewritePatternList *patterns);
|
OwningRewritePatternList *patterns,
|
||||||
|
bool convert_to_lmhlo_only = false);
|
||||||
|
|
||||||
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
|
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
|
||||||
void populateHLOToLinalgConversionPattern(MLIRContext *context,
|
void populateHLOToLinalgConversionPattern(MLIRContext *context,
|
||||||
|
|
|
@ -1021,12 +1021,6 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
||||||
context);
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
|
|
||||||
MLIRContext*, llvm::Optional<mlir::Location>, ValueRange, DictionaryAttr,
|
|
||||||
RegionRange, llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
|
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
|
||||||
OpBuilder&, ValueRange operands,
|
OpBuilder&, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
@ -1411,6 +1405,14 @@ static LogicalResult Verify(DynamicReshapeOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder&, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
DynamicReshapeOp::Adaptor adaptor(operands);
|
||||||
|
reifiedReturnShapes.push_back(adaptor.output_shape());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DynamicReshapeOpNotActuallyDynamic
|
class DynamicReshapeOpNotActuallyDynamic
|
||||||
: public OpRewritePattern<DynamicReshapeOp> {
|
: public OpRewritePattern<DynamicReshapeOp> {
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
@ -576,8 +577,14 @@ class HloToLhloTensorStoreOpLegacyConverter
|
||||||
// "lmhlo.terminator"() : () -> ()
|
// "lmhlo.terminator"() : () -> ()
|
||||||
// }
|
// }
|
||||||
|
|
||||||
struct HloLegalizeToLhlo
|
struct HloLegalizeToLhlo : public HloLegalizeToLhloPassBase<HloLegalizeToLhlo> {
|
||||||
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
|
using HloLegalizeToLhloPassBase<HloLegalizeToLhlo>::HloLegalizeToLhloPassBase;
|
||||||
|
explicit HloLegalizeToLhlo(bool convert_to_lmhlo_only)
|
||||||
|
: HloLegalizeToLhloPassBase<
|
||||||
|
HloLegalizeToLhlo>::HloLegalizeToLhloPassBase() {
|
||||||
|
this->convert_to_lmhlo_only_ = convert_to_lmhlo_only;
|
||||||
|
}
|
||||||
|
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
|
registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
|
||||||
shape::ShapeDialect>();
|
shape::ShapeDialect>();
|
||||||
|
@ -623,7 +630,8 @@ struct HloLegalizeToLhlo
|
||||||
isMemRefType);
|
isMemRefType);
|
||||||
});
|
});
|
||||||
|
|
||||||
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
populateHLOToLHLOConversionPattern(&context, &converter, &patterns,
|
||||||
|
convert_to_lmhlo_only_);
|
||||||
populateFuncOpTypeConversionPattern(patterns, converter);
|
populateFuncOpTypeConversionPattern(patterns, converter);
|
||||||
populateCallOpTypeConversionPattern(patterns, converter);
|
populateCallOpTypeConversionPattern(patterns, converter);
|
||||||
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
|
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
|
||||||
|
@ -643,7 +651,9 @@ struct HloLegalizeToLhlo
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateDynamicHLOToLHLOConversionPattern(
|
// Lowers some metadata-only mhlo ops (e.g. reshape) to memref dialect
|
||||||
|
// directly and Lowers others to their lmhlo counterparts.
|
||||||
|
void populateDynamicHLOToLHLOOrMemRefConversionPattern(
|
||||||
MLIRContext* context, BufferizeTypeConverter* converter,
|
MLIRContext* context, BufferizeTypeConverter* converter,
|
||||||
OwningRewritePatternList* patterns, bool insert_copy) {
|
OwningRewritePatternList* patterns, bool insert_copy) {
|
||||||
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
|
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
|
||||||
|
@ -652,10 +662,28 @@ void populateDynamicHLOToLHLOConversionPattern(
|
||||||
HloToLhloReshapeUnrankedConverter>(*converter, context);
|
HloToLhloReshapeUnrankedConverter>(*converter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simply lowers all mhlo ops to their lmhlo counterparts, do not apply
|
||||||
|
// any optimization (e.g. elide any buffer copy).
|
||||||
|
void populateDynamicHLOToLHLOOnlyConversionPattern(
|
||||||
|
MLIRContext* context, BufferizeTypeConverter* converter,
|
||||||
|
OwningRewritePatternList* patterns) {
|
||||||
|
// clang-format off
|
||||||
|
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::DynamicReshapeOp>
|
||||||
|
>(*converter, context);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
BufferizeTypeConverter* converter,
|
BufferizeTypeConverter* converter,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns,
|
||||||
populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
|
bool convert_to_lmhlo_only) {
|
||||||
|
if (convert_to_lmhlo_only) {
|
||||||
|
populateDynamicHLOToLHLOOnlyConversionPattern(context, converter, patterns);
|
||||||
|
} else {
|
||||||
|
populateDynamicHLOToLHLOOrMemRefConversionPattern(context, converter,
|
||||||
|
patterns);
|
||||||
|
}
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
HloToLhloCustomCallOpConverter,
|
HloToLhloCustomCallOpConverter,
|
||||||
|
@ -712,8 +740,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||||
return std::make_unique<HloLegalizeToLhlo>();
|
bool convert_to_lmhlo_only) {
|
||||||
|
return std::make_unique<HloLegalizeToLhlo>(convert_to_lmhlo_only);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=convert-to-lmhlo-only=true \
|
||||||
|
// RUN: -canonicalize -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_reshape
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[SHAPE:.*]]: memref<3xindex>) -> memref<?x?x?xf32>
|
||||||
|
func @dynamic_reshape(%lhs: tensor<?x?xf32>, %rhs: tensor<3xindex>) -> tensor<?x?x?xf32> {
|
||||||
|
// CHECK-NOT: tensor_load
|
||||||
|
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0]
|
||||||
|
// CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1]
|
||||||
|
// CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2]
|
||||||
|
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
|
||||||
|
// CHECK: "lmhlo.dynamic_reshape"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]])
|
||||||
|
// CHECK: return %[[OUTPUT]]
|
||||||
|
%result = "mhlo.dynamic_reshape"(%lhs, %rhs)
|
||||||
|
: (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
|
return %result : tensor<?x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[SHAPE:.*]]: memref<3xindex>) -> memref<?x?x?xf32>
|
||||||
|
func @dynamic_broadcast_in_dim(%operand: tensor<?x?xf32>, %shape: tensor<3xindex>) -> tensor<?x?x?xf32> {
|
||||||
|
// CHECK-NOT: tensor_load
|
||||||
|
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0]
|
||||||
|
// CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1]
|
||||||
|
// CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2]
|
||||||
|
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
|
||||||
|
// CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]])
|
||||||
|
// CHECK: return %[[OUTPUT]]
|
||||||
|
%result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
|
||||||
|
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||||
|
} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
|
return %result : tensor<?x?x?xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue