[XLA:GPU] Add conversion from HLO -> MLIR LMHLO for TriangularSolve
- Also add layout attributes for inputs and output for error checking. PiperOrigin-RevId: 355863625
This commit is contained in:
parent
1c4521cc42
commit
b251712b1d
|
@ -99,6 +99,17 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
|
||||||
// Any pred, int or floating-point tensor types
|
// Any pred, int or floating-point tensor types
|
||||||
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
|
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
|
||||||
|
|
||||||
|
// A layout attribute (1D tensor of index type)
|
||||||
|
def HLO_LayoutAttr : Attr<
|
||||||
|
And<[IndexElementsAttr.predicate,
|
||||||
|
CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
|
||||||
|
== 1}]>]>,
|
||||||
|
"A 1D tensor of index type (layout)"> {
|
||||||
|
let storageType = IndexElementsAttr.storageType;
|
||||||
|
let returnType = IndexElementsAttr.returnType;
|
||||||
|
let convertFromStorage = IndexElementsAttr.convertFromStorage;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MHLO nullary op definitions.
|
// MHLO nullary op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -641,7 +641,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType
|
||||||
BoolAttr:$left_side,
|
BoolAttr:$left_side,
|
||||||
BoolAttr:$lower,
|
BoolAttr:$lower,
|
||||||
BoolAttr:$unit_diagonal,
|
BoolAttr:$unit_diagonal,
|
||||||
HLO_TransposeAttr:$transpose_a
|
HLO_TransposeAttr:$transpose_a,
|
||||||
|
HLO_LayoutAttr:$layout_a,
|
||||||
|
HLO_LayoutAttr:$layout_b,
|
||||||
|
HLO_LayoutAttr:$layout_output
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -865,7 +865,12 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
||||||
|
|
||||||
// CHECK-LABEL: func @triangular_solve_memrefs
|
// CHECK-LABEL: func @triangular_solve_memrefs
|
||||||
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
|
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
|
||||||
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out)
|
||||||
|
{layout_a = dense<[1, 0]> : tensor<2xindex>,
|
||||||
|
layout_b = dense<[1, 0]> : tensor<2xindex>,
|
||||||
|
layout_output = dense<[1, 0]> : tensor<2xindex>,
|
||||||
|
left_side = true, lower = true, transpose_a = "NO_TRANSPOSE",
|
||||||
|
unit_diagonal = true}
|
||||||
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
|
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue