diff options
author | 2015-11-25 08:48:47 -0800 | |
---|---|---|
committer | 2015-11-25 08:48:47 -0800 | |
commit | 854f49bd43588c062b046384f239f64a3d819702 (patch) | |
tree | c2373bf71ef65ae4c116ea703947141281c1eace /tensorflow/python/ops | |
parent | 9c3043ff3bf31a6a81810b4ce9e87ef936f1f529 (diff) |
TensorFlow: Upstream changes to git
Changes:
- Updates to docs
- Several changes for Python 3 compatibility
- Added license headers
Base CL: 108710566
Diffstat (limited to 'tensorflow/python/ops')
29 files changed, 687 insertions, 480 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9cccc8771f..50f3facf2e 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -65,10 +65,10 @@ import sys import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops @@ -460,9 +460,9 @@ def transpose(a, perm=None, name="transpose"): [3 6]] # Equivalently - tf.transpose(x perm=[0, 1]) ==> [[1 4] - [2 5] - [3 6]] + tf.transpose(x, perm=[1, 0]) ==> [[1 4] + [2 5] + [3 6]] # 'perm' is more useful for n-dimensional tensors, for n > 2 # 'x' is [[[1 2 3] @@ -502,7 +502,7 @@ def transpose(a, perm=None, name="transpose"): return ret -def zeros(shape, dtype=types.float32, name=None): +def zeros(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to zero. This operation returns a tensor of type `dtype` with shape `shape` and @@ -528,7 +528,7 @@ def zeros(shape, dtype=types.float32, name=None): else: shape = ops.convert_to_tensor(shape, name="shape") output = fill(shape, constant(0, dtype=dtype), name=name) - assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype + assert output.dtype.base_dtype == dtypes.as_dtype(dtype).base_dtype return output @@ -594,12 +594,12 @@ def ones_like(tensor, dtype=None, name=None): return ones(ones_shape, dtype=dtype, name=name) -def zeros_initializer(shape, dtype=types.float32): +def zeros_initializer(shape, dtype=dtypes.float32): """An adaptor for zeros() to match the Initializer spec.""" return zeros(shape, dtype) -def ones(shape, dtype=types.float32, name=None): +def ones(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to 1. This operation returns a tensor of type `dtype` with shape `shape` and all @@ -625,7 +625,7 @@ def ones(shape, dtype=types.float32, name=None): else: shape = ops.convert_to_tensor(shape, name="shape") output = fill(shape, constant(1, dtype=dtype), name=name) - assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype + assert output.dtype.base_dtype == dtypes.as_dtype(dtype).base_dtype return output diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py index 0a3e068523..6b8ecc13ea 100644 --- a/tensorflow/python/ops/candidate_sampling_ops.py +++ b/tensorflow/python/ops/candidate_sampling_ops.py @@ -311,7 +311,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique, def compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None): - """Compute the ids of positions in sampled_candidates matching true_classes. + """Compute the position ids in `sampled_candidates` matching `true_classes`. In Candidate Sampling, this operation facilitates virtually removing sampled classes which happen to match target classes. This is done diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index da66610e57..a2b39d6594 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -22,8 +22,8 @@ import collections import six +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import math_ops @@ -65,7 +65,7 @@ def clip_by_norm(t, clip_norm, name=None): """Clips tensor values to a maximum L2-norm. Given a tensor `t`, and a maximum clip value `clip_norm`, this operation - normalizes `t` so that its L2-norm is less than or equal to `clip_norm'. + normalizes `t` so that its L2-norm is less than or equal to `clip_norm`. Specifically, if the L2-norm is already less than or equal to `clip_norm`, then `t` is not modified. If the L2-norm is greater than `clip_norm`, then this operation returns a tensor of the same type and shape as `t` with its @@ -146,18 +146,18 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): if you've already computed the global norm for `t_list`, you can specify the global norm with `use_norm`. - To perform the clipping, the values t_list[i] are set to: + To perform the clipping, the values `t_list[i]` are set to: - `t_list[i] * clip_norm / max(global_norm, clip_norm)` + t_list[i] * clip_norm / max(global_norm, clip_norm) where: - `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))` + global_norm = sqrt(sum([l2norm(t)**2 for t in t_list])) If `clip_norm > global_norm` then the entries in `t_list` remain as they are, otherwise they're all shrunk by the global ratio. - Any of the entries of `t_list` that are of type None are ignored. + Any of the entries of `t_list` that are of type `None` are ignored. This is the correct way to perform gradient clipping (for example, see R. Pascanu, T. Mikolov, and Y. Bengio, "On the difficulty of training @@ -219,7 +219,7 @@ def clip_by_average_norm(t, clip_norm, name=None): Given a tensor `t`, and a maximum clip value `clip_norm`, this operation normalizes `t` so that its average L2-norm is less than or equal to - `clip_norm'. Specifically, if the average L2-norm is already less than or + `clip_norm`. Specifically, if the average L2-norm is already less than or equal to `clip_norm`, then `t` is not modified. If the average L2-norm is greater than `clip_norm`, then this operation returns a tensor of the same type and shape as `t` with its values set to: @@ -244,7 +244,7 @@ def clip_by_average_norm(t, clip_norm, name=None): # Calculate L2-norm per element, clip elements by ratio of clip_norm to # L2-norm per element - n_element = math_ops.cast(array_ops.size(t), types.float32) + n_element = math_ops.cast(array_ops.size(t), dtypes.float32) l2norm_inv = math_ops.rsqrt( math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t)))) tclip = array_ops.identity( diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py index b3decb2651..c0615f1d3b 100644 --- a/tensorflow/python/ops/common_shapes.py +++ b/tensorflow/python/ops/common_shapes.py @@ -130,10 +130,10 @@ def get2d_conv_output_size(input_height, input_width, filter_height, # Compute number of rows in the output, based on the padding. if input_height.value is None or filter_height.value is None: out_rows = None - elif padding_type == "VALID": + elif padding_type == b"VALID": out_rows = ((input_height.value - filter_height.value + row_stride) // row_stride) - elif padding_type == "SAME": + elif padding_type == b"SAME": out_rows = (input_height.value + row_stride - 1) // row_stride else: raise ValueError("Invalid value for padding: %r" % padding_type) @@ -141,10 +141,10 @@ def get2d_conv_output_size(input_height, input_width, filter_height, # Compute number of columns in the output, based on the padding. if input_width.value is None or filter_width.value is None: out_cols = None - elif padding_type == "VALID": + elif padding_type == b"VALID": out_cols = ((input_width.value - filter_width.value + col_stride) // col_stride) - elif padding_type == "SAME": + elif padding_type == b"SAME": out_cols = (input_width.value + col_stride - 1) // col_stride return out_rows, out_cols diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py index 9cf5e99a1e..f2aaad37a9 100644 --- a/tensorflow/python/ops/constant_op.py +++ b/tensorflow/python/ops/constant_op.py @@ -105,10 +105,10 @@ import tensorflow.python.platform import numpy as np from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types def constant(value, dtype=None, shape=None, name="Const"): @@ -182,10 +182,10 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None): raise ValueError( "Cannot convert a partially known TensorShape to a Tensor: %s" % s) if dtype is not None: - if dtype not in (types.int32, types.int64): + if dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype) else: - dtype = types.int32 + dtype = dtypes.int32 if name is None: name = "shape_as_tensor" return constant(s.as_list(), dtype=dtype, name=name) @@ -197,10 +197,10 @@ def _dimension_tensor_conversion_function(d, dtype=None, name=None): if d.value is None: raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d) if dtype is not None: - if dtype not in (types.int32, types.int64): + if dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype) else: - dtype = types.int32 + dtype = dtypes.int32 if name is None: name = "shape_as_tensor" return constant(d.value, dtype=dtype, name=name) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index cd15d921cf..8eb1bd79bf 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -69,9 +69,9 @@ from __future__ import print_function import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import constant_op @@ -443,7 +443,7 @@ def _GetRealValue(value): # The begin position of the slice at slice_index. slice_index = forward_ctxt.grad_context.index - b1 = array_ops.zeros(elem_rank_vec, dtype=types.int32) + b1 = array_ops.zeros(elem_rank_vec, dtype=dtypes.int32) b = array_ops.concat(0, [array_ops.expand_dims(slice_index, 0), b1]) # The slice at slice_index. @@ -1134,7 +1134,7 @@ def _AsTensorList(x, p): Returns: A list of Tensors or IndexedSlices. """ - if not isinstance(x, list) and not isinstance(x, _basetuple): + if not isinstance(x, (list, _basetuple)): x = [x] l = [] @@ -1249,7 +1249,10 @@ def group(*inputs, **kwargs): # 2-level tree. The root node is the returned NoOp node. # deps contains 1 NoOp node for each device. deps = [] - for dev in sorted(six.iterkeys(ops_on_device)): + def device_key(dev): + """A sort key that allows None to be compared to strings.""" + return "" if dev is None else dev + for dev in sorted(six.iterkeys(ops_on_device), key=device_key): deps.append(_GroupControlDeps(dev, ops_on_device[dev])) return _GroupControlDeps(None, deps, name=name) @@ -1332,7 +1335,7 @@ def fold(fn, elems, elem_shape, name=None): d0 = elem_shape[0] n = math_ops.div(s0, d0) b1 = array_ops.zeros(array_ops.expand_dims(array_ops.rank(elems) - 1, 0), - dtype=types.int32) + dtype=dtypes.int32) # Initialize the output with slice 0 b = array_ops.concat(0, [[0], b1]) o = array_ops.slice(elems, b, elem_shape) @@ -1374,7 +1377,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"): Expressions: ``` - f1 = lambda: tf.onstant(17) + f1 = lambda: tf.constant(17) f2 = lambda: tf.constant(23) r = case([(tf.less(x, y), f1)], default=f2) ``` @@ -1428,7 +1431,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"): if not isinstance(tup, _basetuple) or len(tup) != 2: raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") pred, fn = tup - if pred.dtype != types.bool: + if pred.dtype != dtypes.bool: raise TypeError("pred must be of type bool: %s", pred.name) if not callable(fn): raise TypeError("fn for pred %s must be callable." % pred.name) @@ -1468,7 +1471,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"): # TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds)) preds_c = array_ops.concat(0, preds, name="preds_c") num_true_conditions = math_ops.reduce_sum( - math_ops.cast(preds_c, types.int32), name="num_true_conds") + math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") at_most_one_true_condition = math_ops.less( num_true_conditions, constant_op.constant(2, name="two_true_conds")) diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py index 219a503fd7..c46c29ed57 100644 --- a/tensorflow/python/ops/data_flow_grad.py +++ b/tensorflow/python/ops/data_flow_grad.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import data_flow_ops @@ -36,8 +36,8 @@ def _DynamicStitchGrads(op, grad): indices_grad = [None] * num_values def AsInt32(x): - return (x if op.inputs[0].dtype == types.int32 else - math_ops.cast(x, types.int32)) + return (x if op.inputs[0].dtype == dtypes.int32 else + math_ops.cast(x, dtypes.int32)) inputs = [AsInt32(op.inputs[i]) for i in xrange(num_values)] if isinstance(grad, ops.IndexedSlices): output_shape = array_ops.shape(op.outputs[0]) @@ -54,3 +54,8 @@ ops.NoGradient("QueueDequeue") ops.NoGradient("QueueDequeueMany") ops.NoGradient("QueueClose") ops.NoGradient("QueueSize") + +ops.NoGradient("Stack") +ops.NoGradient("StackPush") +ops.NoGradient("StackPop") +ops.NoGradient("StackClose") diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index c50fad3211..261193715c 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -125,12 +125,12 @@ class QueueBase(object): A `QueueBase` object. Raises: - TypeError: when `queues` is not a list of `QueueBase` objects, + TypeError: When `queues` is not a list of `QueueBase` objects, or when the data types of `queues` are not all the same. """ if ((not queues) or (not isinstance(queues, list)) or - (not all([isinstance(x, QueueBase) for x in queues]))): + (not all(isinstance(x, QueueBase) for x in queues))): raise TypeError("A list of queues expected") dtypes = queues[0].dtypes @@ -458,6 +458,12 @@ ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape) ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape) +ops.RegisterShape("Stack")(common_shapes.scalar_shape) +ops.RegisterShape("StackPush")(common_shapes.unknown_shape) +ops.RegisterShape("StackPop")(common_shapes.unknown_shape) +ops.RegisterShape("StackClose")(common_shapes.unknown_shape) + + @ops.RegisterShape("QueueClose") def _ScalarToVoidShape(op): """Shape function for ops that take a scalar and produce no outputs.""" diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 9fb1cf99f2..ff39c3c751 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops @@ -43,7 +43,7 @@ def embedding_lookup(params, ids, name=None): Args: params: A list of tensors with the same shape and type. - ids: A `Tensor` with type `int32` containing the ids to be looked + ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. name: A name for the operation (optional). @@ -69,8 +69,8 @@ def embedding_lookup(params, ids, name=None): original_indices = math_ops.range(array_ops.size(flat_ids)) # Compute flat_ids % partitions for each id ids_mod_p = flat_ids % np - if ids_mod_p.dtype != types.int32: - ids_mod_p = math_ops.cast(ids_mod_p, types.int32) + if ids_mod_p.dtype != dtypes.int32: + ids_mod_p = math_ops.cast(ids_mod_p, dtypes.int32) # Partition single list of ids based on ids % np into np separate lists plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np) # Similarly, partition the original indices. @@ -178,8 +178,8 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, with ops.op_scope(params + [sp_ids], name, "embedding_lookup_sparse") as name: segment_ids = sp_ids.indices[:, 0] - if segment_ids.dtype != types.int32: - segment_ids = math_ops.cast(segment_ids, types.int32) + if segment_ids.dtype != dtypes.int32: + segment_ids = math_ops.cast(segment_ids, dtypes.int32) ids = sp_ids.values if ignore_weights: diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 72c7bc2499..b790d9af6c 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Implements the graph generation for computation of gradients.""" from __future__ import absolute_import @@ -27,16 +26,17 @@ import tensorflow.python.platform import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types # pylint: disable=unused-import from tensorflow.python.ops import array_grad from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import control_flow_grad from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import image_grad from tensorflow.python.ops import logging_ops from tensorflow.python.ops import linalg_grad from tensorflow.python.ops import math_grad @@ -45,7 +45,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.platform import logging - # Warn the user if we convert a sparse representation to dense with at # least this number of elements. _LARGE_SPARSE_NUM_ELEMENTS = 100000000 @@ -69,8 +68,8 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None): """ if dtype and not dtype.is_compatible_with(value.dtype): raise ValueError( - "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" - % (dtype.name, value.dtype.name)) + "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" % + (dtype.name, value.dtype.name)) if value.dense_shape is None: raise ValueError( "Tensor conversion requested for IndexedSlices without dense_shape: %s" @@ -88,11 +87,14 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None): warnings.warn( "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " "This may consume a large amount of memory.") - return math_ops.unsorted_segment_sum( - value.values, value.indices, value.dense_shape[0], name=name) + return math_ops.unsorted_segment_sum(value.values, + value.indices, + value.dense_shape[0], + name=name) -ops.register_tensor_conversion_function(ops.IndexedSlices, _IndexedSlicesToTensor) +ops.register_tensor_conversion_function(ops.IndexedSlices, + _IndexedSlicesToTensor) def _MarkReachedOps(from_ops, reached_ops): @@ -236,14 +238,16 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops): y = ys[i] if grad_y is None: with ops.device(_GetGradsDevice(y.op, colocate_gradients_with_ops)): - grad_ys[i] = array_ops.fill(array_ops.shape(y), - constant_op.constant(1, dtype=y.dtype)) + grad_ys[i] = array_ops.fill( + array_ops.shape(y), + constant_op.constant(1, + dtype=y.dtype)) else: if grad_y.dtype != y.dtype: raise ValueError("Y and ys_grad must be of the same type, " "not y: %s, ys_grad: %s " % - (types.as_dtype(y.dtype).name, - types.as_dtype(grad_y.dtype).name)) + (dtypes.as_dtype(y.dtype).name, + dtypes.as_dtype(grad_y.dtype).name)) return grad_ys @@ -265,11 +269,10 @@ def _VerifyGeneratedGradients(grads, op): inp = op.inputs[i] if grad is not None: if not grad.dtype.is_compatible_with(inp.dtype): - raise ValueError( - "Gradient type %s generated for op %s does " - "not match input type %s" % - (types.as_dtype(grad.dtype).name, op.node_def, - types.as_dtype(inp.dtype).name)) + raise ValueError("Gradient type %s generated for op %s does " + "not match input type %s" % + (dtypes.as_dtype(grad.dtype).name, op.node_def, + dtypes.as_dtype(inp.dtype).name)) def _StopOps(from_ops, pending_count): @@ -301,7 +304,10 @@ def _StopOps(from_ops, pending_count): return stop_ops -def gradients(ys, xs, grad_ys=None, name="gradients", +def gradients(ys, + xs, + grad_ys=None, + name="gradients", colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None): @@ -319,7 +325,7 @@ def gradients(ys, xs, grad_ys=None, name="gradients", `grad_ys` is a list of tensors of the same length as `ys` that holds the initial gradients for each y in `ys`. When `grad_ys` is None, we fill in a tensor of '1's of the shape of y for each y in `ys`. A - user can provide their own initial 'grad_ys` to compute the + user can provide their own initial `grad_ys` to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y). @@ -369,8 +375,8 @@ def gradients(ys, xs, grad_ys=None, name="gradients", # to the xs. to_ops = [t.op for t in ys] from_ops = [t.op for t in xs] - pending_count, has_control_flow = _PendingCount( - ops.get_default_graph(), to_ops, from_ops) + pending_count, has_control_flow = _PendingCount(ops.get_default_graph(), + to_ops, from_ops) # Iterate over the collected ops. # @@ -418,9 +424,9 @@ def gradients(ys, xs, grad_ys=None, name="gradients", # output, it means that the cost does not depend on output[i], # therefore dC/doutput[i] is 0. for i, out_grad in enumerate(out_grads): - if (not out_grad - and types.as_dtype(op.outputs[i].dtype).base_dtype in ( - types.float32, types.float64)): + if (not out_grad and + dtypes.as_dtype(op.outputs[i].dtype).base_dtype in + (dtypes.float32, dtypes.float64)): # Only floating-point outputs get a zero gradient. Gradient # functions should ignore the gradient for other outputs. out_grads[i] = array_ops.zeros_like(op.outputs[i]) @@ -485,7 +491,8 @@ def _GetGrad(grads, t): """Gets gradient for tensor "t".""" op = t.op op_grads = grads.get(op) - if not op_grads: return None + if not op_grads: + return None t_grad = op_grads[t.value_index] assert not isinstance(t_grad, list), ( "gradients list should have been aggregated by now.") @@ -622,8 +629,8 @@ def _AggregatedGrads(grads, op, has_control_flow, aggregation_method=None): # indices. out_grads[i] = ops.IndexedSlices( array_ops.concat(0, [x.values for x in out_grad]), - array_ops.concat(0, [x.indices for x in out_grad]), - out_grad[0].dense_shape) + array_ops.concat(0, [x.indices + for x in out_grad]), out_grad[0].dense_shape) else: out_grads[i] = [] return out_grads diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 9a87319697..5afc8a779b 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -24,9 +24,9 @@ import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.framework import types # pylint: disable=unused-import from tensorflow.python.ops import array_grad from tensorflow.python.ops import array_ops @@ -68,7 +68,7 @@ def _OpsBetween(graph, to_ops, from_ops): reached_ops[op._id] = True gradients._MarkReachedOps(from_ops, reached_ops) between_ops = gradients._GatherInputs(to_ops, reached_ops) - between_ops.sort(lambda x, y: y._id - x._id) + between_ops.sort(key=lambda x: -x._id) return between_ops @@ -246,13 +246,13 @@ class GradientsTest(test_util.TensorFlowTestCase): @ops.RegisterGradient("TestOp") def _TestOpGrad(op, float_grad, string_grad): """Gradient function for TestOp.""" - self.assertEquals(float_grad.dtype, types.float32) + self.assertEquals(float_grad.dtype, dtypes.float32) self.assertFalse(string_grad) return float_grad ops.RegisterShape("TestOp")(None) c = constant(1.0) - x, y = g.create_op("TestOp", [c], [types.float32, types.string]).outputs + x, y = g.create_op("TestOp", [c], [dtypes.float32, dtypes.string]).outputs z = x * 2.0 w = z * 3.0 grads = gradients.gradients(z, [c]) @@ -314,7 +314,7 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): c = constant_op.constant(np_val) c_sparse = math_ops._as_indexed_slices(c) c_sparse = ops.IndexedSlices( - c_sparse.values, math_ops.cast(c_sparse.indices, types.int64), + c_sparse.values, math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape) self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval()) c_dense = math_ops.mul(c_sparse, 1.0) @@ -322,16 +322,16 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): def testWarnings(self): # Smaller than the threshold: no warning. - c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32), - array_ops.placeholder(types.int32), + c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) with warnings.catch_warnings(record=True) as w: math_ops.mul(c_sparse, 1.0) self.assertEqual(0, len(w)) # Greater than or equal to the threshold: warning. - c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32), - array_ops.placeholder(types.int32), + c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100])) with warnings.catch_warnings(record=True) as w: math_ops.mul(c_sparse, 1.0) @@ -341,9 +341,9 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): in str(w[0].message)) # Unknown dense shape: warning. - c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32), - array_ops.placeholder(types.int32), - array_ops.placeholder(types.int32)) + c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), + array_ops.placeholder(dtypes.int32)) with warnings.catch_warnings(record=True) as w: math_ops.mul(c_sparse, 1.0) self.assertEqual(1, len(w)) diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py new file mode 100644 index 0000000000..55f1c74f0e --- /dev/null +++ b/tensorflow/python/ops/image_grad.py @@ -0,0 +1,57 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Contains Gradient functions for image ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import gen_image_ops + + +@ops.RegisterGradient("ResizeNearestNeighbor") +def _ResizeNearestNeighborGrad(op, grad): + """The derivatives for nearest neighbor resizing. + + Args: + op: The ResizeNearestNeighbor op. + grad: The tensor representing the gradient w.r.t. the output. + + Returns: + The gradients w.r.t. the input and the output. + """ + grads = gen_image_ops.resize_nearest_neighbor_grad( + grad, op.inputs[0].get_shape()[1:3]) + return [grads, None] + + +@ops.RegisterShape("ResizeNearestNeighborGrad") +def _ResizeShape(op): + """Shape function for the resize grad ops.""" + input_shape = op.inputs[0].get_shape().with_rank(4) + size = tensor_util.ConstantValue(op.inputs[1]) + if size is not None: + height = size[0] + width = size[1] + else: + height = None + width = None + return [ + tensor_shape.TensorShape([input_shape[0], height, width, input_shape[3]]) + ] diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py new file mode 100644 index 0000000000..488e741e5b --- /dev/null +++ b/tensorflow/python/ops/image_grad_test.py @@ -0,0 +1,85 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Python ops defined in image_grad.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order, +# pylint: disable=unused-import +import tensorflow.python.platform +from tensorflow.python.kernel_tests import gradient_checker as gc + +import numpy as np +import tensorflow as tf +# pylint: enable=g-bad-import-order +# pylint: enable=unused-import + + +class ResizeNearestNeighborOpTest(tf.test.TestCase): + + def testShapeIsCorrectAfterOp(self): + in_shape = [1, 2, 2, 1] + out_shape = [1, 4, 6, 1] + + x = np.arange(0, 4).reshape(in_shape).astype(np.float32) + + with self.test_session() as sess: + input_tensor = tf.constant(x, shape=in_shape) + resize_out = tf.image.resize_nearest_neighbor(input_tensor, + out_shape[1:3]) + self.assertEqual(out_shape, list(resize_out.get_shape())) + + resize_out = sess.run(resize_out) + self.assertEqual(out_shape, list(resize_out.shape)) + + def testGradFromResizeToLargerInBothDims(self): + in_shape = [1, 2, 3, 1] + out_shape = [1, 4, 6, 1] + + x = np.arange(0, 6).reshape(in_shape).astype(np.float32) + + with self.test_session(): + input_tensor = tf.constant(x, shape=in_shape) + resize_out = tf.image.resize_nearest_neighbor(input_tensor, + out_shape[1:3]) + err = gc.ComputeGradientError(input_tensor, + in_shape, + resize_out, + out_shape, + x_init_value=x) + self.assertLess(err, 1e-3) + + def testGradFromResizeToSmallerInBothDims(self): + in_shape = [1, 4, 6, 1] + out_shape = [1, 2, 3, 1] + + x = np.arange(0, 24).reshape(in_shape).astype(np.float32) + + with self.test_session(): + input_tensor = tf.constant(x, shape=in_shape) + resize_out = tf.image.resize_nearest_neighbor(input_tensor, + out_shape[1:3]) + err = gc.ComputeGradientError(input_tensor, + in_shape, + resize_out, + out_shape, + x_init_value=x) + self.assertLess(err, 1e-3) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index a95e8e9765..7be02b220f 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -55,10 +55,6 @@ image = tf.image.decode_jpeg(...) resized_image = tf.image.resize_bilinear(image, [299, 299]) ``` -<i>Maybe refer to the Queue examples that show how to add images to a Queue -after resizing them to a fixed size, and how to dequeue batches of resized -images from the Queue.</i> - @@resize_images @@resize_area @@ -109,11 +105,11 @@ import math import tensorflow.python.platform +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import common_shapes @@ -354,7 +350,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, This op cuts a rectangular part out of `image`. The top-left corner of the returned image is at `offset_height, offset_width` in `image`, and its lower-right corner is at - `offset_height + target_height, offset_width + target_width'. + `offset_height + target_height, offset_width + target_width`. Args: image: 3-D tensor with shape `[height, width, channels]` @@ -554,7 +550,7 @@ def per_image_whitening(image): height, width, depth = _ImageDimensions(image) num_pixels = height * width * depth - image = math_ops.cast(image, dtype=types.float32) + image = math_ops.cast(image, dtype=dtypes.float32) image_mean = math_ops.reduce_mean(image) variance = (math_ops.reduce_mean(math_ops.square(image)) - @@ -592,7 +588,7 @@ def random_brightness(image, max_delta, seed=None): 3-D tensor of images of shape `[height, width, channels]` Raises: - ValueError: if max_delta is negative. + ValueError: if `max_delta` is negative. """ _Check3DImage(image) @@ -664,8 +660,8 @@ def adjust_brightness(image, delta, min_value=None, max_value=None): with ops.op_scope([image, delta, min_value, max_value], None, 'adjust_brightness') as name: adjusted = math_ops.add( - math_ops.cast(image, types.float32), - math_ops.cast(delta, types.float32), + math_ops.cast(image, dtypes.float32), + math_ops.cast(delta, dtypes.float32), name=name) if image.dtype.is_integer: rounded = math_ops.round(adjusted) diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 141fb8c13a..8c52a232cb 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function import math -from tensorflow.python.framework import types +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import math_ops @@ -38,7 +38,7 @@ def constant_initializer(value=0.0): Returns: An initializer that generates tensors with a single value. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): return constant_op.constant(value, dtype=dtype, shape=shape) return _initializer @@ -57,7 +57,7 @@ def random_uniform_initializer(minval=0.0, maxval=1.0, seed=None): Returns: An initializer that generates tensors with a uniform distribution. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): return random_ops.random_uniform(shape, minval, maxval, dtype, seed=seed) return _initializer @@ -76,7 +76,7 @@ def random_normal_initializer(mean=0.0, stddev=1.0, seed=None): Returns: An initializer that generates tensors with a normal distribution. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): return random_ops.random_normal(shape, mean, stddev, dtype, seed=seed) return _initializer @@ -101,7 +101,7 @@ def truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None): An initializer that generates tensors with a truncated normal distribution. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): return random_ops.truncated_normal(shape, mean, stddev, dtype, seed=seed) return _initializer @@ -132,7 +132,7 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None): Returns: An initializer that generates tensors with unit variance. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): input_size = 1.0 # Estimating input size is not possible to do perfectly, but we try. # The estimate, obtained by multiplying all dimensions but the last one, @@ -146,7 +146,7 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None): # TODO(vrv): Unhide when we are ready to expose this publicly. -def _random_walk(shape, nonlinearity, dtype=types.float32, seed=None, +def _random_walk(shape, nonlinearity, dtype=dtypes.float32, seed=None, name="random_walk"): """Create a random tensor such that backprop neither vanishes nor explodes. @@ -200,7 +200,7 @@ class _RandomWalkInitializer(object): self._nonlinearity = nonlinearity self._seed = seed - def __call__(self, shape, dtype=types.float32): + def __call__(self, shape, dtype=dtypes.float32): """Generate a tensor used to initialize a variable.""" return random_ops._random_walk(shape, self._nonlinearity, dtype, seed=self._seed) diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 69d4ee6266..97b491bd53 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -121,9 +121,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_io_ops # pylint: disable=wildcard-import @@ -183,7 +183,7 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, Returns: A tensor of type "tensor_type". """ - base_type = types.as_dtype(tensor_type).base_dtype + base_type = dtypes.as_dtype(tensor_type).base_dtype return gen_io_ops._restore_slice( file_pattern, tensor_name, shape_and_slice, base_type, preferred_shard, name=name) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 9f9e0943ea..12444916ce 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import data_flow_ops @@ -288,7 +288,7 @@ def _MulGrad(op, grad): sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - if x.dtype.base_dtype == types.complex64: + if x.dtype.base_dtype == dtypes.complex64: return (array_ops.reshape(math_ops.reduce_sum(grad * math_ops.conj(y), rx), sx), array_ops.reshape(math_ops.reduce_sum(math_ops.conj(x) * grad, ry), sy)) else: @@ -412,29 +412,14 @@ def _SparseMatMulGrad(op, grad): assert t1 in is_sparse and t2 in is_sparse t1_sparse = is_sparse[t1] t2_sparse = is_sparse[t2] - if not t1_sparse and not t2_sparse: - return math_ops.matmul(t1, t2, - transpose_a=transpose_a, - transpose_b=transpose_b) - transpose_out = False - if not t1_sparse: - transpose_out = True - t1, t2 = t2, t1 - t1_sparse, t2_sparse = t2_sparse, t1_sparse - assert t1_sparse - transpose_a, transpose_b = not transpose_b, not transpose_a - if transpose_b: t2 = array_ops.transpose(t2) transpose_b = False - m = math_ops.matmul(t1, t2, - transpose_a=transpose_a, - transpose_b=transpose_b, - a_is_sparse=t1_sparse, - b_is_sparse=t2_sparse) - if transpose_out: - m = array_ops.transpose(m) - return m + return math_ops.matmul(t1, t2, + transpose_a=transpose_a, + transpose_b=transpose_b, + a_is_sparse=t1_sparse, + b_is_sparse=t2_sparse) if not t_a and not t_b: return (_SparseMatMul(grad, op.inputs[1], transpose_b=True), @@ -515,7 +500,7 @@ def _ConjGrad(_, grad): @ops.RegisterGradient("Cast") def _CastGrad(op, grad): - t = [types.float32, types.float64, types.bfloat16] + t = [dtypes.float32, dtypes.float64, dtypes.bfloat16] src_type = op.inputs[0].dtype.base_dtype dst_type = grad.dtype.base_dtype if src_type in t and dst_type in t: diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 29c2fcf432..2555b4ba17 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -156,10 +156,10 @@ import tensorflow.python.platform import numpy as np import six.moves +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_math_ops @@ -196,7 +196,7 @@ def abs(x, name=None): """ with ops.op_scope([x], name, "Abs") as name: x = ops.convert_to_tensor(x, name="x") - if x.dtype == types.complex64: + if x.dtype == dtypes.complex64: return gen_math_ops.complex_abs(x, name=name) return gen_math_ops._abs(x, name=name) @@ -333,7 +333,7 @@ def to_float(x, name="ToFloat"): Raises: TypeError: If `x` cannot be cast to the `float32`. """ - return cast(x, types.float32, name=name) + return cast(x, dtypes.float32, name=name) def to_double(x, name="ToDouble"): @@ -349,7 +349,7 @@ def to_double(x, name="ToDouble"): Raises: TypeError: If `x` cannot be cast to the `float64`. """ - return cast(x, types.float64, name=name) + return cast(x, dtypes.float64, name=name) def to_int32(x, name="ToInt32"): @@ -365,7 +365,7 @@ def to_int32(x, name="ToInt32"): Raises: TypeError: If `x` cannot be cast to the `int32`. """ - return cast(x, types.int32, name=name) + return cast(x, dtypes.int32, name=name) def to_int64(x, name="ToInt64"): @@ -381,7 +381,7 @@ def to_int64(x, name="ToInt64"): Raises: TypeError: If `x` cannot be cast to the `int64`. """ - return cast(x, types.int64, name=name) + return cast(x, dtypes.int64, name=name) def to_bfloat16(x, name="ToBFloat16"): @@ -397,7 +397,7 @@ def to_bfloat16(x, name="ToBFloat16"): Raises: TypeError: If `x` cannot be cast to the `bfloat16`. """ - return cast(x, types.bfloat16, name=name) + return cast(x, dtypes.bfloat16, name=name) ops.Tensor._override_operator("__neg__", neg) @@ -437,14 +437,14 @@ def _OverrideBinaryOperatorHelper(func, op_name): # Conversion table for __truediv__. None entries mean no conversion required. _TRUEDIV_TABLE = { - types.uint8: types.float32, - types.int8: types.float32, - types.int16: types.float32, - types.int32: types.float64, - types.int64: types.float64, - types.float32: None, - types.float64: None, - types.complex64: None, + dtypes.uint8: dtypes.float32, + dtypes.int8: dtypes.float32, + dtypes.int16: dtypes.float32, + dtypes.int32: dtypes.float64, + dtypes.int64: dtypes.float64, + dtypes.float32: None, + dtypes.float64: None, + dtypes.complex64: None, } @@ -891,7 +891,7 @@ def matmul(a, b, with ops.op_scope([a, b], name, "MatMul") as name: a = ops.convert_to_tensor(a, name="a") b = ops.convert_to_tensor(b, name="b") - if a.dtype == types.float32 and (a_is_sparse or b_is_sparse): + if a.dtype == dtypes.float32 and (a_is_sparse or b_is_sparse): return sparse_matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b, @@ -953,14 +953,14 @@ def _as_indexed_slices_list(inputs): raise TypeError("Expected a list or tuple, not a %s" % type(inputs)) outputs = [_as_indexed_slices(i) for i in inputs] with_int32_index = [o.indices for o in outputs - if o.indices.dtype == types.int32] + if o.indices.dtype == dtypes.int32] if not with_int32_index or len(with_int32_index) == len(outputs): return outputs casted_outputs = [] for o in outputs: - if o.indices.dtype == types.int32: + if o.indices.dtype == dtypes.int32: casted_outputs.append( - ops.IndexedSlices(o.values, cast(o.indices, types.int64), + ops.IndexedSlices(o.values, cast(o.indices, dtypes.int64), o.dense_shape)) else: casted_outputs.append(o) diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 749faaf73a..17160f909e 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -204,9 +204,9 @@ from __future__ import division from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import candidate_sampling_ops from tensorflow.python.ops import constant_op @@ -360,7 +360,7 @@ def zero_fraction(value, name=None): value = ops.convert_to_tensor(value, name="value") zero = constant_op.constant(0, dtype=value.dtype, name="zero") return math_ops.reduce_mean(math_ops.cast(math_ops.equal(value, zero), - types.float32)) + dtypes.float32)) def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): @@ -633,35 +633,42 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, sum to 1 per-example. Args: - weights: tensor of label embeddings with shape = [num_classes, dim]. - biases: tensor of num_classes label biases - inputs: tensor with shape = [batch_size, dim] corresponding to forward - activations of the input network - labels: int tensor with shape [batch_size, num_true] - num_sampled: number of label classes to sample per batch - num_classes: number of possible label classes in the data (e.g. vocab size) - num_true: number of target classes per example (default: 1) + weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` + objects whose concatenation along dimension 0 has shape + `[num_classes, dim]`. The (possibly-sharded) class embeddings. + biases: A `Tensor` of shape `[num_classes]`. The class biases. + inputs: A `Tensor` of shape `[batch_size, dim]`. The forward + activations of the input network. + labels: A `Tensor` of type `int64` and shape `[batch_size, + num_true]`. The target classes. Note that this format differs from + the `labels` argument of `nn.softmax_cross_entropy_with_logits`. + num_sampled: An `int`. The number of classes to randomly sample per batch. + num_classes: An `int`. The number of possible classes. + num_true: An `int`. The number of target classes per training example. sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, - `sampled_expected_count`) returned by a `*_candidate_sampler` function - to use (if None, we default to `log_uniform_candidate_sampler`) - subtract_log_q: subtract the log expected count of the labels in the sample - to get the logits of the true labels (default: True) - Turn off for Negative Sampling. - remove_accidental_hits: whether to remove "accidental hits" where a sampled - label equals the true labels (bool, default: False) + `sampled_expected_count`) returned by a `*_candidate_sampler` function. + (if None, we default to `log_uniform_candidate_sampler`) + subtract_log_q: A `bool`. whether to subtract the log expected count of + the labels in the sample to get the logits of the true labels. + Default is True. Turn off for Negative Sampling. + remove_accidental_hits: A `bool`. whether to remove "accidental hits" + where a sampled class equals one of the target classes. Default is + False. name: A name for the operation (optional). - Returns: - out_logits, out_labels: tensors with shape - `[batch_size, num_true + num_sampled]` for passing to either - `sigmoid_cross_entropy_with_logits` (NCE) - or `softmax_cross_entropy_with_logits` (sampled softmax). + out_logits, out_labels: `Tensor` objects each with shape + `[batch_size, num_true + num_sampled]`, for passing to either + `nn.sigmoid_cross_entropy_with_logits` (NCE) or + `nn.softmax_cross_entropy_with_logits` (sampled softmax). """ + if not isinstance(weights, list): + weights = [weights] + with ops.op_scope( - [weights, biases, inputs, labels], name, "compute_sampled_logits"): - if labels.dtype != types.int64: - labels = math_ops.cast(labels, types.int64) + weights + [biases, inputs, labels], name, "compute_sampled_logits"): + if labels.dtype != dtypes.int64: + labels = math_ops.cast(labels, dtypes.int64) labels_flat = array_ops.reshape(labels, [-1]) # Sample the negative labels. @@ -726,7 +733,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, # This is how SparseToDense expects the indices. acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1]) acc_ids_2d_int32 = array_ops.reshape(math_ops.cast( - acc_ids, types.int32), [-1, 1]) + acc_ids, dtypes.int32), [-1, 1]) sparse_indices = array_ops.concat( 1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices") # Create sampled_logits_shape = [batch_size, num_sampled] @@ -777,17 +784,19 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, with an otherwise unused class. Args: - weights: A `Tensor` of shape [num_classes, dim]. The class embeddings. - biases: A `Tensor` of shape [num_classes]. The class biases. - inputs: A `Tensor` of shape [batch_size, dim]. The forward + weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` + objects whose concatenation along dimension 0 has shape + [num_classes, dim]. The (possibly-sharded) class embeddings. + biases: A `Tensor` of shape `[num_classes]`. The class biases. + inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. labels: A `Tensor` of type `int64` and shape `[batch_size, - num_true]`. The target classes. + num_true]`. The target classes. num_sampled: An `int`. The number of classes to randomly sample per batch. num_classes: An `int`. The number of possible classes. num_true: An `int`. The number of target classes per training example. - sampled_values: a tuple of `(sampled_candidates, true_expected_count, - sampled_expected_count)` returned by a `*_candidate_sampler` function. + sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. (if None, we default to `log_uniform_candidate_sampler`) remove_accidental_hits: A `bool`. Whether to remove "accidental hits" where a sampled class equals one of the target classes. If set to @@ -799,7 +808,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, name: A name for the operation (optional). Returns: - A batch_size 1-D tensor of per-example NCE losses. + A `batch_size` 1-D tensor of per-example NCE losses. """ logits, labels = _compute_sampled_logits( weights, biases, inputs, labels, num_sampled, num_classes, @@ -838,18 +847,20 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math. Args: - weights: A `Tensor` of shape [num_classes, dim]. The class embeddings. - biases: A `Tensor` of shape [num_classes]. The class biases. - inputs: A `Tensor` of shape [batch_size, dim]. The forward + weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` + objects whose concatenation along dimension 0 has shape + [num_classes, dim]. The (possibly-sharded) class embeddings. + biases: A `Tensor` of shape `[num_classes]`. The class biases. + inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. labels: A `Tensor` of type `int64` and shape `[batch_size, - num_true]`. The target classes. Note that this format differs from - the `labels` argument of `nn.softmax_cross_entropy_with_logits`. + num_true]`. The target classes. Note that this format differs from + the `labels` argument of `nn.softmax_cross_entropy_with_logits`. num_sampled: An `int`. The number of classes to randomly sample per batch. num_classes: An `int`. The number of possible classes. num_true: An `int`. The number of target classes per training example. - sampled_values: a tuple of `(sampled_candidates, true_expected_count, - sampled_expected_count)` returned by a `*_candidate_sampler` function. + sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. (if None, we default to `log_uniform_candidate_sampler`) remove_accidental_hits: A `bool`. whether to remove "accidental hits" where a sampled class equals one of the target classes. Default is @@ -857,7 +868,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, name: A name for the operation (optional). Returns: - A batch_size 1-D tensor of per-example sampled softmax losses. + A `batch_size` 1-D tensor of per-example sampled softmax losses. """ logits, labels = _compute_sampled_logits( diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index f0709dd2a9..1eb8ef4c69 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -22,10 +22,10 @@ from __future__ import print_function import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_nn_ops # pylint: disable=wildcard-import diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index c8962f60af..65e28978ba 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -25,8 +25,8 @@ import tensorflow.python.platform import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.kernel_tests import gradient_checker as gc from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op @@ -50,7 +50,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): pred = [min(max(p, eps), 1 - eps) for p in pred] return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)] - def _Inputs(self, x=None, y=None, dtype=types.float64, sizes=None): + def _Inputs(self, x=None, y=None, dtype=dtypes.float64, sizes=None): x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y assert len(x) == len(y) @@ -70,7 +70,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): def testLogisticOutput(self): for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu): - logits, targets, losses = self._Inputs(dtype=types.float32) + logits, targets, losses = self._Inputs(dtype=dtypes.float32) loss = nn.sigmoid_cross_entropy_with_logits(logits, targets) np_loss = np.array(losses).astype(np.float32) tf_loss = loss.eval() @@ -79,7 +79,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): def testLogisticOutputMultiDim(self): for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu): - logits, targets, losses = self._Inputs(dtype=types.float32, + logits, targets, losses = self._Inputs(dtype=dtypes.float32, sizes=[2, 2, 2]) loss = nn.sigmoid_cross_entropy_with_logits(logits, targets) np_loss = np.array(losses).astype(np.float32) @@ -168,9 +168,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): f_shape = [3, 3, 2, 3] x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=types.float32) + dtype=dtypes.float32) f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=types.float32) + dtype=dtypes.float32) output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") value = output.eval() @@ -205,9 +205,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): f_shape = [3, 3, 2, 3] x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=types.float32) + dtype=dtypes.float32) f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=types.float32) + dtype=dtypes.float32) output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") value = output.eval() @@ -237,9 +237,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): f_shape = [3, 3, 2, 3] x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=types.float32) + dtype=dtypes.float32) f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=types.float32) + dtype=dtypes.float32) output = nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID") value = output.eval() @@ -281,8 +281,8 @@ class DeConv2DTest(test_util.TensorFlowTestCase): x_val = np.random.random_sample(x_shape).astype(np.float64) f_val = np.random.random_sample(f_shape).astype(np.float64) with self.test_session(): - x = constant_op.constant(x_val, name="x", dtype=types.float32) - f = constant_op.constant(f_val, name="f", dtype=types.float32) + x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) + f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") err = gc.ComputeGradientError([x, f], [x_shape, f_shape], output, y_shape) print("DeConv gradient err = %g " % err) @@ -355,7 +355,7 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.test_session(): t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) dropout = nn.dropout(t, keep_prob) final_count = 0 self.assertEqual([x_dim, y_dim], dropout.get_shape()) @@ -384,7 +384,7 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.test_session(): t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) self.assertEqual([x_dim, y_dim], dropout.get_shape()) final_count = 0 @@ -410,7 +410,7 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.test_session(): t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) self.assertEqual([x_dim, y_dim], dropout.get_shape()) for _ in xrange(0, num_iter): @@ -431,8 +431,8 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.test_session(): t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) - keep_prob_placeholder = array_ops.placeholder(types.float32) + dtype=dtypes.float32) + keep_prob_placeholder = array_ops.placeholder(dtypes.float32) dropout = nn.dropout(t, keep_prob_placeholder) final_count = 0 self.assertEqual([x_dim, y_dim], dropout.get_shape()) @@ -453,9 +453,9 @@ class DropoutTest(test_util.TensorFlowTestCase): x_dim = 40 y_dim = 30 keep_prob = 0.5 - x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=types.float32) + x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) dropout_x = nn.dropout( - x, keep_prob, noise_shape=array_ops.placeholder(types.int32)) + x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32)) self.assertEqual(x.get_shape(), dropout_x.get_shape()) def testInvalidKeepProb(self): @@ -463,7 +463,7 @@ class DropoutTest(test_util.TensorFlowTestCase): y_dim = 30 t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) with self.assertRaises(ValueError): nn.dropout(t, -1.0) with self.assertRaises(ValueError): @@ -471,9 +471,9 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): nn.dropout(t, [0.0, 1.0]) with self.assertRaises(ValueError): - nn.dropout(t, array_ops.placeholder(types.float64)) + nn.dropout(t, array_ops.placeholder(dtypes.float64)) with self.assertRaises(ValueError): - nn.dropout(t, array_ops.placeholder(types.float32, shape=[2])) + nn.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) def testShapedDropoutShapeError(self): # Runs shaped dropout and verifies an error is thrown on misshapen noise. @@ -482,7 +482,7 @@ class DropoutTest(test_util.TensorFlowTestCase): keep_prob = 0.5 t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) with self.assertRaises(ValueError): _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) with self.assertRaises(ValueError): @@ -641,7 +641,7 @@ class MomentsTest(test_util.TensorFlowTestCase): assert len(shape) == 4 x_numpy = np.random.normal(size=shape).astype(np.float32) - x = array_ops.placeholder(types.float32, shape=[None] * len(shape)) + x = array_ops.placeholder(dtypes.float32, shape=[None] * len(shape)) axes = [0, 1, 2] if global_norm else [0] mean, var = nn.moments(x, axes) @@ -722,6 +722,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): self._num_classes = 5 self._dim = 10 self._batch_size = 3 + self._num_shards = 3 def _GenerateTestInputs(self): np.random.seed(0) @@ -729,8 +730,11 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): biases = np.random.randn(self._num_classes).astype(np.float32) hidden_acts = np.random.randn(self._batch_size, self._dim).astype( np.float32) - - return weights, biases, hidden_acts + sharded_weights = [ + weights[[row for row in range(self._num_classes) + if row % self._num_shards == shard]] + for shard in range(self._num_shards)] + return weights, biases, hidden_acts, sharded_weights def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b, hidden_acts, @@ -763,11 +767,14 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): subtract_log_q, remove_accidental_hits, name="sampled_loss_TF"): # Should be called from within a `with test_session():` block - weights_tf = constant_op.constant(weights) + if isinstance(weights, list): + weights_tf = [constant_op.constant(shard) for shard in weights] + else: + weights_tf = constant_op.constant(weights) biases_tf = constant_op.constant(biases) hidden_acts_tf = constant_op.constant(hidden_acts, shape=(self._batch_size, self._dim)) - labels_tf = constant_op.constant(labels, dtype=types.int64, + labels_tf = constant_op.constant(labels, dtype=dtypes.int64, shape=(self._batch_size, num_true)) pred_logits_tf, pred_labels_tf = nn._compute_sampled_logits( @@ -780,7 +787,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): def testComputeSampledLogitsShapes(self): # We just check that the shapes of the returned values are correct. - weights, biases, hidden_acts = self._GenerateTestInputs() + weights, biases, hidden_acts, _ = self._GenerateTestInputs() sampled = [1, 0, 2, 3] num_sampled = len(sampled) true_exp = sampled_exp = [1., 1., 1., 1.] @@ -811,7 +818,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): def testComputeSampledLogitsValues(self): # Here we check the actual numerics. - weights, biases, hidden_acts = self._GenerateTestInputs() + weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() eps = 1e-3 sampled = [1, 0, 2, 3] num_sampled = len(sampled) @@ -887,6 +894,23 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): self.assertAllClose(logits_np, logits_tf_val, eps) self.assertAllClose(labels_np, labels_tf_val, eps) + # Test 4: Test 1, with sharded weights + logits_np, labels_np = self._ComputeSampledLogitsNP( + true_w, true_b, sampled_w, sampled_b, hidden_acts, + num_true=num_true_test) + logits_tf, labels_tf = self._ComputeSampledLogitsTF( + sharded_weights, biases, hidden_acts, labels, num_sampled, + self._num_classes, + num_true=num_true_test, + sampled_vals=test_sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + name="sampled_loss_test1_num_true%d" % num_true_test) + + logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) + self.assertAllClose(logits_np, logits_tf_val, eps) + self.assertAllClose(labels_np, labels_tf_val, eps) + def testNCELoss(self): # A simple test to verify the numerics. @@ -898,7 +922,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): pred = np.minimum(np.maximum(pred, eps), 1 - eps) return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) - weights, biases, hidden_acts = self._GenerateTestInputs() + weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() labels = [0, 1, 2] true_w, true_b = weights[labels], biases[labels] sampled = [1, 0, 2, 3] @@ -932,6 +956,17 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) + # Test with sharded weights + nce_loss_tf = nn.nce_loss( + [constant_op.constant(shard) for shard in sharded_weights], + biases_tf, inputs_tf, labels_tf, + num_sampled=1, + num_classes=self._num_classes, + num_true=1, + sampled_values=test_sampled_vals) + + self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) + def testSampledSoftmaxLoss(self): # A simple test to verify the numerics. @@ -943,7 +978,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) - weights, biases, hidden_acts = self._GenerateTestInputs() + weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() labels = [0, 1, 2] true_w, true_b = weights[labels], biases[labels] sampled = [1, 0, 2, 3] @@ -977,6 +1012,19 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): self.assertAllClose( sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) + # Test with sharded weights + sampled_softmax_loss_tf = nn.sampled_softmax_loss( + [constant_op.constant(shard) for shard in sharded_weights], + biases_tf, inputs_tf, labels_tf, + num_sampled=1, + num_classes=self._num_classes, + num_true=1, + sampled_values=test_sampled_vals, + remove_accidental_hits=False) + + self.assertAllClose( + sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py index b2312332f1..8525da35f5 100644 --- a/tensorflow/python/ops/numerics.py +++ b/tensorflow/python/ops/numerics.py @@ -19,8 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -45,7 +45,7 @@ def verify_tensor_all_finite(t, msg, name=None): def add_check_numerics_ops(): - """Connect a check_numerics to every floating point tensor. + """Connect a `check_numerics` to every floating point tensor. `check_numerics` operations themselves are added for each `float` or `double` tensor in the graph. For all ops in the graph, the `check_numerics` op for @@ -62,7 +62,7 @@ def add_check_numerics_ops(): # added, and ops can only be added once its inputs are added. for op in ops.get_default_graph().get_operations(): for output in op.outputs: - if output.dtype in [types.float32, types.float64]: + if output.dtype in [dtypes.float32, dtypes.float64]: message = op.name + ":" + str(output.value_index) with ops.control_dependencies(check_op): check_op = [array_ops.check_numerics(output, message=message)] diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py index 562299fdd6..ad0406d43d 100644 --- a/tensorflow/python/ops/op_def_library.py +++ b/tensorflow/python/ops/op_def_library.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numbers import six from tensorflow.core.framework import attr_value_pb2 @@ -27,11 +26,12 @@ from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types as types_lib from tensorflow.python.ops import constant_op from tensorflow.python.platform import logging +from tensorflow.python.util import compat def _Attr(op_def, name): @@ -55,8 +55,8 @@ def _SatisfiesTypeConstraint(dtype, attr_def): if dtype not in allowed_list: raise TypeError( "DataType %s for attr '%s' not in list of allowed values: %s" % - (types_lib.as_dtype(dtype).name, attr_def.name, - ", ".join(types_lib.as_dtype(x).name for x in allowed_list))) + (dtypes.as_dtype(dtype).name, attr_def.name, + ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) def _IsListParameter(arg): @@ -137,7 +137,7 @@ def _Restructure(l, structure): def _MakeFloat(v, arg_name): - if not isinstance(v, numbers.Real): + if not isinstance(v, compat.real_types): raise TypeError("Expected float for argument '%s' not %s." % (arg_name, repr(v))) return float(v) @@ -155,11 +155,10 @@ def _MakeInt(v, arg_name): def _MakeStr(v, arg_name): - if not isinstance(v, six.string_types): + if not isinstance(v, compat.bytes_or_text_types): raise TypeError("Expected string for argument '%s' not %s." % (arg_name, repr(v))) - # TODO(irving): Figure out what to do here from Python 3 - return str(v) # Convert unicode strings to bytes. + return compat.as_bytes(v) # Convert unicode strings to bytes. def _MakeBool(v, arg_name): @@ -171,7 +170,7 @@ def _MakeBool(v, arg_name): def _MakeType(v, attr_def): try: - v = types_lib.as_dtype(v) + v = dtypes.as_dtype(v) except TypeError: raise TypeError("Expected DataType for argument '%s' not %s." % (attr_def.name, repr(v))) @@ -381,7 +380,7 @@ class OpDefLibrary(object): try: values = ops.convert_n_to_tensor_or_indexed_slices( values, name=input_arg.name, - dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None) + dtype=dtypes.as_dtype(dtype).base_dtype if dtype else None) except (TypeError, ValueError): assert dtype is not None, "Should not fail if dtype is None" assert input_arg.number_attr, "Should be number_attr case" @@ -394,11 +393,11 @@ class OpDefLibrary(object): (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s that do not match expected type %s." % - (prefix, types_lib.as_dtype(dtype).name)) + (prefix, dtypes.as_dtype(dtype).name)) elif input_arg.type_attr in attrs: raise TypeError("%s that do not match type %s inferred from " "earlier arguments." % - (prefix, types_lib.as_dtype(dtype).name)) + (prefix, dtypes.as_dtype(dtype).name)) else: raise TypeError("%s that don't all match." % prefix) @@ -423,11 +422,11 @@ class OpDefLibrary(object): (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s expected type of %s." % - (prefix, types_lib.as_dtype(input_arg.type).name)) + (prefix, dtypes.as_dtype(input_arg.type).name)) else: raise TypeError( "%s type %s of argument '%s'." % - (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name, + (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] @@ -501,8 +500,8 @@ class OpDefLibrary(object): "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, - ", ".join(types_lib.as_dtype(x).name for x in attr_value), - ", ".join(types_lib.as_dtype(x).name + ", ".join(dtypes.as_dtype(x).name for x in attr_value), + ", ".join(dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: @@ -567,8 +566,9 @@ class OpDefLibrary(object): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, attr_value.s, - '", "'.join(attr_def.allowed_values.list.s))) + (key, op_type_name, compat.as_text(attr_value.s), + '", "'.join(map(compat.as_text, + attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): @@ -576,8 +576,9 @@ class OpDefLibrary(object): if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, x, - '", "'.join(attr_def.allowed_values.list.s))) + (key, op_type_name, compat.as_text(x), + '", "'.join(map(compat.as_text, + attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: @@ -640,7 +641,7 @@ class OpDefLibrary(object): types = [arg.type] output_structure.append(None) if arg.is_ref: - types = [types_lib.as_dtype(x).as_ref for x in types] + types = [dtypes.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py index 4a8cdbe9a4..4b9d66e8a6 100644 --- a/tensorflow/python/ops/op_def_library_test.py +++ b/tensorflow/python/ops/op_def_library_test.py @@ -23,10 +23,10 @@ from google.protobuf import text_format from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops.op_def_library import OpDefLibrary from tensorflow.python.platform import googletest @@ -104,13 +104,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNoRegisteredOpFails(self): with self.assertRaises(RuntimeError) as cm: self._lib.apply_op("unknown", g=self._g) - self.assertEqual(cm.exception.message, "Unrecognized Op name unknown") + self.assertEqual(str(cm.exception), "Unrecognized Op name unknown") def testAddOpValidation(self): with self.assertRaises(TypeError) as cm: self._add_op("name: 'MissingTypeAttr' " "input_arg { name: 'a' type_attr: 'T' } ") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Inconsistent OpDef for 'MissingTypeAttr', " "missing attr 'T'") @@ -119,13 +119,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'a' type_attr: 'T' } " "attr { name: 'T' type: 'int' }") self.assertEqual( - cm.exception.message, + str(cm.exception), "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int") with self.assertRaises(TypeError) as cm: self._add_op("name: 'MissingNumberAttr' " "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Inconsistent OpDef for 'MissingNumberAttr', " "missing attr 'N'") @@ -134,21 +134,21 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " "attr { name: 'N' type: 'type' }") self.assertEqual( - cm.exception.message, + str(cm.exception), "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type") with self.assertRaises(TypeError) as cm: self._add_op("name: 'TwoTypesA' " "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } " "attr { name: 'T' type: 'type' }") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Arg 'a' of 'TwoTypesA' must have one type field not 2") with self.assertRaises(TypeError) as cm: self._add_op("name: 'TwoTypesB' " "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' }") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Arg 'a' of 'TwoTypesB' must have one type field not 2") with self.assertRaises(TypeError) as cm: @@ -157,17 +157,17 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "type_list_attr: 'U' } " "attr { name: 'T' type: 'type' } " "attr { name: 'U' type: 'list(type)' }") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Arg 'a' of 'ThreeTypes' must have one type field not 3") with self.assertRaises(TypeError) as cm: self._add_op("name: 'NoTypes' output_arg { name: 'a' } ") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Arg 'a' of 'NoTypes' must have one type field not 0") def testSimple(self): out = self._lib.apply_op("Simple", a=3) - self.assertEquals(types.float32, out.dtype) + self.assertEqual(dtypes.float32, out.dtype) self.assertProtoEquals(""" name: 'Simple' op: 'Simple' input: 'Simple/a' """, out.op.node_def) @@ -190,39 +190,39 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testSimpleFailures(self): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a="Bad string") - self.assertEqual(cm.exception.message, - "Expected int32, got 'Bad string' instead.") + self.assertEqual(str(cm.exception), + "Expected int32, got 'Bad string' of type 'str' instead.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a=self.Tensor(types.string)) - self.assertEqual(cm.exception.message, + self._lib.apply_op("Simple", a=self.Tensor(dtypes.string)) + self.assertEqual(str(cm.exception), "Input 'a' of 'Simple' Op has type string " "that does not match expected type of int32.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a=6, extra="bogus") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "apply_op() got unexpected keyword arguments: extra") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "apply_op() got unexpected keyword arguments: extra1, " "extra2") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple") - self.assertEqual(cm.exception.message, "No argument for input a") + self.assertEqual(str(cm.exception), "No argument for input a") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", wrong=7) - self.assertEqual(cm.exception.message, "No argument for input a") + self.assertEqual(str(cm.exception), "No argument for input a") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a=[self.Tensor(types.int32)]) - self.assertStartsWith( - cm.exception.message, - "Expected int32, got list containing Tensors instead.") + self._lib.apply_op("Simple", a=[self.Tensor(dtypes.int32)]) + self.assertStartsWith(str(cm.exception), + "Expected int32, got list containing Tensors of type " + "'_Message' instead.") def testReservedInput(self): self._add_op("name: 'ReservedInput' " @@ -239,34 +239,34 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'T' type: 'type' }") out = self._lib.apply_op("Polymorphic", a=7, name="p") - self.assertEquals(types.int32, out.dtype) + self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'Polymorphic' input: 'p/a' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) out = self._lib.apply_op("Polymorphic", a="s", name="q") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'Polymorphic' input: 'q/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'r' op: 'Polymorphic' input: 'r/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Polymorphic", a="s", T=types.string) - self.assertEqual(cm.exception.message, + self._lib.apply_op("Polymorphic", a="s", T=dtypes.string) + self.assertEqual(str(cm.exception), "Should not specify value for inferred attr 'T'.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Polymorphic", a=[self.Tensor(types.bool)]) - self.assertEqual(cm.exception.message, + self._lib.apply_op("Polymorphic", a=[self.Tensor(dtypes.bool)]) + self.assertEqual(str(cm.exception), "List of Tensors when single Tensor expected") def testPolymorphicOut(self): @@ -274,15 +274,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'out' type_attr: 'T' } " "attr { name: 'T' type: 'type' }") - out = self._lib.apply_op("PolymorphicOut", T=types.int32, name="p") - self.assertEquals(types.int32, out.dtype) + out = self._lib.apply_op("PolymorphicOut", T=dtypes.int32, name="p") + self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'PolymorphicOut' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) - out = self._lib.apply_op("PolymorphicOut", T=types.bool, name="q") - self.assertEquals(types.bool, out.dtype) + out = self._lib.apply_op("PolymorphicOut", T=dtypes.bool, name="q") + self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'PolymorphicOut' attr { key: 'T' value { type: DT_BOOL } } @@ -290,12 +290,12 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("PolymorphicOut") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "No argument for attr T") with self.assertRaises(TypeError) as cm: self._lib.apply_op("PolymorphicOut", T=None) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected DataType for argument 'T' not None.") def testPolymorphicDefaultOut(self): @@ -305,15 +305,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): " default_value { type: DT_STRING } }") out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'PolymorphicDefaultOut' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) - out = self._lib.apply_op("PolymorphicDefaultOut", T=types.bool, + out = self._lib.apply_op("PolymorphicDefaultOut", T=dtypes.bool, name="q") - self.assertEquals(types.bool, out.dtype) + self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'PolymorphicDefaultOut' attr { key: 'T' value { type: DT_BOOL } } @@ -327,14 +327,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'T' type: 'type' }") out = self._lib.apply_op("Binary", a=8, b=9, name="b") - self.assertEquals(types.int32, out.dtype) + self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'b' op: 'Binary' input: 'b/a' input: 'b/b' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) out = self._lib.apply_op("Binary", a="left", b="right", name="c") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'c' op: 'Binary' input: 'c/a' input: 'c/b' attr { key: 'T' value { type: DT_STRING } } @@ -342,13 +342,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Binary", a="left", b=12) - self.assertEqual(cm.exception.message, - "Expected string, got 12 instead.") + self.assertEqual(str(cm.exception), + "Expected string, got 12 of type 'int' instead.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Binary", a=self.Tensor(types.string), - b=self.Tensor(types.int32)) - self.assertEqual(cm.exception.message, + self._lib.apply_op("Binary", a=self.Tensor(dtypes.string), + b=self.Tensor(dtypes.int32)) + self.assertEqual(str(cm.exception), "Input 'b' of 'Binary' Op has type int32 " "that does not match type string of argument 'a'.") @@ -360,14 +360,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): " type: DT_STRING type: DT_BOOL } } }") out = self._lib.apply_op("Restrict", a="foo", name="g") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'g' op: 'Restrict' input: 'g/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) out = self._lib.apply_op("Restrict", a=True, name="h") - self.assertEquals(types.bool, out.dtype) + self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'h' op: 'Restrict' input: 'h/a' attr { key: 'T' value { type: DT_BOOL } } @@ -375,7 +375,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Restrict", a=17) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "DataType int32 for attr 'T' " "not in list of allowed values: " "string, bool") @@ -404,7 +404,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeList", a=17) - self.assertStartsWith(cm.exception.message, + self.assertStartsWith(str(cm.exception), "Expected list for 'a' " "argument to 'TypeList' Op, not ") @@ -429,7 +429,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Input 'b' of 'TypeListTwice' Op has type list of " "string, int32 that does not match type list " "string, bool of argument 'a'.") @@ -439,18 +439,18 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'out' type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' }") - out, = self._lib.apply_op("OutTypeList", T=[types.float32], name="x") - self.assertEquals(types.float32, out.dtype) + out, = self._lib.apply_op("OutTypeList", T=[dtypes.float32], name="x") + self.assertEqual(dtypes.float32, out.dtype) self.assertProtoEquals(""" name: 'x' op: 'OutTypeList' attr { key: 'T' value { list { type: DT_FLOAT } } } """, out.op.node_def) out1, out2 = self._lib.apply_op("OutTypeList", - T=[types.int32, types.bool], + T=[dtypes.int32, dtypes.bool], name="w") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.bool, out2.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.bool, out2.dtype) self.assertProtoEquals(""" name: 'w' op: 'OutTypeList' attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } } @@ -460,8 +460,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): self.assertEqual([], out) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("OutTypeList", T=types.int32) - self.assertEqual(cm.exception.message, "Expected list for attr T") + self._lib.apply_op("OutTypeList", T=dtypes.int32) + self.assertEqual(str(cm.exception), "Expected list for attr T") def testTypeListRestrict(self): self._add_op("name: 'TypeListRestrict' " @@ -477,7 +477,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeListRestrict", a=[True, 12]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "DataType int32 for attr 'T' " "not in list of allowed values: string, bool") @@ -488,10 +488,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): " type: DT_STRING type: DT_BOOL } } }") out1, out2 = self._lib.apply_op("OutTypeListRestrict", - t=[types.bool, types.string], + t=[dtypes.bool, dtypes.string], name="u") - self.assertEquals(types.bool, out1.dtype) - self.assertEquals(types.string, out2.dtype) + self.assertEqual(dtypes.bool, out1.dtype) + self.assertEqual(dtypes.string, out2.dtype) self.assertProtoEquals(""" name: 'u' op: 'OutTypeListRestrict' attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } } @@ -499,8 +499,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("OutTypeListRestrict", - t=[types.string, types.int32]) - self.assertEqual(cm.exception.message, + t=[dtypes.string, dtypes.int32]) + self.assertEqual(str(cm.exception), "DataType int32 for attr 't' " "not in list of allowed values: string, bool") @@ -518,22 +518,22 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a="bad") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected int for argument 'a' not 'bad'.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a=[12]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected int for argument 'a' not [12].") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a=None) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected int for argument 'a' not None.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr") - self.assertEqual(cm.exception.message, "No argument for attr a") + self.assertEqual(str(cm.exception), "No argument for attr a") def testAttrFloat(self): self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }") @@ -550,7 +550,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrFloat", a="bad") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected float for argument 'a' not 'bad'.") def testAttrBool(self): @@ -568,17 +568,17 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=0) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 0.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=1) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 1.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=[]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not [].") def testAttrBoolList(self): @@ -597,7 +597,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBoolList", a=[0]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 0.") def testAttrMin(self): @@ -610,7 +610,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrMin", a=2) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.") def testAttrListMin(self): @@ -625,7 +625,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrListMin", a=[17]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Attr 'a' of 'AttrListMin' Op " "passed list of length 1 less than minimum 2.") @@ -641,7 +641,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrEnum", a="invalid") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), 'Attr \'a\' of \'AttrEnum\' Op ' 'passed string \'invalid\' not in: ' '"apples", "oranges".') @@ -659,7 +659,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), 'Attr \'a\' of \'AttrEnumList\' Op ' 'passed string \'invalid\' not ' 'in: "apples", "oranges".') @@ -705,7 +705,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): # TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes. # with self.assertRaises(TypeError) as cm: # self._lib.apply_op("AttrShape", a=5) - # self.assertEqual(cm.exception.message, + # self.assertEqual(str(cm.exception), # "Don't know how to convert 5 to a TensorShapeProto for " # "argument 'a'") @@ -817,42 +817,42 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=["foo", "bar"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have types " "[string, string] that do not match expected type int32.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", a=[self.Tensor(types.string), - self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + self._lib.apply_op("NIntsIn", a=[self.Tensor(dtypes.string), + self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have " "types [string, string] that do not match expected type " "int32.") with self.assertRaises(ValueError) as cm: self._lib.apply_op("NIntsIn", a=[99]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'a' to 'NIntsIn' Op " "with length 1 shorter than " "minimum length 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=[38, "bar"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have types " "[int32, string] that do not match expected type int32.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", a=[self.Tensor(types.int32), - self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + self._lib.apply_op("NIntsIn", a=[self.Tensor(dtypes.int32), + self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op " "have types [int32, string] that do not match expected " "type int32.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=17) - self.assertStartsWith(cm.exception.message, + self.assertStartsWith(str(cm.exception), "Expected list for 'a' argument " "to 'NIntsIn' Op, not ") @@ -885,7 +885,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) op = self._lib.apply_op("NPolymorphicIn", - a=[1, self.Tensor(types.float32, name="x")], + a=[1, self.Tensor(dtypes.float32, name="x")], name="q") self.assertProtoEquals(""" name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x' @@ -895,33 +895,33 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("NPolymorphicIn", a=[99]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'a' to 'NPolymorphicIn' Op with length 1 " "shorter than minimum length 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=[38, "bar"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "All tensors passed to 'a' of 'NPolymorphicIn' " "Op must have the same type.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", - a=[38, self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + a=[38, self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [int32, string] that don't all match.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", - a=["abcd", self.Tensor(types.int32)]) - self.assertEqual(cm.exception.message, + a=["abcd", self.Tensor(dtypes.int32)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [string, int32] that don't all match.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=17) - self.assertStartsWith(cm.exception.message, + self.assertStartsWith(str(cm.exception), "Expected list for 'a' argument " "to 'NPolymorphicIn' Op, not ") @@ -951,7 +951,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "DataType int32 for attr 'T' " "not in list of allowed values: string, bool") @@ -975,7 +975,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'b' to 'NInTwice' Op " "with length 1 must match " "length 3 of argument 'a'.") @@ -997,23 +997,23 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'b' to 'NInPolymorphicTwice' Op " "with length 1 " "must match length 3 of argument 'a'.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'NInPolymorphicTwice' " "Op have types [string, string] that do not match type " "int32 inferred from earlier arguments.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NInPolymorphicTwice", - a=[self.Tensor(types.int32)], - b=[self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + a=[self.Tensor(dtypes.int32)], + b=[self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of " "'NInPolymorphicTwice' Op have types [string] that do not " "match type int32 inferred from earlier arguments.") @@ -1046,8 +1046,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) op = self._lib.apply_op("NInTwoTypeVariables", - a=[self.Tensor(types.int32, name="q")], - b=[self.Tensor(types.string, name="r")], + a=[self.Tensor(dtypes.int32, name="q")], + b=[self.Tensor(dtypes.string, name="r")], name="p") self.assertProtoEquals(""" name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r' @@ -1058,7 +1058,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'b' to 'NInTwoTypeVariables' Op " "with length 1 " "must match length 3 of argument 'a'.") @@ -1090,22 +1090,22 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Don't know how to infer type variable from empty input " "list passed to input 'a' of 'InPolymorphicTwice' Op.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op " "have types [string, string] that do not match type int32 " "inferred from earlier arguments.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", - a=[self.Tensor(types.int32)], - b=[self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + a=[self.Tensor(dtypes.int32)], + b=[self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'InPolymorphicTwice' " "Op have types [string] that do not match type int32 " "inferred from earlier arguments.") @@ -1116,31 +1116,31 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } } """, out1.op.node_def) out1, out2, out3, out4, out5 = self._lib.apply_op( "NIntsOut", N=5, name="o") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) - self.assertEquals(types.int32, out3.dtype) - self.assertEquals(types.int32, out4.dtype) - self.assertEquals(types.int32, out5.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) + self.assertEqual(dtypes.int32, out3.dtype) + self.assertEqual(dtypes.int32, out4.dtype) + self.assertEqual(dtypes.int32, out5.dtype) self.assertProtoEquals(""" name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } } """, out5.op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("NIntsOut", N=1) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsOut", N=[3]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected int for argument 'N' not [3].") def testNIntsOutDefault(self): @@ -1151,16 +1151,16 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): out1, out2, out3 = self._lib.apply_op( "NIntsOutDefault", N=None, name="z") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) - self.assertEquals(types.int32, out3.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) + self.assertEqual(dtypes.int32, out3.dtype) self.assertProtoEquals(""" name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } } """, out1.op.node_def) out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } } """, out2.op.node_def) @@ -1172,9 +1172,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2 = self._lib.apply_op("NPolymorphicOut", N=2, - T=types.int32, name="n") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) + T=dtypes.int32, name="n") + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'n' op: 'NPolymorphicOut' attr { key: 'T' value { type: DT_INT32 } } @@ -1182,10 +1182,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) out1, out2, out3 = self._lib.apply_op( - "NPolymorphicOut", T=types.string, N=3, name="o") - self.assertEquals(types.string, out1.dtype) - self.assertEquals(types.string, out2.dtype) - self.assertEquals(types.string, out3.dtype) + "NPolymorphicOut", T=dtypes.string, N=3, name="o") + self.assertEqual(dtypes.string, out1.dtype) + self.assertEqual(dtypes.string, out2.dtype) + self.assertEqual(dtypes.string, out3.dtype) self.assertProtoEquals(""" name: 'o' op: 'NPolymorphicOut' attr { key: 'T' value { type: DT_STRING } } @@ -1193,15 +1193,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out3.op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NPolymorphicOut", N=1, T=types.string) - self.assertEqual(cm.exception.message, + self._lib.apply_op("NPolymorphicOut", N=1, T=dtypes.string) + self.assertEqual(str(cm.exception), "Attr 'N' of 'NPolymorphicOut' Op " "passed 1 less than minimum 2.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicOut", N=3, T=[types.string]) + self._lib.apply_op("NPolymorphicOut", N=3, T=[dtypes.string]) self.assertEqual( - cm.exception.message, + str(cm.exception), "Expected DataType for argument 'T' not [tf.string].") def testNPolymorphicOutDefault(self): @@ -1214,8 +1214,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): out1, out2 = self._lib.apply_op( "NPolymorphicOutDefault", N=None, T=None, name="r") - self.assertEquals(types.bool, out1.dtype) - self.assertEquals(types.bool, out2.dtype) + self.assertEqual(dtypes.bool, out1.dtype) + self.assertEqual(dtypes.bool, out2.dtype) self.assertProtoEquals(""" name: 'r' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_BOOL } } @@ -1224,9 +1224,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): out1, out2, out3 = self._lib.apply_op( "NPolymorphicOutDefault", N=3, T=None, name="s") - self.assertEquals(types.bool, out1.dtype) - self.assertEquals(types.bool, out2.dtype) - self.assertEquals(types.bool, out3.dtype) + self.assertEqual(dtypes.bool, out1.dtype) + self.assertEqual(dtypes.bool, out2.dtype) + self.assertEqual(dtypes.bool, out3.dtype) self.assertProtoEquals(""" name: 's' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_BOOL } } @@ -1234,9 +1234,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) out1, out2 = self._lib.apply_op( - "NPolymorphicOutDefault", N=None, T=types.int32, name="t") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) + "NPolymorphicOutDefault", N=None, T=dtypes.int32, name="t") + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 't' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_INT32 } } @@ -1244,10 +1244,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) out1, out2, out3 = self._lib.apply_op( - "NPolymorphicOutDefault", N=3, T=types.int32, name="u") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) - self.assertEquals(types.int32, out3.dtype) + "NPolymorphicOutDefault", N=3, T=dtypes.int32, name="u") + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) + self.assertEqual(dtypes.int32, out3.dtype) self.assertProtoEquals(""" name: 'u' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_INT32 } } @@ -1262,10 +1262,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2, out3 = self._lib.apply_op( - "NPolymorphicRestrictOut", N=3, T=types.bool, name="u") - self.assertEquals(types.bool, out1.dtype) - self.assertEquals(types.bool, out2.dtype) - self.assertEquals(types.bool, out3.dtype) + "NPolymorphicRestrictOut", N=3, T=dtypes.bool, name="u") + self.assertEqual(dtypes.bool, out1.dtype) + self.assertEqual(dtypes.bool, out2.dtype) + self.assertEqual(dtypes.bool, out3.dtype) self.assertProtoEquals(""" name: 'u' op: 'NPolymorphicRestrictOut' attr { key: 'T' value { type: DT_BOOL } } @@ -1273,8 +1273,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=types.int32) - self.assertEqual(cm.exception.message, + self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32) + self.assertEqual(str(cm.exception), "DataType int32 for attr 'T' " "not in list of allowed values: string, bool") @@ -1286,8 +1286,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'a' type_attr: 'T' is_ref: true } " "attr { name: 'T' type: 'type' } ") - out = self._lib.apply_op("RefOut", T=types.bool, name="o") - self.assertEquals(types.bool_ref, out.dtype) + out = self._lib.apply_op("RefOut", T=dtypes.bool, name="o") + self.assertEqual(dtypes.bool_ref, out.dtype) self.assertProtoEquals(""" name: 'o' op: 'RefOut' attr { key: 'T' value { type: DT_BOOL } } @@ -1300,7 +1300,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) # Can pass ref to non-ref input. - out = self._lib.apply_op("RefOut", T=types.int32, name="r") + out = self._lib.apply_op("RefOut", T=dtypes.int32, name="r") out = self._lib.apply_op("Simple", a=out, name="s") self.assertProtoEquals(""" name: 's' op: 'Simple' input: 'r' @@ -1309,7 +1309,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): # Can't pass non-ref to ref input. with self.assertRaises(TypeError) as cm: self._lib.apply_op("RefIn", a=2) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Input 'a' of 'RefIn' Op requires l-value input") def testSpecifyDevice(self): @@ -1340,9 +1340,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): a, b = self._lib.apply_op("MixedStruct", n_a=n_a) self.assertTrue(isinstance(a, list)) self.assertEqual(n_a, len(a)) - self.assertTrue(all(x.dtype == types.int32 for x in a)) + self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) self.assertTrue(isinstance(b, ops.Tensor)) - self.assertEqual(types.float32, b.dtype) + self.assertEqual(dtypes.float32, b.dtype) def testStructuredOutputMultipleLists(self): self._add_op("name: 'ComplexStruct' " @@ -1355,15 +1355,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): for n_a in [0, 1, 3]: for n_b in [0, 1, 3]: for t_c in [[], - [types.int32], - [types.int32, types.float32]]: + [dtypes.int32], + [dtypes.int32, dtypes.float32]]: a, b, c = self._lib.apply_op("ComplexStruct", n_a=n_a, n_b=n_b, t_c=t_c) self.assertEqual(n_a, len(a)) - self.assertTrue(all(x.dtype == types.int32 for x in a)) + self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) self.assertEqual(n_b, len(b)) - self.assertTrue(all(x.dtype == types.int64 for x in b)) + self.assertTrue(all(x.dtype == dtypes.int64 for x in b)) self.assertEqual(t_c, [x.dtype for x in c]) @@ -1387,19 +1387,19 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): def testNoGraph(self): out = self._lib.apply_op("Simple", a=3) - self.assertEquals(out.graph, ops.get_default_graph()) + self.assertEqual(out.graph, ops.get_default_graph()) def testDefaultGraph(self): with self._g.as_default(): out = self._lib.apply_op("Simple", a=3) - self.assertEquals(out.graph, self._g) + self.assertEqual(out.graph, self._g) def testIgnoreDefaultGraphWithGraphArgument(self): default_g = ops.Graph() with default_g.as_default(): out = self._lib.apply_op("Simple", a=3, g=self._g) - self.assertEquals(ops.get_default_graph(), default_g) - self.assertEquals(out.graph, self._g) + self.assertEqual(ops.get_default_graph(), default_g) + self.assertEqual(out.graph, self._g) def testDifferentGraphFails(self): a = self._lib.apply_op("Simple", a=3, g=self._g) @@ -1407,7 +1407,7 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): b = self._lib.apply_op("Simple", a=4, g=other_g) with self.assertRaises(ValueError) as cm: self._lib.apply_op("Binary", a=a, b=b) - self.assertTrue("must be from the same graph" in cm.exception.message) + self.assertTrue("must be from the same graph" in str(cm.exception)) def testDifferentGraphFailsWithGraphArgument(self): other_g = ops.Graph() @@ -1416,7 +1416,7 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("Binary", a=a, b=b, g=self._g) self.assertTrue( - "not from the passed-in graph" in cm.exception.message) + "not from the passed-in graph" in str(cm.exception)) if __name__ == "__main__": diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index a6408f9f0a..87148af2fd 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -88,12 +88,12 @@ def parse_example(serialized, ``` serialized = [ - features: - { feature: [ key: { "ft" value: float_list: { value: [1.0, 2.0] } } ] }, - features: - { feature: [] }, - features: - { feature: [ key: { "ft" value: float_list: { value: [3.0] } } ] } + features + { feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } }, + features + { feature []}, + features + { feature { key: "ft" value { float_list { value: [3.0] } } } ] ``` @@ -109,14 +109,14 @@ def parse_example(serialized, ``` [ - features: { - feature: { key: "kw" value: { bytes_list: { value: [ "knit", "big" ] } } } - feature: { key: "gps" value: { float_list: { value: [] } } } + features { + feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } } + feature { key: "gps" value { float_list { value: [] } } } }, - features: { - feature: { key: "kw" value: { bytes_list: { value: [ "emmy" ] } } } - feature: { key: "dank" value: { int64_list: { value: [ 42 ] } } } - feature: { key: "gps" value: { } } + features { + feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } } + feature { key: "dank" value { int64_list { value: [ 42 ] } } } + feature { key: "gps" value { } } } ] ``` @@ -152,13 +152,13 @@ def parse_example(serialized, ``` [ - features: { - feature: { key: "age" value: { int64_list: { value: [ 0 ] } } } - feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } } + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } }, - features: { - feature: { key: "age" value: { int64_list: { value: [] } } } - feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } } + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } } ] ``` @@ -204,6 +204,8 @@ def parse_example(serialized, The keys of the dict must match the dense_keys of the feature. dense_shapes: A list of tuples with the same length as `dense_keys`. The shape of the data for each dense feature referenced by `dense_keys`. + Required for any input tensors identified by dense_keys whose shapes are + anything other than [] or [1]. name: A name for this operation (optional). Returns: @@ -297,22 +299,22 @@ def parse_single_example(serialized, # pylint: disable=invalid-name For `SparseTensor`s, the first (batch) column of the indices matrix is removed (the indices matrix is a column vector), the values vector is unchanged, and - the first (batch_size) entry of the shape vector is removed (it is now a + the first (`batch_size`) entry of the shape vector is removed (it is now a single element vector). See also `parse_example`. Args: serialized: A scalar string Tensor, a single serialized Example. - See parse_example documentation for more details. + See `parse_example` documentation for more details. names: (Optional) A scalar string Tensor, the associated name. - See parse_example documentation for more details. - sparse_keys: See parse_example documentation for more details. - sparse_types: See parse_example documentation for more details. - dense_keys: See parse_example documentation for more details. - dense_types: See parse_example documentation for more details. - dense_defaults: See parse_example documentation for more details. - dense_shapes: See parse_example documentation for more details. + See `parse_example` documentation for more details. + sparse_keys: See `parse_example` documentation for more details. + sparse_types: See `parse_example` documentation for more details. + dense_keys: See `parse_example` documentation for more details. + dense_types: See `parse_example` documentation for more details. + dense_defaults: See `parse_example` documentation for more details. + dense_shapes: See `parse_example` documentation for more details. name: A name for this operation (optional). Returns: diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 37b1cb115e..cefeec54bb 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -19,10 +19,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.framework import random_seed from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_random_ops @@ -35,14 +35,14 @@ from tensorflow.python.ops.gen_random_ops import * def _ShapeTensor(shape): """Convert to an int32 or int64 tensor, defaulting to int32 if empty.""" if isinstance(shape, (tuple, list)) and not shape: - dtype = types.int32 + dtype = dtypes.int32 else: dtype = None return ops.convert_to_tensor(shape, dtype=dtype, name="shape") # pylint: disable=protected-access -def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32, +def random_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, seed=None, name=None): """Outputs random values from a normal distribution. @@ -80,7 +80,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32, ops.NoGradient("RandomStandardNormal") -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32, +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, seed=None, name=None): """Outputs random values from a truncated normal distribution. @@ -123,7 +123,7 @@ ops.NoGradient("TruncatedNormal") def random_uniform(shape, minval=0.0, maxval=1.0, - dtype=types.float32, seed=None, + dtype=dtypes.float32, seed=None, name=None): """Outputs random values from a uniform distribution. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index d168245989..1507371b40 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -46,10 +46,10 @@ import tensorflow.python.platform import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import gen_sparse_ops @@ -178,7 +178,7 @@ def sparse_reorder(sp_input, name=None): Reordering does not affect the shape of the `SparseTensor`. - For example, if sp_input has shape `[4, 5]` and `indices` / `values`: + For example, if `sp_input` has shape `[4, 5]` and `indices` / `values`: [0, 3]: b [0, 1]: a @@ -334,8 +334,8 @@ def sparse_to_indicator(sp_input, vocab_size, name=None): rank = indices_shape[1] ids = sp_input.values - if ids.dtype != types.int64: - ids = math_ops.cast(ids, types.int64) + if ids.dtype != dtypes.int64: + ids = math_ops.cast(ids, dtypes.int64) # Slice off the last dimension of indices, then then tack on the ids indices_columns_to_preserve = array_ops.slice( @@ -451,8 +451,8 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None): default_value = ops.convert_to_tensor( default_value, dtype=sp_input.values.dtype) - num_rows = math_ops.cast(sp_input.shape[0], types.int32) - all_row_indices = math_ops.cast(math_ops.range(num_rows), types.int64) + num_rows = math_ops.cast(sp_input.shape[0], dtypes.int32) + all_row_indices = math_ops.cast(math_ops.range(num_rows), dtypes.int64) empty_row_indices, _ = array_ops.list_diff( all_row_indices, sp_input.indices[:, 0]) empty_row_indicator = gen_sparse_ops.sparse_to_dense( diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py index 046625eefa..c6e91dcd71 100644 --- a/tensorflow/python/ops/sparse_ops_test.py +++ b/tensorflow/python/ops/sparse_ops_test.py @@ -23,9 +23,9 @@ import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.ops import constant_op from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import googletest @@ -41,9 +41,9 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): val = np.array([0, 10, 13, 14, 32, 33]) shape = np.array([5, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), + constant_op.constant(ind, dtypes.int64), constant_op.constant(val, dtype), - constant_op.constant(shape, types.int64)) + constant_op.constant(shape, dtypes.int64)) def _SparseTensor_2x3x4(self, dtype): ind = np.array([ @@ -55,13 +55,13 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): val = np.array([1, 10, 12, 103, 111, 113, 122]) shape = np.array([2, 3, 4]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), + constant_op.constant(ind, dtypes.int64), constant_op.constant(val, dtype), - constant_op.constant(shape, types.int64)) + constant_op.constant(shape, dtypes.int64)) def testInt32(self): with self.test_session(use_gpu=False): - sp_input = self._SparseTensor_5x6(types.int32) + sp_input = self._SparseTensor_5x6(dtypes.int32) output = sparse_ops.sparse_to_indicator(sp_input, 50).eval() expected_output = np.zeros((5, 50), dtype=np.bool) @@ -73,7 +73,7 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): def testInt64(self): with self.test_session(use_gpu=False): - sp_input = self._SparseTensor_5x6(types.int64) + sp_input = self._SparseTensor_5x6(dtypes.int64) output = sparse_ops.sparse_to_indicator(sp_input, 50).eval() expected_output = np.zeros((5, 50), dtype=np.bool) @@ -85,7 +85,7 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): def testHigherRank(self): with self.test_session(use_gpu=False): - sp_input = self._SparseTensor_2x3x4(types.int64) + sp_input = self._SparseTensor_2x3x4(dtypes.int64) output = sparse_ops.sparse_to_indicator(sp_input, 200).eval() expected_output = np.zeros((2, 3, 200), dtype=np.bool) @@ -107,9 +107,9 @@ class SparseRetainTest(test_util.TensorFlowTestCase): val = np.array([0, 10, 13, 14, 32, 33]) shape = np.array([5, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), - constant_op.constant(val, types.int32), - constant_op.constant(shape, types.int64)) + constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) def testBasic(self): with self.test_session(use_gpu=False) as sess: @@ -153,9 +153,9 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase): val = np.array([0, 10, 13, 14, 32, 33]) shape = np.array([5, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), - constant_op.constant(val, types.int32), - constant_op.constant(shape, types.int64)) + constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) def _SparseTensor_String5x6(self): ind = np.array([ @@ -165,18 +165,18 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase): val = np.array(["a", "b", "c", "d", "e", "f"]) shape = np.array([5, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), - constant_op.constant(val, types.string), - constant_op.constant(shape, types.int64)) + constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.string), + constant_op.constant(shape, dtypes.int64)) def _SparseTensor_2x6(self): ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4]]) val = np.array([0, 10, 13, 14]) shape = np.array([2, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), - constant_op.constant(val, types.int32), - constant_op.constant(shape, types.int64)) + constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) def testFillNumber(self): with self.test_session(use_gpu=False) as sess: @@ -207,7 +207,8 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase): self.assertAllEqual( output.indices, [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]]) - self.assertAllEqual(output.values, ["a", "b", "c", "d", "", "e", "f", ""]) + self.assertAllEqual(output.values, + [b"a", b"b", b"c", b"d", b"", b"e", b"f", b""]) self.assertAllEqual(output.shape, [5, 6]) self.assertAllEqual(empty_row_indicator_out, np.array([0, 0, 1, 0, 1]).astype(np.bool)) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 8ccbad16c6..c5b490ce4c 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -22,9 +22,9 @@ from __future__ import print_function import contextlib import six +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import init_ops from tensorflow.python.ops import variables from tensorflow.python.platform import logging @@ -45,7 +45,7 @@ class _VariableStore(object): """Create a variable store.""" self._vars = {} # A dictionary of the stored TensorFlow variables. - def get_variable(self, name, shape=None, dtype=types.float32, + def get_variable(self, name, shape=None, dtype=dtypes.float32, initializer=None, reuse=None, trainable=True, collections=None): """Gets an existing variable with these parameters or create a new one. @@ -82,7 +82,7 @@ class _VariableStore(object): or when violating reuse during variable creation. """ should_check = reuse is not None - dtype = types.as_dtype(dtype) + dtype = dtypes.as_dtype(dtype) shape = tensor_shape.as_shape(shape) if name in self._vars: # Here we handle the case when returning an existing variable. @@ -158,7 +158,7 @@ class _VariableScope(object): """Set initializer for this scope.""" self._initializer = initializer - def get_variable(self, var_store, name, shape=None, dtype=types.float32, + def get_variable(self, var_store, name, shape=None, dtype=dtypes.float32, initializer=None, trainable=True, collections=None): """Gets an existing variable with this name or create a new one.""" if initializer is None and self._initializer: @@ -194,7 +194,7 @@ def _get_default_variable_store(): return store -def get_variable(name, shape=None, dtype=types.float32, initializer=None, +def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None, trainable=True, collections=None): """Gets an existing variable with these parameters or create a new one. |