aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r--tensorflow/python/ops/array_ops.py106
1 files changed, 106 insertions, 0 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 78e10de933..30d2a6ed44 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1817,6 +1817,112 @@ def _DepthToSpaceShape(op):
[input_shape[0], height, width, new_depth])]
+def one_hot(indices, depth, on_value=1, off_value=0,
+ axis=None, dtype=dtypes.float32, name=None):
+ """Returns a one-hot tensor.
+
+ The locations represented by indices in `indices` take value `on_value`,
+ while all other locations take value `off_value`. By default, `on_value` is 1,
+ and `off_value` is 0. The type of the output tensor is specified by `dtype`,
+ which defaults to `tf.float32`.
+
+ If the input `indices` is rank `N`, the output will have rank `N+1`. The
+ new axis is created at dimension `axis` (default: the new axis is appended
+ at the end).
+
+ If `indices` is a scalar the output shape will be a vector of length `depth`
+
+ If `indices` is a vector of length `features`, the output shape will be:
+ ```
+ features x depth if axis == -1
+ depth x features if axis == 0
+ ```
+
+ If `indices` is a matrix (batch) with shape `[batch, features]`, the output
+ shape will be:
+ ```
+ batch x features x depth if axis == -1
+ batch x depth x features if axis == 1
+ depth x batch x features if axis == 0
+ ```
+
+
+ Examples
+ =========
+
+ Suppose that
+
+ ```
+ indices = [0, 2, -1, 1]
+ depth = 3
+ on_value = 5.0
+ off_value = 0.0
+ axis = -1
+ ```
+
+ Then output is `[4 x 3]`:
+
+ ```
+ output =
+ [5.0 0.0 0.0] // one_hot(0)
+ [0.0 0.0 5.0] // one_hot(2)
+ [0.0 0.0 0.0] // one_hot(-1)
+ [0.0 5.0 0.0] // one_hot(1)
+ ```
+
+ Suppose that
+
+ ```
+ indices = [[0, 2], [1, -1]]
+ depth = 3
+ on_value = 1.0
+ off_value = 0.0
+ axis = -1
+ ```
+
+ Then output is `[2 x 2 x 3]`:
+
+ ```
+ output =
+ [
+ [1.0, 0.0, 0.0] // one_hot(0)
+ [0.0, 0.0, 1.0] // one_hot(2)
+ ][
+ [0.0, 1.0, 0.0] // one_hot(1)
+ [0.0, 0.0, 0.0] // one_hot(-1)
+ ]
+ ```
+
+ Args:
+ indices: A `Tensor` of indices.
+ depth: A scalar defining the depth of the one hot dimension.
+ on_value: A scalar defining the value to fill in output when `indices[j]
+ = i`. (default: 1)
+ off_value: A scalar defining the value to fill in output when `indices[j]
+ != i`. (default: 0)
+ axis: The axis to fill (default: -1, a new inner-most axis).
+ dtype: The data type of the output tensor.
+
+ Returns:
+ output: The one-hot tensor.
+
+ Raises:
+ TypeError: If dtype is `tf.string`
+ """
+ # Check for bad dtype specification
+ if dtype == dtypes.string:
+ raise TypeError("dtype must be a numeric type")
+
+ with ops.op_scope([indices, depth, on_value, off_value,
+ axis, dtype], name, "one_hot") as name:
+ on_value = ops.convert_to_tensor(on_value, dtype=dtype, name="on_value")
+ off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value")
+ indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices")
+ depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth")
+ return gen_array_ops._one_hot(indices, depth, on_value,
+ off_value, axis, name)
+
+
@ops.RegisterShape("OneHot")
def _OneHotShape(op):
"""Shape function for the OneHot op.