diff options
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 78e10de933..30d2a6ed44 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1817,6 +1817,112 @@ def _DepthToSpaceShape(op): [input_shape[0], height, width, new_depth])] +def one_hot(indices, depth, on_value=1, off_value=0, + axis=None, dtype=dtypes.float32, name=None): + """Returns a one-hot tensor. + + The locations represented by indices in `indices` take value `on_value`, + while all other locations take value `off_value`. By default, `on_value` is 1, + and `off_value` is 0. The type of the output tensor is specified by `dtype`, + which defaults to `tf.float32`. + + If the input `indices` is rank `N`, the output will have rank `N+1`. The + new axis is created at dimension `axis` (default: the new axis is appended + at the end). + + If `indices` is a scalar the output shape will be a vector of length `depth` + + If `indices` is a vector of length `features`, the output shape will be: + ``` + features x depth if axis == -1 + depth x features if axis == 0 + ``` + + If `indices` is a matrix (batch) with shape `[batch, features]`, the output + shape will be: + ``` + batch x features x depth if axis == -1 + batch x depth x features if axis == 1 + depth x batch x features if axis == 0 + ``` + + + Examples + ========= + + Suppose that + + ``` + indices = [0, 2, -1, 1] + depth = 3 + on_value = 5.0 + off_value = 0.0 + axis = -1 + ``` + + Then output is `[4 x 3]`: + + ``` + output = + [5.0 0.0 0.0] // one_hot(0) + [0.0 0.0 5.0] // one_hot(2) + [0.0 0.0 0.0] // one_hot(-1) + [0.0 5.0 0.0] // one_hot(1) + ``` + + Suppose that + + ``` + indices = [[0, 2], [1, -1]] + depth = 3 + on_value = 1.0 + off_value = 0.0 + axis = -1 + ``` + + Then output is `[2 x 2 x 3]`: + + ``` + output = + [ + [1.0, 0.0, 0.0] // one_hot(0) + [0.0, 0.0, 1.0] // one_hot(2) + ][ + [0.0, 1.0, 0.0] // one_hot(1) + [0.0, 0.0, 0.0] // one_hot(-1) + ] + ``` + + Args: + indices: A `Tensor` of indices. + depth: A scalar defining the depth of the one hot dimension. + on_value: A scalar defining the value to fill in output when `indices[j] + = i`. (default: 1) + off_value: A scalar defining the value to fill in output when `indices[j] + != i`. (default: 0) + axis: The axis to fill (default: -1, a new inner-most axis). + dtype: The data type of the output tensor. + + Returns: + output: The one-hot tensor. + + Raises: + TypeError: If dtype is `tf.string` + """ + # Check for bad dtype specification + if dtype == dtypes.string: + raise TypeError("dtype must be a numeric type") + + with ops.op_scope([indices, depth, on_value, off_value, + axis, dtype], name, "one_hot") as name: + on_value = ops.convert_to_tensor(on_value, dtype=dtype, name="on_value") + off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value") + indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices") + depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth") + return gen_array_ops._one_hot(indices, depth, on_value, + off_value, axis, name) + + @ops.RegisterShape("OneHot") def _OneHotShape(op): """Shape function for the OneHot op. |