[XLA/GPU] Add XLA HLO -> LMHLO conversion to several ops, and implement them in XLA/GPU.
PiperOrigin-RevId: 353158172
This commit is contained in:
parent
ae10640d78
commit
d1c785381d
|
@ -713,7 +713,7 @@ def HLO_SliceOp: HLO_Op<
|
|||
}
|
||||
|
||||
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
||||
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> {
|
||||
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>]>, BASE_HLO_DynamicSliceOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
Variadic<HLO_ScalarIntTensor>:$start_indices,
|
||||
|
@ -726,7 +726,7 @@ def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
|||
|
||||
def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
|
||||
[NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>,
|
||||
AllShapesMatch<["operand", "result"]>]> {
|
||||
AllShapesMatch<["operand", "result"]>]>, BASE_HLO_DynamicUpdateSliceOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
HLO_Tensor:$update,
|
||||
|
|
|
@ -214,9 +214,7 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_Reduce
|
|||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
|
||||
SingleBlockImplicitTerminator<"TerminatorOp">
|
||||
]>, BASE_HLO_ReduceWindowOp {
|
||||
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
|
@ -309,12 +307,22 @@ def LHLO_SliceOp: LHLO_Op<
|
|||
);
|
||||
}
|
||||
|
||||
def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
||||
def LHLO_DynamicSliceOp: LHLO_Op<"dynamic_slice",
|
||||
[AllElementTypesMatch<["operand", "output"]>]>, BASE_HLO_DynamicSliceOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
I64ElementsAttr:$slice_sizes
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []>, BASE_HLO_DynamicUpdateSliceOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$update,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -135,6 +135,15 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
|
|||
std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
||||
mlir::MLIRContext *context) {
|
||||
assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
|
||||
|
||||
if (op_name == "lmhlo.dot") {
|
||||
return "mhlo.dot_general";
|
||||
}
|
||||
|
||||
if (op_name == "lmhlo.dynamic_slice") {
|
||||
return "mhlo.dynamic-slice";
|
||||
}
|
||||
|
||||
std::string mhlo_op_name(op_name.drop_front(1));
|
||||
if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
|
||||
return "";
|
||||
|
|
Loading…
Reference in New Issue