diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 9704f34..282b056 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -48,7 +48,8 @@ class HloClientDialect : public Dialect { #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" template -static Value getConstantLike(OpBuilder& b, T constant, Value val) { +static Value getConstantLike(OpBuilder& b, Location loc, T constant, + Value val) { Type ty = getElementTypeOrSelf(val.getType()); auto getAttr = [&]() -> Attribute { @@ -56,8 +57,7 @@ static Value getConstantLike(OpBuilder& b, T constant, Value val) { if (ty.isa()) return b.getFloatAttr(ty, constant); llvm_unreachable("unhandled element type"); }; - // TODO(jpienaar): Add ability to pass loc via native call and update. - return b.create(b.getUnknownLoc(), getAttr(), val); + return b.create(loc, getAttr(), val); } } // namespace chlo diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index c201aef..32940cb 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -28,7 +28,7 @@ class ConstantSplat : NativeCodeCall< "hlo::getSplat(&$_builder, $0, " # value # ")">; class HLO_ConstantLike : NativeCodeCall< - "chlo::getConstantLike($_builder, " # value # ", $0)">; + "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;