Update mhlo.constant to use a custom assembly format instead of a custom printer and parser (NFC).
PiperOrigin-RevId: 325560779
This commit is contained in:
parent
2b0a244a6b
commit
53fdda7f3e
|
@ -67,8 +67,7 @@ def HLO_ConstOp : HLO_Op<"constant",
|
|||
"OpBuilder &builder, OperationState &result, Attribute value"
|
||||
>];
|
||||
|
||||
let printer = [{ return Print(*this, &p); }];
|
||||
let parser = [{ return ParseConstOp(&parser, &result); }];
|
||||
let assemblyFormat = "attr-dict $value";
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
|
|
|
@ -112,37 +112,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
|
|||
// ConstOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void Print(ConstOp op, OpAsmPrinter* printer) {
|
||||
// Print op name.
|
||||
*printer << op.getOperationName();
|
||||
|
||||
// Elide attribute value while printing the attribute dictionary.
|
||||
SmallVector<StringRef, 1> elided_attrs;
|
||||
elided_attrs.push_back("value");
|
||||
printer->printOptionalAttrDict(op.getAttrs(), elided_attrs);
|
||||
|
||||
*printer << ' ' << op.value();
|
||||
}
|
||||
|
||||
static ParseResult ParseConstOp(OpAsmParser* parser, OperationState* result) {
|
||||
if (parser->parseOptionalAttrDict(result->attributes)) return failure();
|
||||
|
||||
// If colon is not present after attribute dictionary, it should be short form
|
||||
// and attribute 'value' is outside the dictionary.
|
||||
if (failed(parser->parseOptionalColon())) {
|
||||
Attribute value;
|
||||
if (parser->parseAttribute(value, "value", result->attributes))
|
||||
return failure();
|
||||
return parser->addTypeToList(value.getType(), result->types);
|
||||
}
|
||||
|
||||
// Long form should have type of the result after colon.
|
||||
Type ty;
|
||||
if (parser->parseType(ty)) return failure();
|
||||
result->types.push_back(ty);
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
|
||||
|
|
|
@ -480,7 +480,7 @@ func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4
|
|||
// expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}}
|
||||
%0 = "mhlo.map"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<5xf32>):
|
||||
%1 = mhlo.constant {value = dense<2.0> : tensor<f32>} : tensor<f32>
|
||||
%1 = mhlo.constant dense<2.0> : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
|
||||
return %0 : tensor<4x5xf32>
|
||||
|
@ -492,7 +492,7 @@ func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: t
|
|||
// expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}}
|
||||
%0 = "mhlo.map"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%1 = mhlo.constant {value = dense<2.0> : tensor<f32>} : tensor<f32>
|
||||
%1 = mhlo.constant dense<2.0> : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
|
||||
return %0 : tensor<4x5xf32>
|
||||
|
@ -504,7 +504,7 @@ func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: te
|
|||
// expected-error@+1 {{computation must return single output, but got: 0}}
|
||||
%0 = "mhlo.map"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%1 = mhlo.constant {value = dense<2.0> : tensor<f32>} : tensor<f32>
|
||||
%1 = mhlo.constant dense<2.0> : tensor<f32>
|
||||
"mhlo.return"() : () -> ()
|
||||
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
|
||||
return %0 : tensor<4x5xf32>
|
||||
|
@ -516,7 +516,7 @@ func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4
|
|||
// expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}}
|
||||
%0 = "mhlo.map"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%1 = mhlo.constant {value = dense<2.0> : tensor<f32>} : tensor<5xf32>
|
||||
%1 = mhlo.constant dense<2.0> : tensor<5xf32>
|
||||
"mhlo.return"(%1) : (tensor<5xf32>) -> ()
|
||||
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
|
||||
return %0 : tensor<4x5xf32>
|
||||
|
@ -528,7 +528,7 @@ func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5
|
|||
// expected-error@+1 {{element type of result and computation output must match, but got: 'f32' and 'i32'}}
|
||||
%0 = "mhlo.map"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%1 = mhlo.constant {value = dense<2> : tensor<i32>} : tensor<i32>
|
||||
%1 = mhlo.constant dense<2> : tensor<i32>
|
||||
"mhlo.return"(%1) : (tensor<i32>) -> ()
|
||||
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
|
||||
return %0 : tensor<4x5xf32>
|
||||
|
|
Loading…
Reference in New Issue