From 9cbe5f2285537627c26841d06640bf454821051a Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Fri, 31 Jul 2020 11:50:22 -0700 Subject: [PATCH] Constrain mhlo.const to static shaped tensors. Constants of unknown shape cannot be materialized. In most cases, one likely wants to use a scalar constant and rely on broadcasting instead. PiperOrigin-RevId: 324252475 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 2 +- tests/ops.mlir | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 93c5388..3d7b827 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -52,7 +52,7 @@ def HLO_ConstOp : HLO_Op<"constant", ); let results = (outs - HLO_Tensor:$output + HLO_StaticShapeTensor:$output ); let builders = [OpBuilder< diff --git a/tests/ops.mlir b/tests/ops.mlir index b46827b..920e62e 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -939,7 +939,23 @@ func @constants() -> () { func @constant_invalid() -> () { // expected-error@+1 {{op failed to verify that all of {value, output} have same type}} - %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor<3xi32>) + return +} + +// ----- + +func @constant_invalid() -> () { + // expected-error@+1 {{op result #0 must be statically shaped tensor}} + %0 = "mhlo.constant"() {value = dense<1> : tensor} : () -> tensor + return +} + +// ----- + +func @constant_invalid() -> () { + // expected-error@+1 {{elements literal type must have static shape}} + %0 = "mhlo.constant"() {value = dense<1> : tensor} : () -> tensor return }