Update mhlo.constant to use a custom assembly format instead of a custom printer and parser (NFC).

PiperOrigin-RevId: 325560779
This commit is contained in:
Andy Ly 2020-08-07 21:33:04 -07:00 committed by Geoffrey Martin-Noble
parent 2b0a244a6b
commit 53fdda7f3e
3 changed files with 6 additions and 38 deletions

View File

@ -67,8 +67,7 @@ def HLO_ConstOp : HLO_Op<"constant",
"OpBuilder &builder, OperationState &result, Attribute value" "OpBuilder &builder, OperationState &result, Attribute value"
>]; >];
let printer = [{ return Print(*this, &p); }]; let assemblyFormat = "attr-dict $value";
let parser = [{ return ParseConstOp(&parser, &result); }];
let hasFolder = 1; let hasFolder = 1;

View File

@ -112,37 +112,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
// ConstOp // ConstOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void Print(ConstOp op, OpAsmPrinter* printer) {
// Print op name.
*printer << op.getOperationName();
// Elide attribute value while printing the attribute dictionary.
SmallVector<StringRef, 1> elided_attrs;
elided_attrs.push_back("value");
printer->printOptionalAttrDict(op.getAttrs(), elided_attrs);
*printer << ' ' << op.value();
}
static ParseResult ParseConstOp(OpAsmParser* parser, OperationState* result) {
if (parser->parseOptionalAttrDict(result->attributes)) return failure();
// If colon is not present after attribute dictionary, it should be short form
// and attribute 'value' is outside the dictionary.
if (failed(parser->parseOptionalColon())) {
Attribute value;
if (parser->parseAttribute(value, "value", result->attributes))
return failure();
return parser->addTypeToList(value.getType(), result->types);
}
// Long form should have type of the result after colon.
Type ty;
if (parser->parseType(ty)) return failure();
result->types.push_back(ty);
return success();
}
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands"); assert(operands.empty() && "constant has no operands");

View File

@ -480,7 +480,7 @@ func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4
// expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}} // expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}}
%0 = "mhlo.map"(%arg0, %arg1) ( { %0 = "mhlo.map"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<5xf32>): ^bb0(%arg2: tensor<f32>, %arg3: tensor<5xf32>):
%1 = mhlo.constant {value = dense<2.0> : tensor<f32>} : tensor<f32> %1 = mhlo.constant dense<2.0> : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> () "mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
return %0 : tensor<4x5xf32> return %0 : tensor<4x5xf32>
@ -492,7 +492,7 @@ func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: t
// expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}} // expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}}
%0 = "mhlo.map"(%arg0, %arg1) ( { %0 = "mhlo.map"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>): ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = mhlo.constant {value = dense<2.0> : tensor<f32>} : tensor<f32> %1 = mhlo.constant dense<2.0> : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> () "mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
return %0 : tensor<4x5xf32> return %0 : tensor<4x5xf32>
@ -504,7 +504,7 @@ func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: te
// expected-error@+1 {{computation must return single output, but got: 0}} // expected-error@+1 {{computation must return single output, but got: 0}}
%0 = "mhlo.map"(%arg0, %arg1) ( { %0 = "mhlo.map"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.constant {value = dense<2.0> : tensor<f32>} : tensor<f32> %1 = mhlo.constant dense<2.0> : tensor<f32>
"mhlo.return"() : () -> () "mhlo.return"() : () -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
return %0 : tensor<4x5xf32> return %0 : tensor<4x5xf32>
@ -516,7 +516,7 @@ func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4
// expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}} // expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}}
%0 = "mhlo.map"(%arg0, %arg1) ( { %0 = "mhlo.map"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.constant {value = dense<2.0> : tensor<f32>} : tensor<5xf32> %1 = mhlo.constant dense<2.0> : tensor<5xf32>
"mhlo.return"(%1) : (tensor<5xf32>) -> () "mhlo.return"(%1) : (tensor<5xf32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
return %0 : tensor<4x5xf32> return %0 : tensor<4x5xf32>
@ -528,7 +528,7 @@ func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5
// expected-error@+1 {{element type of result and computation output must match, but got: 'f32' and 'i32'}} // expected-error@+1 {{element type of result and computation output must match, but got: 'f32' and 'i32'}}
%0 = "mhlo.map"(%arg0, %arg1) ( { %0 = "mhlo.map"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.constant {value = dense<2> : tensor<i32>} : tensor<i32> %1 = mhlo.constant dense<2> : tensor<i32>
"mhlo.return"(%1) : (tensor<i32>) -> () "mhlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
return %0 : tensor<4x5xf32> return %0 : tensor<4x5xf32>