Add zeta and broadcasting_zeta to chlo dialect.
PiperOrigin-RevId: 354500879
This commit is contained in:
parent
eb8d5a5e39
commit
e61ef86fdb
|
@ -66,7 +66,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
||||||
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
||||||
// shape broadcasting.
|
// shape broadcasting.
|
||||||
//
|
//
|
||||||
// These correspond to operations in the mhlo dialect without the
|
// These correspond to operations in the chlo and mhlo dialects without the
|
||||||
// "broadcast_" prefix, except that those ops require same-shaped operands and
|
// "broadcast_" prefix, except that those ops require same-shaped operands and
|
||||||
// results.
|
// results.
|
||||||
//
|
//
|
||||||
|
@ -256,8 +256,31 @@ def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOCLient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_zeta",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "Hurwitz zeta function";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Zeta(operand, operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\)
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_FpTensor:$lhs,
|
||||||
|
HLO_FpTensor:$rhs,
|
||||||
|
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
||||||
|
// padded rank-broadcast semantics if omitted.
|
||||||
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||||
|
);
|
||||||
|
let results = (outs HLO_FpTensor);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// XLA binary elementwise op definitions.
|
// XLA binary logical elementwise op definitions.
|
||||||
// The same description as the arithmetic binary elementwise ops applies.
|
// The same description as the arithmetic binary elementwise ops applies.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
@ -309,6 +332,38 @@ def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA non-broadcasting binary operations.
|
||||||
|
//
|
||||||
|
// These are operations that are supported by the XLA Builder API but that are
|
||||||
|
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def HLOClient_ZetaOp : HLOClient_Op<"zeta",
|
||||||
|
[NoSideEffect, SameOperandsAndResultType]> {
|
||||||
|
let summary = "Hurwitz zeta function";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Zeta(operand, operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\)
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_FpTensor:$lhs,
|
||||||
|
HLO_FpTensor:$rhs
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs HLO_FpTensor);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$lhs `,` $rhs attr-dict `:`
|
||||||
|
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Broadcasting complex op
|
// Broadcasting complex op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -53,7 +53,7 @@ struct HloCompareAdaptor {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
|
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
|
||||||
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
|
// to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values.
|
||||||
template <template <typename, typename, typename> class Pattern,
|
template <template <typename, typename, typename> class Pattern,
|
||||||
typename... ConstructorArgs>
|
typename... ConstructorArgs>
|
||||||
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||||
|
@ -79,6 +79,7 @@ void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||||
|
POPULATE_BCAST(BroadcastZetaOp, ZetaOp);
|
||||||
|
|
||||||
// Broadcasting ops requiring special construction.
|
// Broadcasting ops requiring special construction.
|
||||||
patterns
|
patterns
|
||||||
|
|
|
@ -318,6 +318,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
|
||||||
|
|
||||||
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
|
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
|
||||||
#undef BROADCAST_BINARY_OP_DEFS
|
#undef BROADCAST_BINARY_OP_DEFS
|
||||||
|
|
|
@ -39,14 +39,14 @@ struct ChloLegalizeToHloPass
|
||||||
OwningRewritePatternList conversionPatterns;
|
OwningRewritePatternList conversionPatterns;
|
||||||
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
||||||
|
|
||||||
// Consider the mhlo dialect legal for tests.
|
// Consider the mhlo dialect legal for tests. Also add helper dialects
|
||||||
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
// that are needed by the patterns.
|
||||||
|
conversionTarget.addLegalDialect<
|
||||||
|
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
|
||||||
|
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
||||||
|
|
||||||
// The conversion uses helpers from the standard dialect.
|
// TODO(herhut): This is temporary while Zeta cannot be lowered to hlo.
|
||||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
conversionTarget.addLegalOp<chlo::ZetaOp>();
|
||||||
conversionTarget.addLegalDialect<mlir::tensor::TensorDialect>();
|
|
||||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
|
||||||
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
|
||||||
|
|
||||||
chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
||||||
|
|
||||||
|
|
|
@ -237,3 +237,13 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||||
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||||
return %0 : tensor<4xi1>
|
return %0 : tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @ZetaWithoutBroadcast
|
||||||
|
func @ZetaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>)
|
||||||
|
-> tensor<4xf32> {
|
||||||
|
// CHECK: chlo.zeta %arg0, %arg1
|
||||||
|
%0 = chlo.broadcast_zeta %arg0, %arg1
|
||||||
|
: (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
return %0 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue