aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_helpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_helpers.h')
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h48
1 files changed, 9 insertions, 39 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index c320016998..d6ca4ab934 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -28,22 +28,6 @@ namespace tensorflow {
// Helper methods for building XLA computations.
class XlaHelpers {
public:
- // Returns a handle representing the minimum value of a scalar
- // element of data_type. -inf for floating-point types.
- static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type);
-
- // Returns a handle representing the minimum finite value of a scalar
- // element of data_type.
- static xla::XlaOp MinFiniteValue(xla::XlaBuilder* b, DataType data_type);
-
- // Returns a handle representing the maximum value of a scalar
- // element of data_type. inf for floating point types.
- static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type);
-
- // Returns a handle representing the maximum finite value of a scalar
- // element of data_type.
- static xla::XlaOp MaxFiniteValue(xla::XlaBuilder* b, DataType data_type);
-
// Returns a handle representing the zero value of a scalar
// element of data_type.
static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type);
@@ -52,10 +36,6 @@ class XlaHelpers {
// element of data_type.
static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type);
- // Returns the machine epsilon for floating-point type `data_type`, i.e.,
- // the difference between 1.0 and the next representable value.
- static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type);
-
// Returns a handle representing the given value of an integer scalar
// element of data_type.
// Note that unlike One and Zero, does not work on boolean types.
@@ -73,25 +53,15 @@ class XlaHelpers {
gtl::ArraySlice<int64> shape,
xla::Literal* output);
- // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and
- // `input_dtype` are the shape and dtype of `input` respectively, and
- // `output_type` is the dtype to use for `argmax`.
- static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input, const TensorShape& input_shape,
- DataType input_type, DataType output_type, int axis,
- xla::XlaOp* argmax);
-
- // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and
- // `input_dtype` are the shape and dtype of `input` respectively, and
- // `output_type` is the dtype to use for `argmin`.
- static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input, const TensorShape& input_shape,
- DataType input_type, DataType output_type, int axis,
- xla::XlaOp* argmin);
-
- // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`.
- static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
- xla::XlaOp* iota);
+ // Returns the argmax of `input` along `axis`. `output_type` is the type to
+ // use for the output.
+ static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis);
+
+ // Returns the argmin of `input` along `axis`. `output_type` is the type to
+ // use for the output.
+ static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis);
// Converts `indices` into a one-hot representation. `depth` is the size
// of the new axis to add. `axis` is the position at which to add the new