diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_helpers.h')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_helpers.h | 48 |
1 files changed, 9 insertions, 39 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index c320016998..d6ca4ab934 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -28,22 +28,6 @@ namespace tensorflow { // Helper methods for building XLA computations. class XlaHelpers { public: - // Returns a handle representing the minimum value of a scalar - // element of data_type. -inf for floating-point types. - static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); - - // Returns a handle representing the minimum finite value of a scalar - // element of data_type. - static xla::XlaOp MinFiniteValue(xla::XlaBuilder* b, DataType data_type); - - // Returns a handle representing the maximum value of a scalar - // element of data_type. inf for floating point types. - static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); - - // Returns a handle representing the maximum finite value of a scalar - // element of data_type. - static xla::XlaOp MaxFiniteValue(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the zero value of a scalar // element of data_type. static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); @@ -52,10 +36,6 @@ class XlaHelpers { // element of data_type. static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); - // Returns the machine epsilon for floating-point type `data_type`, i.e., - // the difference between 1.0 and the next representable value. - static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. @@ -73,25 +53,15 @@ class XlaHelpers { gtl::ArraySlice<int64> shape, xla::Literal* output); - // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and - // `input_dtype` are the shape and dtype of `input` respectively, and - // `output_type` is the dtype to use for `argmax`. - static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - xla::XlaOp* argmax); - - // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and - // `input_dtype` are the shape and dtype of `input` respectively, and - // `output_type` is the dtype to use for `argmin`. - static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - xla::XlaOp* argmin); - - // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota); + // Returns the argmax of `input` along `axis`. `output_type` is the type to + // use for the output. + static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, + int axis); + + // Returns the argmin of `input` along `axis`. `output_type` is the type to + // use for the output. + static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, + int axis); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new |