Rename xla_chlo dialect into chlo
Following on the plan of isolating the compiler/mlir/hlo directory. PiperOrigin-RevId: 320212018
This commit is contained in:
parent
7c4a5d62b5
commit
94dcb90d38
|
@ -28,18 +28,18 @@ limitations under the License.
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace xla_chlo {
|
namespace chlo {
|
||||||
|
|
||||||
class XlaHloClientDialect : public Dialect {
|
class HloClientDialect : public Dialect {
|
||||||
public:
|
public:
|
||||||
explicit XlaHloClientDialect(MLIRContext *context);
|
explicit HloClientDialect(MLIRContext *context);
|
||||||
static StringRef getDialectNamespace() { return "xla_chlo"; }
|
static StringRef getDialectNamespace() { return "chlo"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
||||||
|
|
||||||
} // namespace xla_chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
||||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||||
//
|
//
|
||||||
// The typical use of this dialect is for client libraries to be able to emit
|
// The typical use of this dialect is for client libraries to be able to emit
|
||||||
// less constrained ops and rely on the conversion framework to lower any
|
// less constrained ops and rely on the conversion framework to lower any
|
||||||
// xla_chlo ops to canonical mhlo ops.
|
// chlo ops to canonical mhlo ops.
|
||||||
//
|
//
|
||||||
// See: https://www.tensorflow.org/xla/operation_semantics
|
// See: https://www.tensorflow.org/xla/operation_semantics
|
||||||
|
|
||||||
|
@ -35,8 +35,8 @@ include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectIn
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
|
||||||
def HLOClient_Dialect : Dialect {
|
def HLOClient_Dialect : Dialect {
|
||||||
let name = "xla_chlo";
|
let name = "chlo";
|
||||||
let cppNamespace = "xla_chlo";
|
let cppNamespace = "chlo";
|
||||||
let summary = [{
|
let summary = [{
|
||||||
XLA Client HLO Ops
|
XLA Client HLO Ops
|
||||||
}];
|
}];
|
||||||
|
|
|
@ -84,14 +84,14 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
|
||||||
|
|
||||||
} // namespace lmhlo
|
} // namespace lmhlo
|
||||||
|
|
||||||
namespace xla_chlo {
|
namespace chlo {
|
||||||
|
|
||||||
// Populates a collection of conversion patterns for legalizing client-HLO to
|
// Populates a collection of conversion patterns for legalizing client-HLO to
|
||||||
// HLO.
|
// HLO.
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns);
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
} // namespace xla_chlo
|
} // namespace chlo
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace xla_chlo {
|
namespace chlo {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static LogicalResult Verify(T op) {
|
static LogicalResult Verify(T op) {
|
||||||
|
@ -263,10 +263,10 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// xla_chlo Dialect Constructor
|
// chlo Dialect Constructor
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
|
HloClientDialect::HloClientDialect(MLIRContext* context)
|
||||||
: Dialect(getDialectNamespace(), context) {
|
: Dialect(getDialectNamespace(), context) {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
|
@ -274,5 +274,5 @@ XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla_chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -19,6 +19,5 @@ limitations under the License.
|
||||||
|
|
||||||
// Static initialization for XLA dialect registration.
|
// Static initialization for XLA dialect registration.
|
||||||
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
|
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
|
||||||
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
|
static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops;
|
||||||
xla_chlo_ops;
|
|
||||||
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;
|
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;
|
||||||
|
|
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace xla_chlo {
|
namespace chlo {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -235,5 +235,5 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
context, patterns);
|
context, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla_chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace xla_chlo {
|
namespace chlo {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ struct TestChloLegalizeToHloPass
|
||||||
ConversionTarget conversionTarget(getContext());
|
ConversionTarget conversionTarget(getContext());
|
||||||
OwningRewritePatternList conversionPatterns;
|
OwningRewritePatternList conversionPatterns;
|
||||||
|
|
||||||
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
|
conversionTarget.addIllegalDialect<HloClientDialect>();
|
||||||
// Consider the mhlo dialect legal for tests.
|
// Consider the mhlo dialect legal for tests.
|
||||||
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
|
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
|
||||||
// The conversion uses helpers from the Standard dialect.
|
// The conversion uses helpers from the Standard dialect.
|
||||||
|
@ -49,9 +49,9 @@ struct TestChloLegalizeToHloPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace xla_chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
static mlir::PassRegistration<mlir::xla_chlo::TestChloLegalizeToHloPass> pass(
|
static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass(
|
||||||
"test-xla-chlo-legalize-to-hlo",
|
"test-xla-chlo-legalize-to-hlo",
|
||||||
"Test pass for applying chlo -> hlo legalization patterns");
|
"Test pass for applying chlo -> hlo legalization patterns");
|
||||||
|
|
|
@ -11,7 +11,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
|
||||||
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||||
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
|
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
|
||||||
// CHECK: return %[[EXTENTS]]
|
// CHECK: return %[[EXTENTS]]
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
|
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
|
||||||
return %1 : tensor<1xindex>
|
return %1 : tensor<1xindex>
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @complex_ranked_components
|
// CHECK-LABEL: @complex_ranked_components
|
||||||
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
|
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
|
||||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
|
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
|
||||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
|
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
|
||||||
return %1 : tensor<?x?xcomplex<f32>>
|
return %1 : tensor<?x?xcomplex<f32>>
|
||||||
|
@ -28,7 +28,7 @@ func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @compare_ranked_components
|
// CHECK-LABEL: @compare_ranked_components
|
||||||
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
|
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
|
||||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
|
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
|
||||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||||
return %0 : tensor<?x?xi1>
|
return %0 : tensor<?x?xi1>
|
||||||
|
@ -37,7 +37,7 @@ func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @broadcast_add_ranked_components_r1
|
// CHECK-LABEL: @broadcast_add_ranked_components_r1
|
||||||
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
|
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
|
||||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
|
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %1 : tensor<?xf32>
|
return %1 : tensor<?xf32>
|
||||||
|
@ -46,7 +46,7 @@ func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @broadcast_add_ranked_components_r1x2
|
// CHECK-LABEL: @broadcast_add_ranked_components_r1x2
|
||||||
func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?x3xf32>) -> tensor<?x3xf32> {
|
func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?x3xf32>) -> tensor<?x3xf32> {
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
|
||||||
// TODO: Overly broad shapes are being returned. Tighten the calculation
|
// TODO: Overly broad shapes are being returned. Tighten the calculation
|
||||||
// and update/extend these tests.
|
// and update/extend these tests.
|
||||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
|
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
// CHECK-LABEL: @addWithoutBroadcast
|
// CHECK-LABEL: @addWithoutBroadcast
|
||||||
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.add %arg0, %arg1
|
// CHECK: mhlo.add %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
||||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
|
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
return %0 : tensor<?x?xf32>
|
return %0 : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
|
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
|
||||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||||
return %0 : tensor<?x?xcomplex<f32>>
|
return %0 : tensor<?x?xcomplex<f32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||||
// CHECK: shape.assuming_yield %[[RESULT]]
|
// CHECK: shape.assuming_yield %[[RESULT]]
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
|
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
|
||||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||||
return %0 : tensor<?x?xi1>
|
return %0 : tensor<?x?xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||||
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
||||||
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||||
// CHECK: mhlo.add
|
// CHECK: mhlo.add
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||||
return %0 : tensor<1x4xf32>
|
return %0 : tensor<1x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
|
||||||
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
|
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
|
||||||
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
|
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
|
||||||
// CHECK: mhlo.add
|
// CHECK: mhlo.add
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
|
||||||
return %0 : tensor<1x4xf32>
|
return %0 : tensor<1x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1:
|
||||||
func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||||
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
|
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
|
||||||
// expected-error @+1 {{failed to legalize operation}}
|
// expected-error @+1 {{failed to legalize operation}}
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||||
return %0 : tensor<1x4xf32>
|
return %0 : tensor<1x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %a
|
||||||
func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||||
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
|
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
|
||||||
// expected-error @+1 {{failed to legalize operation}}
|
// expected-error @+1 {{failed to legalize operation}}
|
||||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||||
return %0 : tensor<1x4xf32>
|
return %0 : tensor<1x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
|
||||||
// CHECK-LABEL: @andWithoutBroadcast
|
// CHECK-LABEL: @andWithoutBroadcast
|
||||||
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||||
// CHECK: mhlo.and %arg0, %arg1
|
// CHECK: mhlo.and %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
%0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||||
return %0 : tensor<4xi1>
|
return %0 : tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||||
// CHECK-LABEL: @atan2WithoutBroadcast
|
// CHECK-LABEL: @atan2WithoutBroadcast
|
||||||
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.atan2 %arg0, %arg1
|
// CHECK: mhlo.atan2 %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
||||||
// CHECK-LABEL: @compareWithoutBroadcast
|
// CHECK-LABEL: @compareWithoutBroadcast
|
||||||
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
|
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
|
||||||
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||||
return %0 : tensor<4xi1>
|
return %0 : tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||||
// CHECK-LABEL: @complexWithoutBroadcast
|
// CHECK-LABEL: @complexWithoutBroadcast
|
||||||
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
|
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
|
||||||
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||||
return %0 : tensor<4xcomplex<f32>>
|
return %0 : tensor<4xcomplex<f32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||||
// CHECK-LABEL: @divideWithoutBroadcast
|
// CHECK-LABEL: @divideWithoutBroadcast
|
||||||
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.divide %arg0, %arg1
|
// CHECK: mhlo.divide %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
|
||||||
// CHECK-LABEL: @maximumWithoutBroadcast
|
// CHECK-LABEL: @maximumWithoutBroadcast
|
||||||
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.maximum %arg0, %arg1
|
// CHECK: mhlo.maximum %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||||
// CHECK-LABEL: @minimumWithoutBroadcast
|
// CHECK-LABEL: @minimumWithoutBroadcast
|
||||||
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.minimum %arg0, %arg1
|
// CHECK: mhlo.minimum %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||||
// CHECK-LABEL: @multiplyWithoutBroadcast
|
// CHECK-LABEL: @multiplyWithoutBroadcast
|
||||||
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.multiply %arg0, %arg1
|
// CHECK: mhlo.multiply %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
|
||||||
// CHECK-LABEL: @orWithoutBroadcast
|
// CHECK-LABEL: @orWithoutBroadcast
|
||||||
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||||
// CHECK: mhlo.or %arg0, %arg1
|
// CHECK: mhlo.or %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
%0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||||
return %0 : tensor<4xi1>
|
return %0 : tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,7 +186,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
|
||||||
// CHECK-LABEL: @powerWithoutBroadcast
|
// CHECK-LABEL: @powerWithoutBroadcast
|
||||||
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.power %arg0, %arg1
|
// CHECK: mhlo.power %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,7 +194,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
||||||
// CHECK-LABEL: @remainderWithoutBroadcast
|
// CHECK-LABEL: @remainderWithoutBroadcast
|
||||||
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.remainder %arg0, %arg1
|
// CHECK: mhlo.remainder %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,7 +202,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
|
||||||
// CHECK-LABEL: @shift_leftWithoutBroadcast
|
// CHECK-LABEL: @shift_leftWithoutBroadcast
|
||||||
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.shift_left %arg0, %arg1
|
// CHECK: mhlo.shift_left %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
|
||||||
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
|
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
|
||||||
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
|
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,7 +218,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
|
||||||
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
|
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
|
||||||
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.shift_right_logical %arg0, %arg1
|
// CHECK: mhlo.shift_right_logical %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,7 +226,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
|
||||||
// CHECK-LABEL: @subWithoutBroadcast
|
// CHECK-LABEL: @subWithoutBroadcast
|
||||||
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.subtract %arg0, %arg1
|
// CHECK: mhlo.subtract %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,6 +234,6 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
|
||||||
// CHECK-LABEL: @xorWithoutBroadcast
|
// CHECK-LABEL: @xorWithoutBroadcast
|
||||||
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||||
// CHECK: mhlo.xor %arg0, %arg1
|
// CHECK: mhlo.xor %arg0, %arg1
|
||||||
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||||
return %0 : tensor<4xi1>
|
return %0 : tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue