From 909574e393027f9a68b4e4bfbfdda623445dcfea Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Sat, 20 Feb 2021 07:06:24 -0800 Subject: [PATCH] Pass result element type to XlaBuilder for `mhlo.dot_general` and `mhlo.convolution` ops. `mhlo.dot_general` and `mhlo.convolution` result element type might be different from operand element type. See `preferred_element_type` attribute that allows i8xi8 to i32 dot computation. `mhlo` to HLO exporter should pass the result element type to Xla builder to override the shape inference of XLA. PiperOrigin-RevId: 358580718 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 77b5159..cca165e 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -960,6 +960,9 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, let results = (outs HLO_Tensor); let verifier = [{ return Verify(*this); }]; + // DotGeneral op required custom exporter to pass the preferred element type + // to Xla builder. + let hasCustomHLOConverter = 1; } // Define Base Einsum op within the HLO dialect as these are client ops and