aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-25 08:48:47 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-25 08:48:47 -0800
commit854f49bd43588c062b046384f239f64a3d819702 (patch)
treec2373bf71ef65ae4c116ea703947141281c1eace /tensorflow/python/ops
parent9c3043ff3bf31a6a81810b4ce9e87ef936f1f529 (diff)
TensorFlow: Upstream changes to git
Changes: - Updates to docs - Several changes for Python 3 compatibility - Added license headers Base CL: 108710566
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_ops.py18
-rw-r--r--tensorflow/python/ops/candidate_sampling_ops.py2
-rw-r--r--tensorflow/python/ops/clip_ops.py16
-rw-r--r--tensorflow/python/ops/common_shapes.py8
-rw-r--r--tensorflow/python/ops/constant_op.py10
-rw-r--r--tensorflow/python/ops/control_flow_ops.py19
-rw-r--r--tensorflow/python/ops/data_flow_grad.py11
-rw-r--r--tensorflow/python/ops/data_flow_ops.py10
-rw-r--r--tensorflow/python/ops/embedding_ops.py12
-rw-r--r--tensorflow/python/ops/gradients.py61
-rw-r--r--tensorflow/python/ops/gradients_test.py24
-rw-r--r--tensorflow/python/ops/image_grad.py57
-rw-r--r--tensorflow/python/ops/image_grad_test.py85
-rw-r--r--tensorflow/python/ops/image_ops.py16
-rw-r--r--tensorflow/python/ops/init_ops.py16
-rw-r--r--tensorflow/python/ops/io_ops.py4
-rw-r--r--tensorflow/python/ops/math_grad.py31
-rw-r--r--tensorflow/python/ops/math_ops.py38
-rw-r--r--tensorflow/python/ops/nn.py93
-rw-r--r--tensorflow/python/ops/nn_ops.py2
-rw-r--r--tensorflow/python/ops/nn_test.py112
-rw-r--r--tensorflow/python/ops/numerics.py6
-rw-r--r--tensorflow/python/ops/op_def_library.py43
-rw-r--r--tensorflow/python/ops/op_def_library_test.py340
-rw-r--r--tensorflow/python/ops/parsing_ops.py58
-rw-r--r--tensorflow/python/ops/random_ops.py10
-rw-r--r--tensorflow/python/ops/sparse_ops.py12
-rw-r--r--tensorflow/python/ops/sparse_ops_test.py43
-rw-r--r--tensorflow/python/ops/variable_scope.py10
29 files changed, 687 insertions, 480 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 9cccc8771f..50f3facf2e 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -65,10 +65,10 @@ import sys
import tensorflow.python.platform
import numpy as np
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
@@ -460,9 +460,9 @@ def transpose(a, perm=None, name="transpose"):
[3 6]]
# Equivalently
- tf.transpose(x perm=[0, 1]) ==> [[1 4]
- [2 5]
- [3 6]]
+ tf.transpose(x, perm=[1, 0]) ==> [[1 4]
+ [2 5]
+ [3 6]]
# 'perm' is more useful for n-dimensional tensors, for n > 2
# 'x' is [[[1 2 3]
@@ -502,7 +502,7 @@ def transpose(a, perm=None, name="transpose"):
return ret
-def zeros(shape, dtype=types.float32, name=None):
+def zeros(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to zero.
This operation returns a tensor of type `dtype` with shape `shape` and
@@ -528,7 +528,7 @@ def zeros(shape, dtype=types.float32, name=None):
else:
shape = ops.convert_to_tensor(shape, name="shape")
output = fill(shape, constant(0, dtype=dtype), name=name)
- assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
+ assert output.dtype.base_dtype == dtypes.as_dtype(dtype).base_dtype
return output
@@ -594,12 +594,12 @@ def ones_like(tensor, dtype=None, name=None):
return ones(ones_shape, dtype=dtype, name=name)
-def zeros_initializer(shape, dtype=types.float32):
+def zeros_initializer(shape, dtype=dtypes.float32):
"""An adaptor for zeros() to match the Initializer spec."""
return zeros(shape, dtype)
-def ones(shape, dtype=types.float32, name=None):
+def ones(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to 1.
This operation returns a tensor of type `dtype` with shape `shape` and all
@@ -625,7 +625,7 @@ def ones(shape, dtype=types.float32, name=None):
else:
shape = ops.convert_to_tensor(shape, name="shape")
output = fill(shape, constant(1, dtype=dtype), name=name)
- assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
+ assert output.dtype.base_dtype == dtypes.as_dtype(dtype).base_dtype
return output
diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py
index 0a3e068523..6b8ecc13ea 100644
--- a/tensorflow/python/ops/candidate_sampling_ops.py
+++ b/tensorflow/python/ops/candidate_sampling_ops.py
@@ -311,7 +311,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique,
def compute_accidental_hits(true_classes, sampled_candidates, num_true,
seed=None, name=None):
- """Compute the ids of positions in sampled_candidates matching true_classes.
+ """Compute the position ids in `sampled_candidates` matching `true_classes`.
In Candidate Sampling, this operation facilitates virtually removing
sampled classes which happen to match target classes. This is done
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index da66610e57..a2b39d6594 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -22,8 +22,8 @@ import collections
import six
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import math_ops
@@ -65,7 +65,7 @@ def clip_by_norm(t, clip_norm, name=None):
"""Clips tensor values to a maximum L2-norm.
Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
- normalizes `t` so that its L2-norm is less than or equal to `clip_norm'.
+ normalizes `t` so that its L2-norm is less than or equal to `clip_norm`.
Specifically, if the L2-norm is already less than or equal to `clip_norm`,
then `t` is not modified. If the L2-norm is greater than `clip_norm`, then
this operation returns a tensor of the same type and shape as `t` with its
@@ -146,18 +146,18 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
if you've already computed the global norm for `t_list`, you can specify
the global norm with `use_norm`.
- To perform the clipping, the values t_list[i] are set to:
+ To perform the clipping, the values `t_list[i]` are set to:
- `t_list[i] * clip_norm / max(global_norm, clip_norm)`
+ t_list[i] * clip_norm / max(global_norm, clip_norm)
where:
- `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))`
+ global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))
If `clip_norm > global_norm` then the entries in `t_list` remain as they are,
otherwise they're all shrunk by the global ratio.
- Any of the entries of `t_list` that are of type None are ignored.
+ Any of the entries of `t_list` that are of type `None` are ignored.
This is the correct way to perform gradient clipping (for example, see
R. Pascanu, T. Mikolov, and Y. Bengio, "On the difficulty of training
@@ -219,7 +219,7 @@ def clip_by_average_norm(t, clip_norm, name=None):
Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
normalizes `t` so that its average L2-norm is less than or equal to
- `clip_norm'. Specifically, if the average L2-norm is already less than or
+ `clip_norm`. Specifically, if the average L2-norm is already less than or
equal to `clip_norm`, then `t` is not modified. If the average L2-norm is
greater than `clip_norm`, then this operation returns a tensor of the same
type and shape as `t` with its values set to:
@@ -244,7 +244,7 @@ def clip_by_average_norm(t, clip_norm, name=None):
# Calculate L2-norm per element, clip elements by ratio of clip_norm to
# L2-norm per element
- n_element = math_ops.cast(array_ops.size(t), types.float32)
+ n_element = math_ops.cast(array_ops.size(t), dtypes.float32)
l2norm_inv = math_ops.rsqrt(
math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
tclip = array_ops.identity(
diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py
index b3decb2651..c0615f1d3b 100644
--- a/tensorflow/python/ops/common_shapes.py
+++ b/tensorflow/python/ops/common_shapes.py
@@ -130,10 +130,10 @@ def get2d_conv_output_size(input_height, input_width, filter_height,
# Compute number of rows in the output, based on the padding.
if input_height.value is None or filter_height.value is None:
out_rows = None
- elif padding_type == "VALID":
+ elif padding_type == b"VALID":
out_rows = ((input_height.value - filter_height.value + row_stride) //
row_stride)
- elif padding_type == "SAME":
+ elif padding_type == b"SAME":
out_rows = (input_height.value + row_stride - 1) // row_stride
else:
raise ValueError("Invalid value for padding: %r" % padding_type)
@@ -141,10 +141,10 @@ def get2d_conv_output_size(input_height, input_width, filter_height,
# Compute number of columns in the output, based on the padding.
if input_width.value is None or filter_width.value is None:
out_cols = None
- elif padding_type == "VALID":
+ elif padding_type == b"VALID":
out_cols = ((input_width.value - filter_width.value + col_stride) //
col_stride)
- elif padding_type == "SAME":
+ elif padding_type == b"SAME":
out_cols = (input_width.value + col_stride - 1) // col_stride
return out_rows, out_cols
diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py
index 9cf5e99a1e..f2aaad37a9 100644
--- a/tensorflow/python/ops/constant_op.py
+++ b/tensorflow/python/ops/constant_op.py
@@ -105,10 +105,10 @@ import tensorflow.python.platform
import numpy as np
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
def constant(value, dtype=None, shape=None, name="Const"):
@@ -182,10 +182,10 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None):
raise ValueError(
"Cannot convert a partially known TensorShape to a Tensor: %s" % s)
if dtype is not None:
- if dtype not in (types.int32, types.int64):
+ if dtype not in (dtypes.int32, dtypes.int64):
raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
else:
- dtype = types.int32
+ dtype = dtypes.int32
if name is None:
name = "shape_as_tensor"
return constant(s.as_list(), dtype=dtype, name=name)
@@ -197,10 +197,10 @@ def _dimension_tensor_conversion_function(d, dtype=None, name=None):
if d.value is None:
raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d)
if dtype is not None:
- if dtype not in (types.int32, types.int64):
+ if dtype not in (dtypes.int32, dtypes.int64):
raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
else:
- dtype = types.int32
+ dtype = dtypes.int32
if name is None:
name = "shape_as_tensor"
return constant(d.value, dtype=dtype, name=name)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index cd15d921cf..8eb1bd79bf 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -69,9 +69,9 @@ from __future__ import print_function
import six
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import constant_op
@@ -443,7 +443,7 @@ def _GetRealValue(value):
# The begin position of the slice at slice_index.
slice_index = forward_ctxt.grad_context.index
- b1 = array_ops.zeros(elem_rank_vec, dtype=types.int32)
+ b1 = array_ops.zeros(elem_rank_vec, dtype=dtypes.int32)
b = array_ops.concat(0, [array_ops.expand_dims(slice_index, 0), b1])
# The slice at slice_index.
@@ -1134,7 +1134,7 @@ def _AsTensorList(x, p):
Returns:
A list of Tensors or IndexedSlices.
"""
- if not isinstance(x, list) and not isinstance(x, _basetuple):
+ if not isinstance(x, (list, _basetuple)):
x = [x]
l = []
@@ -1249,7 +1249,10 @@ def group(*inputs, **kwargs):
# 2-level tree. The root node is the returned NoOp node.
# deps contains 1 NoOp node for each device.
deps = []
- for dev in sorted(six.iterkeys(ops_on_device)):
+ def device_key(dev):
+ """A sort key that allows None to be compared to strings."""
+ return "" if dev is None else dev
+ for dev in sorted(six.iterkeys(ops_on_device), key=device_key):
deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
return _GroupControlDeps(None, deps, name=name)
@@ -1332,7 +1335,7 @@ def fold(fn, elems, elem_shape, name=None):
d0 = elem_shape[0]
n = math_ops.div(s0, d0)
b1 = array_ops.zeros(array_ops.expand_dims(array_ops.rank(elems) - 1, 0),
- dtype=types.int32)
+ dtype=dtypes.int32)
# Initialize the output with slice 0
b = array_ops.concat(0, [[0], b1])
o = array_ops.slice(elems, b, elem_shape)
@@ -1374,7 +1377,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
Expressions:
```
- f1 = lambda: tf.onstant(17)
+ f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
r = case([(tf.less(x, y), f1)], default=f2)
```
@@ -1428,7 +1431,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
if not isinstance(tup, _basetuple) or len(tup) != 2:
raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
pred, fn = tup
- if pred.dtype != types.bool:
+ if pred.dtype != dtypes.bool:
raise TypeError("pred must be of type bool: %s", pred.name)
if not callable(fn):
raise TypeError("fn for pred %s must be callable." % pred.name)
@@ -1468,7 +1471,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
# TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds))
preds_c = array_ops.concat(0, preds, name="preds_c")
num_true_conditions = math_ops.reduce_sum(
- math_ops.cast(preds_c, types.int32), name="num_true_conds")
+ math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
at_most_one_true_condition = math_ops.less(
num_true_conditions, constant_op.constant(2, name="two_true_conds"))
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
index 219a503fd7..c46c29ed57 100644
--- a/tensorflow/python/ops/data_flow_grad.py
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -19,8 +19,8 @@ from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import data_flow_ops
@@ -36,8 +36,8 @@ def _DynamicStitchGrads(op, grad):
indices_grad = [None] * num_values
def AsInt32(x):
- return (x if op.inputs[0].dtype == types.int32 else
- math_ops.cast(x, types.int32))
+ return (x if op.inputs[0].dtype == dtypes.int32 else
+ math_ops.cast(x, dtypes.int32))
inputs = [AsInt32(op.inputs[i]) for i in xrange(num_values)]
if isinstance(grad, ops.IndexedSlices):
output_shape = array_ops.shape(op.outputs[0])
@@ -54,3 +54,8 @@ ops.NoGradient("QueueDequeue")
ops.NoGradient("QueueDequeueMany")
ops.NoGradient("QueueClose")
ops.NoGradient("QueueSize")
+
+ops.NoGradient("Stack")
+ops.NoGradient("StackPush")
+ops.NoGradient("StackPop")
+ops.NoGradient("StackClose")
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index c50fad3211..261193715c 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -125,12 +125,12 @@ class QueueBase(object):
A `QueueBase` object.
Raises:
- TypeError: when `queues` is not a list of `QueueBase` objects,
+ TypeError: When `queues` is not a list of `QueueBase` objects,
or when the data types of `queues` are not all the same.
"""
if ((not queues) or
(not isinstance(queues, list)) or
- (not all([isinstance(x, QueueBase) for x in queues]))):
+ (not all(isinstance(x, QueueBase) for x in queues))):
raise TypeError("A list of queues expected")
dtypes = queues[0].dtypes
@@ -458,6 +458,12 @@ ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape)
ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape)
+ops.RegisterShape("Stack")(common_shapes.scalar_shape)
+ops.RegisterShape("StackPush")(common_shapes.unknown_shape)
+ops.RegisterShape("StackPop")(common_shapes.unknown_shape)
+ops.RegisterShape("StackClose")(common_shapes.unknown_shape)
+
+
@ops.RegisterShape("QueueClose")
def _ScalarToVoidShape(op):
"""Shape function for ops that take a scalar and produce no outputs."""
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 9fb1cf99f2..ff39c3c751 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -19,8 +19,8 @@ from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
@@ -43,7 +43,7 @@ def embedding_lookup(params, ids, name=None):
Args:
params: A list of tensors with the same shape and type.
- ids: A `Tensor` with type `int32` containing the ids to be looked
+ ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
up in `params`.
name: A name for the operation (optional).
@@ -69,8 +69,8 @@ def embedding_lookup(params, ids, name=None):
original_indices = math_ops.range(array_ops.size(flat_ids))
# Compute flat_ids % partitions for each id
ids_mod_p = flat_ids % np
- if ids_mod_p.dtype != types.int32:
- ids_mod_p = math_ops.cast(ids_mod_p, types.int32)
+ if ids_mod_p.dtype != dtypes.int32:
+ ids_mod_p = math_ops.cast(ids_mod_p, dtypes.int32)
# Partition single list of ids based on ids % np into np separate lists
plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np)
# Similarly, partition the original indices.
@@ -178,8 +178,8 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights,
with ops.op_scope(params + [sp_ids], name, "embedding_lookup_sparse") as name:
segment_ids = sp_ids.indices[:, 0]
- if segment_ids.dtype != types.int32:
- segment_ids = math_ops.cast(segment_ids, types.int32)
+ if segment_ids.dtype != dtypes.int32:
+ segment_ids = math_ops.cast(segment_ids, dtypes.int32)
ids = sp_ids.values
if ignore_weights:
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 72c7bc2499..b790d9af6c 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Implements the graph generation for computation of gradients."""
from __future__ import absolute_import
@@ -27,16 +26,17 @@ import tensorflow.python.platform
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
# pylint: disable=unused-import
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_grad
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import image_grad
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import linalg_grad
from tensorflow.python.ops import math_grad
@@ -45,7 +45,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.platform import logging
-
# Warn the user if we convert a sparse representation to dense with at
# least this number of elements.
_LARGE_SPARSE_NUM_ELEMENTS = 100000000
@@ -69,8 +68,8 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None):
"""
if dtype and not dtype.is_compatible_with(value.dtype):
raise ValueError(
- "Tensor conversion requested dtype %s for IndexedSlices with dtype %s"
- % (dtype.name, value.dtype.name))
+ "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" %
+ (dtype.name, value.dtype.name))
if value.dense_shape is None:
raise ValueError(
"Tensor conversion requested for IndexedSlices without dense_shape: %s"
@@ -88,11 +87,14 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None):
warnings.warn(
"Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
"This may consume a large amount of memory.")
- return math_ops.unsorted_segment_sum(
- value.values, value.indices, value.dense_shape[0], name=name)
+ return math_ops.unsorted_segment_sum(value.values,
+ value.indices,
+ value.dense_shape[0],
+ name=name)
-ops.register_tensor_conversion_function(ops.IndexedSlices, _IndexedSlicesToTensor)
+ops.register_tensor_conversion_function(ops.IndexedSlices,
+ _IndexedSlicesToTensor)
def _MarkReachedOps(from_ops, reached_ops):
@@ -236,14 +238,16 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
y = ys[i]
if grad_y is None:
with ops.device(_GetGradsDevice(y.op, colocate_gradients_with_ops)):
- grad_ys[i] = array_ops.fill(array_ops.shape(y),
- constant_op.constant(1, dtype=y.dtype))
+ grad_ys[i] = array_ops.fill(
+ array_ops.shape(y),
+ constant_op.constant(1,
+ dtype=y.dtype))
else:
if grad_y.dtype != y.dtype:
raise ValueError("Y and ys_grad must be of the same type, "
"not y: %s, ys_grad: %s " %
- (types.as_dtype(y.dtype).name,
- types.as_dtype(grad_y.dtype).name))
+ (dtypes.as_dtype(y.dtype).name,
+ dtypes.as_dtype(grad_y.dtype).name))
return grad_ys
@@ -265,11 +269,10 @@ def _VerifyGeneratedGradients(grads, op):
inp = op.inputs[i]
if grad is not None:
if not grad.dtype.is_compatible_with(inp.dtype):
- raise ValueError(
- "Gradient type %s generated for op %s does "
- "not match input type %s" %
- (types.as_dtype(grad.dtype).name, op.node_def,
- types.as_dtype(inp.dtype).name))
+ raise ValueError("Gradient type %s generated for op %s does "
+ "not match input type %s" %
+ (dtypes.as_dtype(grad.dtype).name, op.node_def,
+ dtypes.as_dtype(inp.dtype).name))
def _StopOps(from_ops, pending_count):
@@ -301,7 +304,10 @@ def _StopOps(from_ops, pending_count):
return stop_ops
-def gradients(ys, xs, grad_ys=None, name="gradients",
+def gradients(ys,
+ xs,
+ grad_ys=None,
+ name="gradients",
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None):
@@ -319,7 +325,7 @@ def gradients(ys, xs, grad_ys=None, name="gradients",
`grad_ys` is a list of tensors of the same length as `ys` that holds
the initial gradients for each y in `ys`. When `grad_ys` is None,
we fill in a tensor of '1's of the shape of y for each y in `ys`. A
- user can provide their own initial 'grad_ys` to compute the
+ user can provide their own initial `grad_ys` to compute the
derivatives using a different initial gradient for each y (e.g., if
one wanted to weight the gradient differently for each value in
each y).
@@ -369,8 +375,8 @@ def gradients(ys, xs, grad_ys=None, name="gradients",
# to the xs.
to_ops = [t.op for t in ys]
from_ops = [t.op for t in xs]
- pending_count, has_control_flow = _PendingCount(
- ops.get_default_graph(), to_ops, from_ops)
+ pending_count, has_control_flow = _PendingCount(ops.get_default_graph(),
+ to_ops, from_ops)
# Iterate over the collected ops.
#
@@ -418,9 +424,9 @@ def gradients(ys, xs, grad_ys=None, name="gradients",
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
- if (not out_grad
- and types.as_dtype(op.outputs[i].dtype).base_dtype in (
- types.float32, types.float64)):
+ if (not out_grad and
+ dtypes.as_dtype(op.outputs[i].dtype).base_dtype in
+ (dtypes.float32, dtypes.float64)):
# Only floating-point outputs get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
out_grads[i] = array_ops.zeros_like(op.outputs[i])
@@ -485,7 +491,8 @@ def _GetGrad(grads, t):
"""Gets gradient for tensor "t"."""
op = t.op
op_grads = grads.get(op)
- if not op_grads: return None
+ if not op_grads:
+ return None
t_grad = op_grads[t.value_index]
assert not isinstance(t_grad, list), (
"gradients list should have been aggregated by now.")
@@ -622,8 +629,8 @@ def _AggregatedGrads(grads, op, has_control_flow, aggregation_method=None):
# indices.
out_grads[i] = ops.IndexedSlices(
array_ops.concat(0, [x.values for x in out_grad]),
- array_ops.concat(0, [x.indices for x in out_grad]),
- out_grad[0].dense_shape)
+ array_ops.concat(0, [x.indices
+ for x in out_grad]), out_grad[0].dense_shape)
else:
out_grads[i] = []
return out_grads
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 9a87319697..5afc8a779b 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -24,9 +24,9 @@ import tensorflow.python.platform
import numpy as np
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
-from tensorflow.python.framework import types
# pylint: disable=unused-import
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import array_ops
@@ -68,7 +68,7 @@ def _OpsBetween(graph, to_ops, from_ops):
reached_ops[op._id] = True
gradients._MarkReachedOps(from_ops, reached_ops)
between_ops = gradients._GatherInputs(to_ops, reached_ops)
- between_ops.sort(lambda x, y: y._id - x._id)
+ between_ops.sort(key=lambda x: -x._id)
return between_ops
@@ -246,13 +246,13 @@ class GradientsTest(test_util.TensorFlowTestCase):
@ops.RegisterGradient("TestOp")
def _TestOpGrad(op, float_grad, string_grad):
"""Gradient function for TestOp."""
- self.assertEquals(float_grad.dtype, types.float32)
+ self.assertEquals(float_grad.dtype, dtypes.float32)
self.assertFalse(string_grad)
return float_grad
ops.RegisterShape("TestOp")(None)
c = constant(1.0)
- x, y = g.create_op("TestOp", [c], [types.float32, types.string]).outputs
+ x, y = g.create_op("TestOp", [c], [dtypes.float32, dtypes.string]).outputs
z = x * 2.0
w = z * 3.0
grads = gradients.gradients(z, [c])
@@ -314,7 +314,7 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
c = constant_op.constant(np_val)
c_sparse = math_ops._as_indexed_slices(c)
c_sparse = ops.IndexedSlices(
- c_sparse.values, math_ops.cast(c_sparse.indices, types.int64),
+ c_sparse.values, math_ops.cast(c_sparse.indices, dtypes.int64),
c_sparse.dense_shape)
self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
c_dense = math_ops.mul(c_sparse, 1.0)
@@ -322,16 +322,16 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
def testWarnings(self):
# Smaller than the threshold: no warning.
- c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
- array_ops.placeholder(types.int32),
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.int32),
constant([4, 4, 4, 4]))
with warnings.catch_warnings(record=True) as w:
math_ops.mul(c_sparse, 1.0)
self.assertEqual(0, len(w))
# Greater than or equal to the threshold: warning.
- c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
- array_ops.placeholder(types.int32),
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.int32),
constant([100, 100, 100, 100]))
with warnings.catch_warnings(record=True) as w:
math_ops.mul(c_sparse, 1.0)
@@ -341,9 +341,9 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
in str(w[0].message))
# Unknown dense shape: warning.
- c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
- array_ops.placeholder(types.int32),
- array_ops.placeholder(types.int32))
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.int32),
+ array_ops.placeholder(dtypes.int32))
with warnings.catch_warnings(record=True) as w:
math_ops.mul(c_sparse, 1.0)
self.assertEqual(1, len(w))
diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py
new file mode 100644
index 0000000000..55f1c74f0e
--- /dev/null
+++ b/tensorflow/python/ops/image_grad.py
@@ -0,0 +1,57 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Contains Gradient functions for image ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_image_ops
+
+
+@ops.RegisterGradient("ResizeNearestNeighbor")
+def _ResizeNearestNeighborGrad(op, grad):
+ """The derivatives for nearest neighbor resizing.
+
+ Args:
+ op: The ResizeNearestNeighbor op.
+ grad: The tensor representing the gradient w.r.t. the output.
+
+ Returns:
+ The gradients w.r.t. the input and the output.
+ """
+ grads = gen_image_ops.resize_nearest_neighbor_grad(
+ grad, op.inputs[0].get_shape()[1:3])
+ return [grads, None]
+
+
+@ops.RegisterShape("ResizeNearestNeighborGrad")
+def _ResizeShape(op):
+ """Shape function for the resize grad ops."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ size = tensor_util.ConstantValue(op.inputs[1])
+ if size is not None:
+ height = size[0]
+ width = size[1]
+ else:
+ height = None
+ width = None
+ return [
+ tensor_shape.TensorShape([input_shape[0], height, width, input_shape[3]])
+ ]
diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py
new file mode 100644
index 0000000000..488e741e5b
--- /dev/null
+++ b/tensorflow/python/ops/image_grad_test.py
@@ -0,0 +1,85 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Python ops defined in image_grad.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=g-bad-import-order,
+# pylint: disable=unused-import
+import tensorflow.python.platform
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+import numpy as np
+import tensorflow as tf
+# pylint: enable=g-bad-import-order
+# pylint: enable=unused-import
+
+
+class ResizeNearestNeighborOpTest(tf.test.TestCase):
+
+ def testShapeIsCorrectAfterOp(self):
+ in_shape = [1, 2, 2, 1]
+ out_shape = [1, 4, 6, 1]
+
+ x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
+
+ with self.test_session() as sess:
+ input_tensor = tf.constant(x, shape=in_shape)
+ resize_out = tf.image.resize_nearest_neighbor(input_tensor,
+ out_shape[1:3])
+ self.assertEqual(out_shape, list(resize_out.get_shape()))
+
+ resize_out = sess.run(resize_out)
+ self.assertEqual(out_shape, list(resize_out.shape))
+
+ def testGradFromResizeToLargerInBothDims(self):
+ in_shape = [1, 2, 3, 1]
+ out_shape = [1, 4, 6, 1]
+
+ x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
+
+ with self.test_session():
+ input_tensor = tf.constant(x, shape=in_shape)
+ resize_out = tf.image.resize_nearest_neighbor(input_tensor,
+ out_shape[1:3])
+ err = gc.ComputeGradientError(input_tensor,
+ in_shape,
+ resize_out,
+ out_shape,
+ x_init_value=x)
+ self.assertLess(err, 1e-3)
+
+ def testGradFromResizeToSmallerInBothDims(self):
+ in_shape = [1, 4, 6, 1]
+ out_shape = [1, 2, 3, 1]
+
+ x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
+
+ with self.test_session():
+ input_tensor = tf.constant(x, shape=in_shape)
+ resize_out = tf.image.resize_nearest_neighbor(input_tensor,
+ out_shape[1:3])
+ err = gc.ComputeGradientError(input_tensor,
+ in_shape,
+ resize_out,
+ out_shape,
+ x_init_value=x)
+ self.assertLess(err, 1e-3)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index a95e8e9765..7be02b220f 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -55,10 +55,6 @@ image = tf.image.decode_jpeg(...)
resized_image = tf.image.resize_bilinear(image, [299, 299])
```
-<i>Maybe refer to the Queue examples that show how to add images to a Queue
-after resizing them to a fixed size, and how to dequeue batches of resized
-images from the Queue.</i>
-
@@resize_images
@@resize_area
@@ -109,11 +105,11 @@ import math
import tensorflow.python.platform
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import common_shapes
@@ -354,7 +350,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
This op cuts a rectangular part out of `image`. The top-left corner of the
returned image is at `offset_height, offset_width` in `image`, and its
lower-right corner is at
- `offset_height + target_height, offset_width + target_width'.
+ `offset_height + target_height, offset_width + target_width`.
Args:
image: 3-D tensor with shape `[height, width, channels]`
@@ -554,7 +550,7 @@ def per_image_whitening(image):
height, width, depth = _ImageDimensions(image)
num_pixels = height * width * depth
- image = math_ops.cast(image, dtype=types.float32)
+ image = math_ops.cast(image, dtype=dtypes.float32)
image_mean = math_ops.reduce_mean(image)
variance = (math_ops.reduce_mean(math_ops.square(image)) -
@@ -592,7 +588,7 @@ def random_brightness(image, max_delta, seed=None):
3-D tensor of images of shape `[height, width, channels]`
Raises:
- ValueError: if max_delta is negative.
+ ValueError: if `max_delta` is negative.
"""
_Check3DImage(image)
@@ -664,8 +660,8 @@ def adjust_brightness(image, delta, min_value=None, max_value=None):
with ops.op_scope([image, delta, min_value, max_value], None,
'adjust_brightness') as name:
adjusted = math_ops.add(
- math_ops.cast(image, types.float32),
- math_ops.cast(delta, types.float32),
+ math_ops.cast(image, dtypes.float32),
+ math_ops.cast(delta, dtypes.float32),
name=name)
if image.dtype.is_integer:
rounded = math_ops.round(adjusted)
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 141fb8c13a..8c52a232cb 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
import math
-from tensorflow.python.framework import types
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import math_ops
@@ -38,7 +38,7 @@ def constant_initializer(value=0.0):
Returns:
An initializer that generates tensors with a single value.
"""
- def _initializer(shape, dtype=types.float32):
+ def _initializer(shape, dtype=dtypes.float32):
return constant_op.constant(value, dtype=dtype, shape=shape)
return _initializer
@@ -57,7 +57,7 @@ def random_uniform_initializer(minval=0.0, maxval=1.0, seed=None):
Returns:
An initializer that generates tensors with a uniform distribution.
"""
- def _initializer(shape, dtype=types.float32):
+ def _initializer(shape, dtype=dtypes.float32):
return random_ops.random_uniform(shape, minval, maxval, dtype, seed=seed)
return _initializer
@@ -76,7 +76,7 @@ def random_normal_initializer(mean=0.0, stddev=1.0, seed=None):
Returns:
An initializer that generates tensors with a normal distribution.
"""
- def _initializer(shape, dtype=types.float32):
+ def _initializer(shape, dtype=dtypes.float32):
return random_ops.random_normal(shape, mean, stddev, dtype, seed=seed)
return _initializer
@@ -101,7 +101,7 @@ def truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None):
An initializer that generates tensors with a truncated normal
distribution.
"""
- def _initializer(shape, dtype=types.float32):
+ def _initializer(shape, dtype=dtypes.float32):
return random_ops.truncated_normal(shape, mean, stddev, dtype, seed=seed)
return _initializer
@@ -132,7 +132,7 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None):
Returns:
An initializer that generates tensors with unit variance.
"""
- def _initializer(shape, dtype=types.float32):
+ def _initializer(shape, dtype=dtypes.float32):
input_size = 1.0
# Estimating input size is not possible to do perfectly, but we try.
# The estimate, obtained by multiplying all dimensions but the last one,
@@ -146,7 +146,7 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None):
# TODO(vrv): Unhide when we are ready to expose this publicly.
-def _random_walk(shape, nonlinearity, dtype=types.float32, seed=None,
+def _random_walk(shape, nonlinearity, dtype=dtypes.float32, seed=None,
name="random_walk"):
"""Create a random tensor such that backprop neither vanishes nor explodes.
@@ -200,7 +200,7 @@ class _RandomWalkInitializer(object):
self._nonlinearity = nonlinearity
self._seed = seed
- def __call__(self, shape, dtype=types.float32):
+ def __call__(self, shape, dtype=dtypes.float32):
"""Generate a tensor used to initialize a variable."""
return random_ops._random_walk(shape, self._nonlinearity, dtype,
seed=self._seed)
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 69d4ee6266..97b491bd53 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -121,9 +121,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import types
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_io_ops
# pylint: disable=wildcard-import
@@ -183,7 +183,7 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
Returns:
A tensor of type "tensor_type".
"""
- base_type = types.as_dtype(tensor_type).base_dtype
+ base_type = dtypes.as_dtype(tensor_type).base_dtype
return gen_io_ops._restore_slice(
file_pattern, tensor_name, shape_and_slice, base_type,
preferred_shard, name=name)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 9f9e0943ea..12444916ce 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import data_flow_ops
@@ -288,7 +288,7 @@ def _MulGrad(op, grad):
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
- if x.dtype.base_dtype == types.complex64:
+ if x.dtype.base_dtype == dtypes.complex64:
return (array_ops.reshape(math_ops.reduce_sum(grad * math_ops.conj(y), rx), sx),
array_ops.reshape(math_ops.reduce_sum(math_ops.conj(x) * grad, ry), sy))
else:
@@ -412,29 +412,14 @@ def _SparseMatMulGrad(op, grad):
assert t1 in is_sparse and t2 in is_sparse
t1_sparse = is_sparse[t1]
t2_sparse = is_sparse[t2]
- if not t1_sparse and not t2_sparse:
- return math_ops.matmul(t1, t2,
- transpose_a=transpose_a,
- transpose_b=transpose_b)
- transpose_out = False
- if not t1_sparse:
- transpose_out = True
- t1, t2 = t2, t1
- t1_sparse, t2_sparse = t2_sparse, t1_sparse
- assert t1_sparse
- transpose_a, transpose_b = not transpose_b, not transpose_a
-
if transpose_b:
t2 = array_ops.transpose(t2)
transpose_b = False
- m = math_ops.matmul(t1, t2,
- transpose_a=transpose_a,
- transpose_b=transpose_b,
- a_is_sparse=t1_sparse,
- b_is_sparse=t2_sparse)
- if transpose_out:
- m = array_ops.transpose(m)
- return m
+ return math_ops.matmul(t1, t2,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ a_is_sparse=t1_sparse,
+ b_is_sparse=t2_sparse)
if not t_a and not t_b:
return (_SparseMatMul(grad, op.inputs[1], transpose_b=True),
@@ -515,7 +500,7 @@ def _ConjGrad(_, grad):
@ops.RegisterGradient("Cast")
def _CastGrad(op, grad):
- t = [types.float32, types.float64, types.bfloat16]
+ t = [dtypes.float32, dtypes.float64, dtypes.bfloat16]
src_type = op.inputs[0].dtype.base_dtype
dst_type = grad.dtype.base_dtype
if src_type in t and dst_type in t:
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 29c2fcf432..2555b4ba17 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -156,10 +156,10 @@ import tensorflow.python.platform
import numpy as np
import six.moves
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_math_ops
@@ -196,7 +196,7 @@ def abs(x, name=None):
"""
with ops.op_scope([x], name, "Abs") as name:
x = ops.convert_to_tensor(x, name="x")
- if x.dtype == types.complex64:
+ if x.dtype == dtypes.complex64:
return gen_math_ops.complex_abs(x, name=name)
return gen_math_ops._abs(x, name=name)
@@ -333,7 +333,7 @@ def to_float(x, name="ToFloat"):
Raises:
TypeError: If `x` cannot be cast to the `float32`.
"""
- return cast(x, types.float32, name=name)
+ return cast(x, dtypes.float32, name=name)
def to_double(x, name="ToDouble"):
@@ -349,7 +349,7 @@ def to_double(x, name="ToDouble"):
Raises:
TypeError: If `x` cannot be cast to the `float64`.
"""
- return cast(x, types.float64, name=name)
+ return cast(x, dtypes.float64, name=name)
def to_int32(x, name="ToInt32"):
@@ -365,7 +365,7 @@ def to_int32(x, name="ToInt32"):
Raises:
TypeError: If `x` cannot be cast to the `int32`.
"""
- return cast(x, types.int32, name=name)
+ return cast(x, dtypes.int32, name=name)
def to_int64(x, name="ToInt64"):
@@ -381,7 +381,7 @@ def to_int64(x, name="ToInt64"):
Raises:
TypeError: If `x` cannot be cast to the `int64`.
"""
- return cast(x, types.int64, name=name)
+ return cast(x, dtypes.int64, name=name)
def to_bfloat16(x, name="ToBFloat16"):
@@ -397,7 +397,7 @@ def to_bfloat16(x, name="ToBFloat16"):
Raises:
TypeError: If `x` cannot be cast to the `bfloat16`.
"""
- return cast(x, types.bfloat16, name=name)
+ return cast(x, dtypes.bfloat16, name=name)
ops.Tensor._override_operator("__neg__", neg)
@@ -437,14 +437,14 @@ def _OverrideBinaryOperatorHelper(func, op_name):
# Conversion table for __truediv__. None entries mean no conversion required.
_TRUEDIV_TABLE = {
- types.uint8: types.float32,
- types.int8: types.float32,
- types.int16: types.float32,
- types.int32: types.float64,
- types.int64: types.float64,
- types.float32: None,
- types.float64: None,
- types.complex64: None,
+ dtypes.uint8: dtypes.float32,
+ dtypes.int8: dtypes.float32,
+ dtypes.int16: dtypes.float32,
+ dtypes.int32: dtypes.float64,
+ dtypes.int64: dtypes.float64,
+ dtypes.float32: None,
+ dtypes.float64: None,
+ dtypes.complex64: None,
}
@@ -891,7 +891,7 @@ def matmul(a, b,
with ops.op_scope([a, b], name, "MatMul") as name:
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
- if a.dtype == types.float32 and (a_is_sparse or b_is_sparse):
+ if a.dtype == dtypes.float32 and (a_is_sparse or b_is_sparse):
return sparse_matmul(a, b,
transpose_a=transpose_a,
transpose_b=transpose_b,
@@ -953,14 +953,14 @@ def _as_indexed_slices_list(inputs):
raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
outputs = [_as_indexed_slices(i) for i in inputs]
with_int32_index = [o.indices for o in outputs
- if o.indices.dtype == types.int32]
+ if o.indices.dtype == dtypes.int32]
if not with_int32_index or len(with_int32_index) == len(outputs):
return outputs
casted_outputs = []
for o in outputs:
- if o.indices.dtype == types.int32:
+ if o.indices.dtype == dtypes.int32:
casted_outputs.append(
- ops.IndexedSlices(o.values, cast(o.indices, types.int64),
+ ops.IndexedSlices(o.values, cast(o.indices, dtypes.int64),
o.dense_shape))
else:
casted_outputs.append(o)
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 749faaf73a..17160f909e 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -204,9 +204,9 @@ from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import candidate_sampling_ops
from tensorflow.python.ops import constant_op
@@ -360,7 +360,7 @@ def zero_fraction(value, name=None):
value = ops.convert_to_tensor(value, name="value")
zero = constant_op.constant(0, dtype=value.dtype, name="zero")
return math_ops.reduce_mean(math_ops.cast(math_ops.equal(value, zero),
- types.float32))
+ dtypes.float32))
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
@@ -633,35 +633,42 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
sum to 1 per-example.
Args:
- weights: tensor of label embeddings with shape = [num_classes, dim].
- biases: tensor of num_classes label biases
- inputs: tensor with shape = [batch_size, dim] corresponding to forward
- activations of the input network
- labels: int tensor with shape [batch_size, num_true]
- num_sampled: number of label classes to sample per batch
- num_classes: number of possible label classes in the data (e.g. vocab size)
- num_true: number of target classes per example (default: 1)
+ weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
+ objects whose concatenation along dimension 0 has shape
+ `[num_classes, dim]`. The (possibly-sharded) class embeddings.
+ biases: A `Tensor` of shape `[num_classes]`. The class biases.
+ inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
+ activations of the input network.
+ labels: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes. Note that this format differs from
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_classes: An `int`. The number of possible classes.
+ num_true: An `int`. The number of target classes per training example.
sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
- `sampled_expected_count`) returned by a `*_candidate_sampler` function
- to use (if None, we default to `log_uniform_candidate_sampler`)
- subtract_log_q: subtract the log expected count of the labels in the sample
- to get the logits of the true labels (default: True)
- Turn off for Negative Sampling.
- remove_accidental_hits: whether to remove "accidental hits" where a sampled
- label equals the true labels (bool, default: False)
+ `sampled_expected_count`) returned by a `*_candidate_sampler` function.
+ (if None, we default to `log_uniform_candidate_sampler`)
+ subtract_log_q: A `bool`. whether to subtract the log expected count of
+ the labels in the sample to get the logits of the true labels.
+ Default is True. Turn off for Negative Sampling.
+ remove_accidental_hits: A `bool`. whether to remove "accidental hits"
+ where a sampled class equals one of the target classes. Default is
+ False.
name: A name for the operation (optional).
-
Returns:
- out_logits, out_labels: tensors with shape
- `[batch_size, num_true + num_sampled]` for passing to either
- `sigmoid_cross_entropy_with_logits` (NCE)
- or `softmax_cross_entropy_with_logits` (sampled softmax).
+ out_logits, out_labels: `Tensor` objects each with shape
+ `[batch_size, num_true + num_sampled]`, for passing to either
+ `nn.sigmoid_cross_entropy_with_logits` (NCE) or
+ `nn.softmax_cross_entropy_with_logits` (sampled softmax).
"""
+ if not isinstance(weights, list):
+ weights = [weights]
+
with ops.op_scope(
- [weights, biases, inputs, labels], name, "compute_sampled_logits"):
- if labels.dtype != types.int64:
- labels = math_ops.cast(labels, types.int64)
+ weights + [biases, inputs, labels], name, "compute_sampled_logits"):
+ if labels.dtype != dtypes.int64:
+ labels = math_ops.cast(labels, dtypes.int64)
labels_flat = array_ops.reshape(labels, [-1])
# Sample the negative labels.
@@ -726,7 +733,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
# This is how SparseToDense expects the indices.
acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(
- acc_ids, types.int32), [-1, 1])
+ acc_ids, dtypes.int32), [-1, 1])
sparse_indices = array_ops.concat(
1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
# Create sampled_logits_shape = [batch_size, num_sampled]
@@ -777,17 +784,19 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
with an otherwise unused class.
Args:
- weights: A `Tensor` of shape [num_classes, dim]. The class embeddings.
- biases: A `Tensor` of shape [num_classes]. The class biases.
- inputs: A `Tensor` of shape [batch_size, dim]. The forward
+ weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
+ objects whose concatenation along dimension 0 has shape
+ [num_classes, dim]. The (possibly-sharded) class embeddings.
+ biases: A `Tensor` of shape `[num_classes]`. The class biases.
+ inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
labels: A `Tensor` of type `int64` and shape `[batch_size,
- num_true]`. The target classes.
+ num_true]`. The target classes.
num_sampled: An `int`. The number of classes to randomly sample per batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
- sampled_values: a tuple of `(sampled_candidates, true_expected_count,
- sampled_expected_count)` returned by a `*_candidate_sampler` function.
+ sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
+ `sampled_expected_count`) returned by a `*_candidate_sampler` function.
(if None, we default to `log_uniform_candidate_sampler`)
remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
where a sampled class equals one of the target classes. If set to
@@ -799,7 +808,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
name: A name for the operation (optional).
Returns:
- A batch_size 1-D tensor of per-example NCE losses.
+ A `batch_size` 1-D tensor of per-example NCE losses.
"""
logits, labels = _compute_sampled_logits(
weights, biases, inputs, labels, num_sampled, num_classes,
@@ -838,18 +847,20 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
Args:
- weights: A `Tensor` of shape [num_classes, dim]. The class embeddings.
- biases: A `Tensor` of shape [num_classes]. The class biases.
- inputs: A `Tensor` of shape [batch_size, dim]. The forward
+ weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
+ objects whose concatenation along dimension 0 has shape
+ [num_classes, dim]. The (possibly-sharded) class embeddings.
+ biases: A `Tensor` of shape `[num_classes]`. The class biases.
+ inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
labels: A `Tensor` of type `int64` and shape `[batch_size,
- num_true]`. The target classes. Note that this format differs from
- the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ num_true]`. The target classes. Note that this format differs from
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
num_sampled: An `int`. The number of classes to randomly sample per batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
- sampled_values: a tuple of `(sampled_candidates, true_expected_count,
- sampled_expected_count)` returned by a `*_candidate_sampler` function.
+ sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
+ `sampled_expected_count`) returned by a `*_candidate_sampler` function.
(if None, we default to `log_uniform_candidate_sampler`)
remove_accidental_hits: A `bool`. whether to remove "accidental hits"
where a sampled class equals one of the target classes. Default is
@@ -857,7 +868,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
name: A name for the operation (optional).
Returns:
- A batch_size 1-D tensor of per-example sampled softmax losses.
+ A `batch_size` 1-D tensor of per-example sampled softmax losses.
"""
logits, labels = _compute_sampled_logits(
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index f0709dd2a9..1eb8ef4c69 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,10 +22,10 @@ from __future__ import print_function
import tensorflow.python.platform
import numpy as np
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_nn_ops
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index c8962f60af..65e28978ba 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -25,8 +25,8 @@ import tensorflow.python.platform
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
-from tensorflow.python.framework import types
from tensorflow.python.kernel_tests import gradient_checker as gc
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
@@ -50,7 +50,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase):
pred = [min(max(p, eps), 1 - eps) for p in pred]
return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)]
- def _Inputs(self, x=None, y=None, dtype=types.float64, sizes=None):
+ def _Inputs(self, x=None, y=None, dtype=dtypes.float64, sizes=None):
x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
assert len(x) == len(y)
@@ -70,7 +70,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase):
def testLogisticOutput(self):
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
- logits, targets, losses = self._Inputs(dtype=types.float32)
+ logits, targets, losses = self._Inputs(dtype=dtypes.float32)
loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
@@ -79,7 +79,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase):
def testLogisticOutputMultiDim(self):
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
- logits, targets, losses = self._Inputs(dtype=types.float32,
+ logits, targets, losses = self._Inputs(dtype=dtypes.float32,
sizes=[2, 2, 2])
loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
np_loss = np.array(losses).astype(np.float32)
@@ -168,9 +168,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase):
f_shape = [3, 3, 2, 3]
x = constant_op.constant(1.0, shape=x_shape, name="x",
- dtype=types.float32)
+ dtype=dtypes.float32)
f = constant_op.constant(1.0, shape=f_shape, name="filter",
- dtype=types.float32)
+ dtype=dtypes.float32)
output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
value = output.eval()
@@ -205,9 +205,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase):
f_shape = [3, 3, 2, 3]
x = constant_op.constant(1.0, shape=x_shape, name="x",
- dtype=types.float32)
+ dtype=dtypes.float32)
f = constant_op.constant(1.0, shape=f_shape, name="filter",
- dtype=types.float32)
+ dtype=dtypes.float32)
output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
value = output.eval()
@@ -237,9 +237,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase):
f_shape = [3, 3, 2, 3]
x = constant_op.constant(1.0, shape=x_shape, name="x",
- dtype=types.float32)
+ dtype=dtypes.float32)
f = constant_op.constant(1.0, shape=f_shape, name="filter",
- dtype=types.float32)
+ dtype=dtypes.float32)
output = nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID")
value = output.eval()
@@ -281,8 +281,8 @@ class DeConv2DTest(test_util.TensorFlowTestCase):
x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64)
with self.test_session():
- x = constant_op.constant(x_val, name="x", dtype=types.float32)
- f = constant_op.constant(f_val, name="f", dtype=types.float32)
+ x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
+ f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
err = gc.ComputeGradientError([x, f], [x_shape, f_shape], output, y_shape)
print("DeConv gradient err = %g " % err)
@@ -355,7 +355,7 @@ class DropoutTest(test_util.TensorFlowTestCase):
with self.test_session():
t = constant_op.constant(1.0,
shape=[x_dim, y_dim],
- dtype=types.float32)
+ dtype=dtypes.float32)
dropout = nn.dropout(t, keep_prob)
final_count = 0
self.assertEqual([x_dim, y_dim], dropout.get_shape())
@@ -384,7 +384,7 @@ class DropoutTest(test_util.TensorFlowTestCase):
with self.test_session():
t = constant_op.constant(1.0,
shape=[x_dim, y_dim],
- dtype=types.float32)
+ dtype=dtypes.float32)
dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
self.assertEqual([x_dim, y_dim], dropout.get_shape())
final_count = 0
@@ -410,7 +410,7 @@ class DropoutTest(test_util.TensorFlowTestCase):
with self.test_session():
t = constant_op.constant(1.0,
shape=[x_dim, y_dim],
- dtype=types.float32)
+ dtype=dtypes.float32)
dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
self.assertEqual([x_dim, y_dim], dropout.get_shape())
for _ in xrange(0, num_iter):
@@ -431,8 +431,8 @@ class DropoutTest(test_util.TensorFlowTestCase):
with self.test_session():
t = constant_op.constant(1.0,
shape=[x_dim, y_dim],
- dtype=types.float32)
- keep_prob_placeholder = array_ops.placeholder(types.float32)
+ dtype=dtypes.float32)
+ keep_prob_placeholder = array_ops.placeholder(dtypes.float32)
dropout = nn.dropout(t, keep_prob_placeholder)
final_count = 0
self.assertEqual([x_dim, y_dim], dropout.get_shape())
@@ -453,9 +453,9 @@ class DropoutTest(test_util.TensorFlowTestCase):
x_dim = 40
y_dim = 30
keep_prob = 0.5
- x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=types.float32)
+ x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
dropout_x = nn.dropout(
- x, keep_prob, noise_shape=array_ops.placeholder(types.int32))
+ x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32))
self.assertEqual(x.get_shape(), dropout_x.get_shape())
def testInvalidKeepProb(self):
@@ -463,7 +463,7 @@ class DropoutTest(test_util.TensorFlowTestCase):
y_dim = 30
t = constant_op.constant(1.0,
shape=[x_dim, y_dim],
- dtype=types.float32)
+ dtype=dtypes.float32)
with self.assertRaises(ValueError):
nn.dropout(t, -1.0)
with self.assertRaises(ValueError):
@@ -471,9 +471,9 @@ class DropoutTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
nn.dropout(t, [0.0, 1.0])
with self.assertRaises(ValueError):
- nn.dropout(t, array_ops.placeholder(types.float64))
+ nn.dropout(t, array_ops.placeholder(dtypes.float64))
with self.assertRaises(ValueError):
- nn.dropout(t, array_ops.placeholder(types.float32, shape=[2]))
+ nn.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2]))
def testShapedDropoutShapeError(self):
# Runs shaped dropout and verifies an error is thrown on misshapen noise.
@@ -482,7 +482,7 @@ class DropoutTest(test_util.TensorFlowTestCase):
keep_prob = 0.5
t = constant_op.constant(1.0,
shape=[x_dim, y_dim],
- dtype=types.float32)
+ dtype=dtypes.float32)
with self.assertRaises(ValueError):
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
with self.assertRaises(ValueError):
@@ -641,7 +641,7 @@ class MomentsTest(test_util.TensorFlowTestCase):
assert len(shape) == 4
x_numpy = np.random.normal(size=shape).astype(np.float32)
- x = array_ops.placeholder(types.float32, shape=[None] * len(shape))
+ x = array_ops.placeholder(dtypes.float32, shape=[None] * len(shape))
axes = [0, 1, 2] if global_norm else [0]
mean, var = nn.moments(x, axes)
@@ -722,6 +722,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
self._num_classes = 5
self._dim = 10
self._batch_size = 3
+ self._num_shards = 3
def _GenerateTestInputs(self):
np.random.seed(0)
@@ -729,8 +730,11 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
biases = np.random.randn(self._num_classes).astype(np.float32)
hidden_acts = np.random.randn(self._batch_size, self._dim).astype(
np.float32)
-
- return weights, biases, hidden_acts
+ sharded_weights = [
+ weights[[row for row in range(self._num_classes)
+ if row % self._num_shards == shard]]
+ for shard in range(self._num_shards)]
+ return weights, biases, hidden_acts, sharded_weights
def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b,
hidden_acts,
@@ -763,11 +767,14 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
subtract_log_q, remove_accidental_hits,
name="sampled_loss_TF"):
# Should be called from within a `with test_session():` block
- weights_tf = constant_op.constant(weights)
+ if isinstance(weights, list):
+ weights_tf = [constant_op.constant(shard) for shard in weights]
+ else:
+ weights_tf = constant_op.constant(weights)
biases_tf = constant_op.constant(biases)
hidden_acts_tf = constant_op.constant(hidden_acts,
shape=(self._batch_size, self._dim))
- labels_tf = constant_op.constant(labels, dtype=types.int64,
+ labels_tf = constant_op.constant(labels, dtype=dtypes.int64,
shape=(self._batch_size, num_true))
pred_logits_tf, pred_labels_tf = nn._compute_sampled_logits(
@@ -780,7 +787,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
def testComputeSampledLogitsShapes(self):
# We just check that the shapes of the returned values are correct.
- weights, biases, hidden_acts = self._GenerateTestInputs()
+ weights, biases, hidden_acts, _ = self._GenerateTestInputs()
sampled = [1, 0, 2, 3]
num_sampled = len(sampled)
true_exp = sampled_exp = [1., 1., 1., 1.]
@@ -811,7 +818,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
def testComputeSampledLogitsValues(self):
# Here we check the actual numerics.
- weights, biases, hidden_acts = self._GenerateTestInputs()
+ weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
eps = 1e-3
sampled = [1, 0, 2, 3]
num_sampled = len(sampled)
@@ -887,6 +894,23 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
self.assertAllClose(logits_np, logits_tf_val, eps)
self.assertAllClose(labels_np, labels_tf_val, eps)
+ # Test 4: Test 1, with sharded weights
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ num_true=num_true_test)
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ sharded_weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ name="sampled_loss_test1_num_true%d" % num_true_test)
+
+ logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
+ self.assertAllClose(logits_np, logits_tf_val, eps)
+ self.assertAllClose(labels_np, labels_tf_val, eps)
+
def testNCELoss(self):
# A simple test to verify the numerics.
@@ -898,7 +922,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
pred = np.minimum(np.maximum(pred, eps), 1 - eps)
return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred)
- weights, biases, hidden_acts = self._GenerateTestInputs()
+ weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
labels = [0, 1, 2]
true_w, true_b = weights[labels], biases[labels]
sampled = [1, 0, 2, 3]
@@ -932,6 +956,17 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4)
+ # Test with sharded weights
+ nce_loss_tf = nn.nce_loss(
+ [constant_op.constant(shard) for shard in sharded_weights],
+ biases_tf, inputs_tf, labels_tf,
+ num_sampled=1,
+ num_classes=self._num_classes,
+ num_true=1,
+ sampled_values=test_sampled_vals)
+
+ self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4)
+
def testSampledSoftmaxLoss(self):
# A simple test to verify the numerics.
@@ -943,7 +978,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
- weights, biases, hidden_acts = self._GenerateTestInputs()
+ weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
labels = [0, 1, 2]
true_w, true_b = weights[labels], biases[labels]
sampled = [1, 0, 2, 3]
@@ -977,6 +1012,19 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
self.assertAllClose(
sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4)
+ # Test with sharded weights
+ sampled_softmax_loss_tf = nn.sampled_softmax_loss(
+ [constant_op.constant(shard) for shard in sharded_weights],
+ biases_tf, inputs_tf, labels_tf,
+ num_sampled=1,
+ num_classes=self._num_classes,
+ num_true=1,
+ sampled_values=test_sampled_vals,
+ remove_accidental_hits=False)
+
+ self.assertAllClose(
+ sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
index b2312332f1..8525da35f5 100644
--- a/tensorflow/python/ops/numerics.py
+++ b/tensorflow/python/ops/numerics.py
@@ -19,8 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -45,7 +45,7 @@ def verify_tensor_all_finite(t, msg, name=None):
def add_check_numerics_ops():
- """Connect a check_numerics to every floating point tensor.
+ """Connect a `check_numerics` to every floating point tensor.
`check_numerics` operations themselves are added for each `float` or `double`
tensor in the graph. For all ops in the graph, the `check_numerics` op for
@@ -62,7 +62,7 @@ def add_check_numerics_ops():
# added, and ops can only be added once its inputs are added.
for op in ops.get_default_graph().get_operations():
for output in op.outputs:
- if output.dtype in [types.float32, types.float64]:
+ if output.dtype in [dtypes.float32, dtypes.float64]:
message = op.name + ":" + str(output.value_index)
with ops.control_dependencies(check_op):
check_op = [array_ops.check_numerics(output, message=message)]
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
index 562299fdd6..ad0406d43d 100644
--- a/tensorflow/python/ops/op_def_library.py
+++ b/tensorflow/python/ops/op_def_library.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numbers
import six
from tensorflow.core.framework import attr_value_pb2
@@ -27,11 +26,12 @@ from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import types as types_lib
from tensorflow.python.ops import constant_op
from tensorflow.python.platform import logging
+from tensorflow.python.util import compat
def _Attr(op_def, name):
@@ -55,8 +55,8 @@ def _SatisfiesTypeConstraint(dtype, attr_def):
if dtype not in allowed_list:
raise TypeError(
"DataType %s for attr '%s' not in list of allowed values: %s" %
- (types_lib.as_dtype(dtype).name, attr_def.name,
- ", ".join(types_lib.as_dtype(x).name for x in allowed_list)))
+ (dtypes.as_dtype(dtype).name, attr_def.name,
+ ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
def _IsListParameter(arg):
@@ -137,7 +137,7 @@ def _Restructure(l, structure):
def _MakeFloat(v, arg_name):
- if not isinstance(v, numbers.Real):
+ if not isinstance(v, compat.real_types):
raise TypeError("Expected float for argument '%s' not %s." %
(arg_name, repr(v)))
return float(v)
@@ -155,11 +155,10 @@ def _MakeInt(v, arg_name):
def _MakeStr(v, arg_name):
- if not isinstance(v, six.string_types):
+ if not isinstance(v, compat.bytes_or_text_types):
raise TypeError("Expected string for argument '%s' not %s." %
(arg_name, repr(v)))
- # TODO(irving): Figure out what to do here from Python 3
- return str(v) # Convert unicode strings to bytes.
+ return compat.as_bytes(v) # Convert unicode strings to bytes.
def _MakeBool(v, arg_name):
@@ -171,7 +170,7 @@ def _MakeBool(v, arg_name):
def _MakeType(v, attr_def):
try:
- v = types_lib.as_dtype(v)
+ v = dtypes.as_dtype(v)
except TypeError:
raise TypeError("Expected DataType for argument '%s' not %s." %
(attr_def.name, repr(v)))
@@ -381,7 +380,7 @@ class OpDefLibrary(object):
try:
values = ops.convert_n_to_tensor_or_indexed_slices(
values, name=input_arg.name,
- dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None)
+ dtype=dtypes.as_dtype(dtype).base_dtype if dtype else None)
except (TypeError, ValueError):
assert dtype is not None, "Should not fail if dtype is None"
assert input_arg.number_attr, "Should be number_attr case"
@@ -394,11 +393,11 @@ class OpDefLibrary(object):
(input_name, op_type_name, observed))
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s that do not match expected type %s." %
- (prefix, types_lib.as_dtype(dtype).name))
+ (prefix, dtypes.as_dtype(dtype).name))
elif input_arg.type_attr in attrs:
raise TypeError("%s that do not match type %s inferred from "
"earlier arguments." %
- (prefix, types_lib.as_dtype(dtype).name))
+ (prefix, dtypes.as_dtype(dtype).name))
else:
raise TypeError("%s that don't all match." % prefix)
@@ -423,11 +422,11 @@ class OpDefLibrary(object):
(input_name, op_type_name, observed))
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s expected type of %s." %
- (prefix, types_lib.as_dtype(input_arg.type).name))
+ (prefix, dtypes.as_dtype(input_arg.type).name))
else:
raise TypeError(
"%s type %s of argument '%s'." %
- (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name,
+ (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
inferred_from[input_arg.type_attr]))
types = [values.dtype]
@@ -501,8 +500,8 @@ class OpDefLibrary(object):
"Input '%s' of '%s' Op has type list of %s that does not "
"match type list %s of argument '%s'." %
(input_name, op_type_name,
- ", ".join(types_lib.as_dtype(x).name for x in attr_value),
- ", ".join(types_lib.as_dtype(x).name
+ ", ".join(dtypes.as_dtype(x).name for x in attr_value),
+ ", ".join(dtypes.as_dtype(x).name
for x in attrs[input_arg.type_list_attr]),
inferred_from[input_arg.type_list_attr]))
else:
@@ -567,8 +566,9 @@ class OpDefLibrary(object):
if attr_value.s not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
- (key, op_type_name, attr_value.s,
- '", "'.join(attr_def.allowed_values.list.s)))
+ (key, op_type_name, compat.as_text(attr_value.s),
+ '", "'.join(map(compat.as_text,
+ attr_def.allowed_values.list.s))))
elif attr_def.type == "list(string)":
attr_value.list.s.extend([_MakeStr(x, key) for x in value])
if attr_def.HasField("allowed_values"):
@@ -576,8 +576,9 @@ class OpDefLibrary(object):
if x not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
- (key, op_type_name, x,
- '", "'.join(attr_def.allowed_values.list.s)))
+ (key, op_type_name, compat.as_text(x),
+ '", "'.join(map(compat.as_text,
+ attr_def.allowed_values.list.s))))
elif attr_def.type == "int":
attr_value.i = _MakeInt(value, key)
if attr_def.has_minimum:
@@ -640,7 +641,7 @@ class OpDefLibrary(object):
types = [arg.type]
output_structure.append(None)
if arg.is_ref:
- types = [types_lib.as_dtype(x).as_ref for x in types]
+ types = [dtypes.as_dtype(x).as_ref for x in types]
output_types.extend(types)
if keywords:
diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py
index 4a8cdbe9a4..4b9d66e8a6 100644
--- a/tensorflow/python/ops/op_def_library_test.py
+++ b/tensorflow/python/ops/op_def_library_test.py
@@ -23,10 +23,10 @@ from google.protobuf import text_format
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import types
from tensorflow.python.ops.op_def_library import OpDefLibrary
from tensorflow.python.platform import googletest
@@ -104,13 +104,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
def testNoRegisteredOpFails(self):
with self.assertRaises(RuntimeError) as cm:
self._lib.apply_op("unknown", g=self._g)
- self.assertEqual(cm.exception.message, "Unrecognized Op name unknown")
+ self.assertEqual(str(cm.exception), "Unrecognized Op name unknown")
def testAddOpValidation(self):
with self.assertRaises(TypeError) as cm:
self._add_op("name: 'MissingTypeAttr' "
"input_arg { name: 'a' type_attr: 'T' } ")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Inconsistent OpDef for 'MissingTypeAttr', "
"missing attr 'T'")
@@ -119,13 +119,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"output_arg { name: 'a' type_attr: 'T' } "
"attr { name: 'T' type: 'int' }")
self.assertEqual(
- cm.exception.message,
+ str(cm.exception),
"Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int")
with self.assertRaises(TypeError) as cm:
self._add_op("name: 'MissingNumberAttr' "
"input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Inconsistent OpDef for 'MissingNumberAttr', "
"missing attr 'N'")
@@ -134,21 +134,21 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
"attr { name: 'N' type: 'type' }")
self.assertEqual(
- cm.exception.message,
+ str(cm.exception),
"Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type")
with self.assertRaises(TypeError) as cm:
self._add_op("name: 'TwoTypesA' "
"input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } "
"attr { name: 'T' type: 'type' }")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Arg 'a' of 'TwoTypesA' must have one type field not 2")
with self.assertRaises(TypeError) as cm:
self._add_op("name: 'TwoTypesB' "
"input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } "
"attr { name: 'T' type: 'list(type)' }")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Arg 'a' of 'TwoTypesB' must have one type field not 2")
with self.assertRaises(TypeError) as cm:
@@ -157,17 +157,17 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"type_list_attr: 'U' } "
"attr { name: 'T' type: 'type' } "
"attr { name: 'U' type: 'list(type)' }")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Arg 'a' of 'ThreeTypes' must have one type field not 3")
with self.assertRaises(TypeError) as cm:
self._add_op("name: 'NoTypes' output_arg { name: 'a' } ")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Arg 'a' of 'NoTypes' must have one type field not 0")
def testSimple(self):
out = self._lib.apply_op("Simple", a=3)
- self.assertEquals(types.float32, out.dtype)
+ self.assertEqual(dtypes.float32, out.dtype)
self.assertProtoEquals("""
name: 'Simple' op: 'Simple' input: 'Simple/a'
""", out.op.node_def)
@@ -190,39 +190,39 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
def testSimpleFailures(self):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Simple", a="Bad string")
- self.assertEqual(cm.exception.message,
- "Expected int32, got 'Bad string' instead.")
+ self.assertEqual(str(cm.exception),
+ "Expected int32, got 'Bad string' of type 'str' instead.")
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("Simple", a=self.Tensor(types.string))
- self.assertEqual(cm.exception.message,
+ self._lib.apply_op("Simple", a=self.Tensor(dtypes.string))
+ self.assertEqual(str(cm.exception),
"Input 'a' of 'Simple' Op has type string "
"that does not match expected type of int32.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Simple", a=6, extra="bogus")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"apply_op() got unexpected keyword arguments: extra")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"apply_op() got unexpected keyword arguments: extra1, "
"extra2")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Simple")
- self.assertEqual(cm.exception.message, "No argument for input a")
+ self.assertEqual(str(cm.exception), "No argument for input a")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Simple", wrong=7)
- self.assertEqual(cm.exception.message, "No argument for input a")
+ self.assertEqual(str(cm.exception), "No argument for input a")
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("Simple", a=[self.Tensor(types.int32)])
- self.assertStartsWith(
- cm.exception.message,
- "Expected int32, got list containing Tensors instead.")
+ self._lib.apply_op("Simple", a=[self.Tensor(dtypes.int32)])
+ self.assertStartsWith(str(cm.exception),
+ "Expected int32, got list containing Tensors of type "
+ "'_Message' instead.")
def testReservedInput(self):
self._add_op("name: 'ReservedInput' "
@@ -239,34 +239,34 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"attr { name: 'T' type: 'type' }")
out = self._lib.apply_op("Polymorphic", a=7, name="p")
- self.assertEquals(types.int32, out.dtype)
+ self.assertEqual(dtypes.int32, out.dtype)
self.assertProtoEquals("""
name: 'p' op: 'Polymorphic' input: 'p/a'
attr { key: 'T' value { type: DT_INT32 } }
""", out.op.node_def)
out = self._lib.apply_op("Polymorphic", a="s", name="q")
- self.assertEquals(types.string, out.dtype)
+ self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'q' op: 'Polymorphic' input: 'q/a'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r")
- self.assertEquals(types.string, out.dtype)
+ self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'r' op: 'Polymorphic' input: 'r/a'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("Polymorphic", a="s", T=types.string)
- self.assertEqual(cm.exception.message,
+ self._lib.apply_op("Polymorphic", a="s", T=dtypes.string)
+ self.assertEqual(str(cm.exception),
"Should not specify value for inferred attr 'T'.")
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("Polymorphic", a=[self.Tensor(types.bool)])
- self.assertEqual(cm.exception.message,
+ self._lib.apply_op("Polymorphic", a=[self.Tensor(dtypes.bool)])
+ self.assertEqual(str(cm.exception),
"List of Tensors when single Tensor expected")
def testPolymorphicOut(self):
@@ -274,15 +274,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"output_arg { name: 'out' type_attr: 'T' } "
"attr { name: 'T' type: 'type' }")
- out = self._lib.apply_op("PolymorphicOut", T=types.int32, name="p")
- self.assertEquals(types.int32, out.dtype)
+ out = self._lib.apply_op("PolymorphicOut", T=dtypes.int32, name="p")
+ self.assertEqual(dtypes.int32, out.dtype)
self.assertProtoEquals("""
name: 'p' op: 'PolymorphicOut'
attr { key: 'T' value { type: DT_INT32 } }
""", out.op.node_def)
- out = self._lib.apply_op("PolymorphicOut", T=types.bool, name="q")
- self.assertEquals(types.bool, out.dtype)
+ out = self._lib.apply_op("PolymorphicOut", T=dtypes.bool, name="q")
+ self.assertEqual(dtypes.bool, out.dtype)
self.assertProtoEquals("""
name: 'q' op: 'PolymorphicOut'
attr { key: 'T' value { type: DT_BOOL } }
@@ -290,12 +290,12 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("PolymorphicOut")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"No argument for attr T")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("PolymorphicOut", T=None)
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected DataType for argument 'T' not None.")
def testPolymorphicDefaultOut(self):
@@ -305,15 +305,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
" default_value { type: DT_STRING } }")
out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p")
- self.assertEquals(types.string, out.dtype)
+ self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'p' op: 'PolymorphicDefaultOut'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
- out = self._lib.apply_op("PolymorphicDefaultOut", T=types.bool,
+ out = self._lib.apply_op("PolymorphicDefaultOut", T=dtypes.bool,
name="q")
- self.assertEquals(types.bool, out.dtype)
+ self.assertEqual(dtypes.bool, out.dtype)
self.assertProtoEquals("""
name: 'q' op: 'PolymorphicDefaultOut'
attr { key: 'T' value { type: DT_BOOL } }
@@ -327,14 +327,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"attr { name: 'T' type: 'type' }")
out = self._lib.apply_op("Binary", a=8, b=9, name="b")
- self.assertEquals(types.int32, out.dtype)
+ self.assertEqual(dtypes.int32, out.dtype)
self.assertProtoEquals("""
name: 'b' op: 'Binary' input: 'b/a' input: 'b/b'
attr { key: 'T' value { type: DT_INT32 } }
""", out.op.node_def)
out = self._lib.apply_op("Binary", a="left", b="right", name="c")
- self.assertEquals(types.string, out.dtype)
+ self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'c' op: 'Binary' input: 'c/a' input: 'c/b'
attr { key: 'T' value { type: DT_STRING } }
@@ -342,13 +342,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Binary", a="left", b=12)
- self.assertEqual(cm.exception.message,
- "Expected string, got 12 instead.")
+ self.assertEqual(str(cm.exception),
+ "Expected string, got 12 of type 'int' instead.")
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("Binary", a=self.Tensor(types.string),
- b=self.Tensor(types.int32))
- self.assertEqual(cm.exception.message,
+ self._lib.apply_op("Binary", a=self.Tensor(dtypes.string),
+ b=self.Tensor(dtypes.int32))
+ self.assertEqual(str(cm.exception),
"Input 'b' of 'Binary' Op has type int32 "
"that does not match type string of argument 'a'.")
@@ -360,14 +360,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
" type: DT_STRING type: DT_BOOL } } }")
out = self._lib.apply_op("Restrict", a="foo", name="g")
- self.assertEquals(types.string, out.dtype)
+ self.assertEqual(dtypes.string, out.dtype)
self.assertProtoEquals("""
name: 'g' op: 'Restrict' input: 'g/a'
attr { key: 'T' value { type: DT_STRING } }
""", out.op.node_def)
out = self._lib.apply_op("Restrict", a=True, name="h")
- self.assertEquals(types.bool, out.dtype)
+ self.assertEqual(dtypes.bool, out.dtype)
self.assertProtoEquals("""
name: 'h' op: 'Restrict' input: 'h/a'
attr { key: 'T' value { type: DT_BOOL } }
@@ -375,7 +375,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Restrict", a=17)
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"DataType int32 for attr 'T' "
"not in list of allowed values: "
"string, bool")
@@ -404,7 +404,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("TypeList", a=17)
- self.assertStartsWith(cm.exception.message,
+ self.assertStartsWith(str(cm.exception),
"Expected list for 'a' "
"argument to 'TypeList' Op, not ")
@@ -429,7 +429,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Input 'b' of 'TypeListTwice' Op has type list of "
"string, int32 that does not match type list "
"string, bool of argument 'a'.")
@@ -439,18 +439,18 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"output_arg { name: 'out' type_list_attr: 'T' } "
"attr { name: 'T' type: 'list(type)' }")
- out, = self._lib.apply_op("OutTypeList", T=[types.float32], name="x")
- self.assertEquals(types.float32, out.dtype)
+ out, = self._lib.apply_op("OutTypeList", T=[dtypes.float32], name="x")
+ self.assertEqual(dtypes.float32, out.dtype)
self.assertProtoEquals("""
name: 'x' op: 'OutTypeList'
attr { key: 'T' value { list { type: DT_FLOAT } } }
""", out.op.node_def)
out1, out2 = self._lib.apply_op("OutTypeList",
- T=[types.int32, types.bool],
+ T=[dtypes.int32, dtypes.bool],
name="w")
- self.assertEquals(types.int32, out1.dtype)
- self.assertEquals(types.bool, out2.dtype)
+ self.assertEqual(dtypes.int32, out1.dtype)
+ self.assertEqual(dtypes.bool, out2.dtype)
self.assertProtoEquals("""
name: 'w' op: 'OutTypeList'
attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } }
@@ -460,8 +460,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
self.assertEqual([], out)
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("OutTypeList", T=types.int32)
- self.assertEqual(cm.exception.message, "Expected list for attr T")
+ self._lib.apply_op("OutTypeList", T=dtypes.int32)
+ self.assertEqual(str(cm.exception), "Expected list for attr T")
def testTypeListRestrict(self):
self._add_op("name: 'TypeListRestrict' "
@@ -477,7 +477,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("TypeListRestrict", a=[True, 12])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"DataType int32 for attr 'T' "
"not in list of allowed values: string, bool")
@@ -488,10 +488,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
" type: DT_STRING type: DT_BOOL } } }")
out1, out2 = self._lib.apply_op("OutTypeListRestrict",
- t=[types.bool, types.string],
+ t=[dtypes.bool, dtypes.string],
name="u")
- self.assertEquals(types.bool, out1.dtype)
- self.assertEquals(types.string, out2.dtype)
+ self.assertEqual(dtypes.bool, out1.dtype)
+ self.assertEqual(dtypes.string, out2.dtype)
self.assertProtoEquals("""
name: 'u' op: 'OutTypeListRestrict'
attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } }
@@ -499,8 +499,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("OutTypeListRestrict",
- t=[types.string, types.int32])
- self.assertEqual(cm.exception.message,
+ t=[dtypes.string, dtypes.int32])
+ self.assertEqual(str(cm.exception),
"DataType int32 for attr 't' "
"not in list of allowed values: string, bool")
@@ -518,22 +518,22 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Attr", a="bad")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected int for argument 'a' not 'bad'.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Attr", a=[12])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected int for argument 'a' not [12].")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Attr", a=None)
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected int for argument 'a' not None.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("Attr")
- self.assertEqual(cm.exception.message, "No argument for attr a")
+ self.assertEqual(str(cm.exception), "No argument for attr a")
def testAttrFloat(self):
self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }")
@@ -550,7 +550,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("AttrFloat", a="bad")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected float for argument 'a' not 'bad'.")
def testAttrBool(self):
@@ -568,17 +568,17 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("AttrBool", a=0)
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected bool for argument 'a' not 0.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("AttrBool", a=1)
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected bool for argument 'a' not 1.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("AttrBool", a=[])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected bool for argument 'a' not [].")
def testAttrBoolList(self):
@@ -597,7 +597,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("AttrBoolList", a=[0])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected bool for argument 'a' not 0.")
def testAttrMin(self):
@@ -610,7 +610,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("AttrMin", a=2)
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.")
def testAttrListMin(self):
@@ -625,7 +625,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("AttrListMin", a=[17])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Attr 'a' of 'AttrListMin' Op "
"passed list of length 1 less than minimum 2.")
@@ -641,7 +641,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("AttrEnum", a="invalid")
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
'Attr \'a\' of \'AttrEnum\' Op '
'passed string \'invalid\' not in: '
'"apples", "oranges".')
@@ -659,7 +659,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
'Attr \'a\' of \'AttrEnumList\' Op '
'passed string \'invalid\' not '
'in: "apples", "oranges".')
@@ -705,7 +705,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
# TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes.
# with self.assertRaises(TypeError) as cm:
# self._lib.apply_op("AttrShape", a=5)
- # self.assertEqual(cm.exception.message,
+ # self.assertEqual(str(cm.exception),
# "Don't know how to convert 5 to a TensorShapeProto for "
# "argument 'a'")
@@ -817,42 +817,42 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NIntsIn", a=["foo", "bar"])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NIntsIn' Op have types "
"[string, string] that do not match expected type int32.")
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("NIntsIn", a=[self.Tensor(types.string),
- self.Tensor(types.string)])
- self.assertEqual(cm.exception.message,
+ self._lib.apply_op("NIntsIn", a=[self.Tensor(dtypes.string),
+ self.Tensor(dtypes.string)])
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NIntsIn' Op have "
"types [string, string] that do not match expected type "
"int32.")
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("NIntsIn", a=[99])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"List argument 'a' to 'NIntsIn' Op "
"with length 1 shorter than "
"minimum length 2.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NIntsIn", a=[38, "bar"])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NIntsIn' Op have types "
"[int32, string] that do not match expected type int32.")
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("NIntsIn", a=[self.Tensor(types.int32),
- self.Tensor(types.string)])
- self.assertEqual(cm.exception.message,
+ self._lib.apply_op("NIntsIn", a=[self.Tensor(dtypes.int32),
+ self.Tensor(dtypes.string)])
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NIntsIn' Op "
"have types [int32, string] that do not match expected "
"type int32.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NIntsIn", a=17)
- self.assertStartsWith(cm.exception.message,
+ self.assertStartsWith(str(cm.exception),
"Expected list for 'a' argument "
"to 'NIntsIn' Op, not ")
@@ -885,7 +885,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
""", op.node_def)
op = self._lib.apply_op("NPolymorphicIn",
- a=[1, self.Tensor(types.float32, name="x")],
+ a=[1, self.Tensor(dtypes.float32, name="x")],
name="q")
self.assertProtoEquals("""
name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x'
@@ -895,33 +895,33 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("NPolymorphicIn", a=[99])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"List argument 'a' to 'NPolymorphicIn' Op with length 1 "
"shorter than minimum length 2.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicIn", a=[38, "bar"])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"All tensors passed to 'a' of 'NPolymorphicIn' "
"Op must have the same type.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicIn",
- a=[38, self.Tensor(types.string)])
- self.assertEqual(cm.exception.message,
+ a=[38, self.Tensor(dtypes.string)])
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [int32, string] that don't all match.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicIn",
- a=["abcd", self.Tensor(types.int32)])
- self.assertEqual(cm.exception.message,
+ a=["abcd", self.Tensor(dtypes.int32)])
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [string, int32] that don't all match.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicIn", a=17)
- self.assertStartsWith(cm.exception.message,
+ self.assertStartsWith(str(cm.exception),
"Expected list for 'a' argument "
"to 'NPolymorphicIn' Op, not ")
@@ -951,7 +951,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"DataType int32 for attr 'T' "
"not in list of allowed values: string, bool")
@@ -975,7 +975,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"List argument 'b' to 'NInTwice' Op "
"with length 1 must match "
"length 3 of argument 'a'.")
@@ -997,23 +997,23 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"List argument 'b' to 'NInPolymorphicTwice' Op "
"with length 1 "
"must match length 3 of argument 'a'.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'b' of 'NInPolymorphicTwice' "
"Op have types [string, string] that do not match type "
"int32 inferred from earlier arguments.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NInPolymorphicTwice",
- a=[self.Tensor(types.int32)],
- b=[self.Tensor(types.string)])
- self.assertEqual(cm.exception.message,
+ a=[self.Tensor(dtypes.int32)],
+ b=[self.Tensor(dtypes.string)])
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'b' of "
"'NInPolymorphicTwice' Op have types [string] that do not "
"match type int32 inferred from earlier arguments.")
@@ -1046,8 +1046,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
""", op.node_def)
op = self._lib.apply_op("NInTwoTypeVariables",
- a=[self.Tensor(types.int32, name="q")],
- b=[self.Tensor(types.string, name="r")],
+ a=[self.Tensor(dtypes.int32, name="q")],
+ b=[self.Tensor(dtypes.string, name="r")],
name="p")
self.assertProtoEquals("""
name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r'
@@ -1058,7 +1058,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"List argument 'b' to 'NInTwoTypeVariables' Op "
"with length 1 "
"must match length 3 of argument 'a'.")
@@ -1090,22 +1090,22 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Don't know how to infer type variable from empty input "
"list passed to input 'a' of 'InPolymorphicTwice' Op.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'b' of 'InPolymorphicTwice' Op "
"have types [string, string] that do not match type int32 "
"inferred from earlier arguments.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("InPolymorphicTwice",
- a=[self.Tensor(types.int32)],
- b=[self.Tensor(types.string)])
- self.assertEqual(cm.exception.message,
+ a=[self.Tensor(dtypes.int32)],
+ b=[self.Tensor(dtypes.string)])
+ self.assertEqual(str(cm.exception),
"Tensors in list passed to 'b' of 'InPolymorphicTwice' "
"Op have types [string] that do not match type int32 "
"inferred from earlier arguments.")
@@ -1116,31 +1116,31 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n")
- self.assertEquals(types.int32, out1.dtype)
- self.assertEquals(types.int32, out2.dtype)
+ self.assertEqual(dtypes.int32, out1.dtype)
+ self.assertEqual(dtypes.int32, out2.dtype)
self.assertProtoEquals("""
name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } }
""", out1.op.node_def)
out1, out2, out3, out4, out5 = self._lib.apply_op(
"NIntsOut", N=5, name="o")
- self.assertEquals(types.int32, out1.dtype)
- self.assertEquals(types.int32, out2.dtype)
- self.assertEquals(types.int32, out3.dtype)
- self.assertEquals(types.int32, out4.dtype)
- self.assertEquals(types.int32, out5.dtype)
+ self.assertEqual(dtypes.int32, out1.dtype)
+ self.assertEqual(dtypes.int32, out2.dtype)
+ self.assertEqual(dtypes.int32, out3.dtype)
+ self.assertEqual(dtypes.int32, out4.dtype)
+ self.assertEqual(dtypes.int32, out5.dtype)
self.assertProtoEquals("""
name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } }
""", out5.op.node_def)
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("NIntsOut", N=1)
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.")
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NIntsOut", N=[3])
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Expected int for argument 'N' not [3].")
def testNIntsOutDefault(self):
@@ -1151,16 +1151,16 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
out1, out2, out3 = self._lib.apply_op(
"NIntsOutDefault", N=None, name="z")
- self.assertEquals(types.int32, out1.dtype)
- self.assertEquals(types.int32, out2.dtype)
- self.assertEquals(types.int32, out3.dtype)
+ self.assertEqual(dtypes.int32, out1.dtype)
+ self.assertEqual(dtypes.int32, out2.dtype)
+ self.assertEqual(dtypes.int32, out3.dtype)
self.assertProtoEquals("""
name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } }
""", out1.op.node_def)
out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y")
- self.assertEquals(types.int32, out1.dtype)
- self.assertEquals(types.int32, out2.dtype)
+ self.assertEqual(dtypes.int32, out1.dtype)
+ self.assertEqual(dtypes.int32, out2.dtype)
self.assertProtoEquals("""
name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } }
""", out2.op.node_def)
@@ -1172,9 +1172,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
out1, out2 = self._lib.apply_op("NPolymorphicOut", N=2,
- T=types.int32, name="n")
- self.assertEquals(types.int32, out1.dtype)
- self.assertEquals(types.int32, out2.dtype)
+ T=dtypes.int32, name="n")
+ self.assertEqual(dtypes.int32, out1.dtype)
+ self.assertEqual(dtypes.int32, out2.dtype)
self.assertProtoEquals("""
name: 'n' op: 'NPolymorphicOut'
attr { key: 'T' value { type: DT_INT32 } }
@@ -1182,10 +1182,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
""", out1.op.node_def)
out1, out2, out3 = self._lib.apply_op(
- "NPolymorphicOut", T=types.string, N=3, name="o")
- self.assertEquals(types.string, out1.dtype)
- self.assertEquals(types.string, out2.dtype)
- self.assertEquals(types.string, out3.dtype)
+ "NPolymorphicOut", T=dtypes.string, N=3, name="o")
+ self.assertEqual(dtypes.string, out1.dtype)
+ self.assertEqual(dtypes.string, out2.dtype)
+ self.assertEqual(dtypes.string, out3.dtype)
self.assertProtoEquals("""
name: 'o' op: 'NPolymorphicOut'
attr { key: 'T' value { type: DT_STRING } }
@@ -1193,15 +1193,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
""", out3.op.node_def)
with self.assertRaises(ValueError) as cm:
- self._lib.apply_op("NPolymorphicOut", N=1, T=types.string)
- self.assertEqual(cm.exception.message,
+ self._lib.apply_op("NPolymorphicOut", N=1, T=dtypes.string)
+ self.assertEqual(str(cm.exception),
"Attr 'N' of 'NPolymorphicOut' Op "
"passed 1 less than minimum 2.")
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("NPolymorphicOut", N=3, T=[types.string])
+ self._lib.apply_op("NPolymorphicOut", N=3, T=[dtypes.string])
self.assertEqual(
- cm.exception.message,
+ str(cm.exception),
"Expected DataType for argument 'T' not [tf.string].")
def testNPolymorphicOutDefault(self):
@@ -1214,8 +1214,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
out1, out2 = self._lib.apply_op(
"NPolymorphicOutDefault", N=None, T=None, name="r")
- self.assertEquals(types.bool, out1.dtype)
- self.assertEquals(types.bool, out2.dtype)
+ self.assertEqual(dtypes.bool, out1.dtype)
+ self.assertEqual(dtypes.bool, out2.dtype)
self.assertProtoEquals("""
name: 'r' op: 'NPolymorphicOutDefault'
attr { key: 'T' value { type: DT_BOOL } }
@@ -1224,9 +1224,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
out1, out2, out3 = self._lib.apply_op(
"NPolymorphicOutDefault", N=3, T=None, name="s")
- self.assertEquals(types.bool, out1.dtype)
- self.assertEquals(types.bool, out2.dtype)
- self.assertEquals(types.bool, out3.dtype)
+ self.assertEqual(dtypes.bool, out1.dtype)
+ self.assertEqual(dtypes.bool, out2.dtype)
+ self.assertEqual(dtypes.bool, out3.dtype)
self.assertProtoEquals("""
name: 's' op: 'NPolymorphicOutDefault'
attr { key: 'T' value { type: DT_BOOL } }
@@ -1234,9 +1234,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
""", out1.op.node_def)
out1, out2 = self._lib.apply_op(
- "NPolymorphicOutDefault", N=None, T=types.int32, name="t")
- self.assertEquals(types.int32, out1.dtype)
- self.assertEquals(types.int32, out2.dtype)
+ "NPolymorphicOutDefault", N=None, T=dtypes.int32, name="t")
+ self.assertEqual(dtypes.int32, out1.dtype)
+ self.assertEqual(dtypes.int32, out2.dtype)
self.assertProtoEquals("""
name: 't' op: 'NPolymorphicOutDefault'
attr { key: 'T' value { type: DT_INT32 } }
@@ -1244,10 +1244,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
""", out1.op.node_def)
out1, out2, out3 = self._lib.apply_op(
- "NPolymorphicOutDefault", N=3, T=types.int32, name="u")
- self.assertEquals(types.int32, out1.dtype)
- self.assertEquals(types.int32, out2.dtype)
- self.assertEquals(types.int32, out3.dtype)
+ "NPolymorphicOutDefault", N=3, T=dtypes.int32, name="u")
+ self.assertEqual(dtypes.int32, out1.dtype)
+ self.assertEqual(dtypes.int32, out2.dtype)
+ self.assertEqual(dtypes.int32, out3.dtype)
self.assertProtoEquals("""
name: 'u' op: 'NPolymorphicOutDefault'
attr { key: 'T' value { type: DT_INT32 } }
@@ -1262,10 +1262,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
out1, out2, out3 = self._lib.apply_op(
- "NPolymorphicRestrictOut", N=3, T=types.bool, name="u")
- self.assertEquals(types.bool, out1.dtype)
- self.assertEquals(types.bool, out2.dtype)
- self.assertEquals(types.bool, out3.dtype)
+ "NPolymorphicRestrictOut", N=3, T=dtypes.bool, name="u")
+ self.assertEqual(dtypes.bool, out1.dtype)
+ self.assertEqual(dtypes.bool, out2.dtype)
+ self.assertEqual(dtypes.bool, out3.dtype)
self.assertProtoEquals("""
name: 'u' op: 'NPolymorphicRestrictOut'
attr { key: 'T' value { type: DT_BOOL } }
@@ -1273,8 +1273,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
""", out1.op.node_def)
with self.assertRaises(TypeError) as cm:
- self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=types.int32)
- self.assertEqual(cm.exception.message,
+ self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32)
+ self.assertEqual(str(cm.exception),
"DataType int32 for attr 'T' "
"not in list of allowed values: string, bool")
@@ -1286,8 +1286,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
"output_arg { name: 'a' type_attr: 'T' is_ref: true } "
"attr { name: 'T' type: 'type' } ")
- out = self._lib.apply_op("RefOut", T=types.bool, name="o")
- self.assertEquals(types.bool_ref, out.dtype)
+ out = self._lib.apply_op("RefOut", T=dtypes.bool, name="o")
+ self.assertEqual(dtypes.bool_ref, out.dtype)
self.assertProtoEquals("""
name: 'o' op: 'RefOut'
attr { key: 'T' value { type: DT_BOOL } }
@@ -1300,7 +1300,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
""", op.node_def)
# Can pass ref to non-ref input.
- out = self._lib.apply_op("RefOut", T=types.int32, name="r")
+ out = self._lib.apply_op("RefOut", T=dtypes.int32, name="r")
out = self._lib.apply_op("Simple", a=out, name="s")
self.assertProtoEquals("""
name: 's' op: 'Simple' input: 'r'
@@ -1309,7 +1309,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
# Can't pass non-ref to ref input.
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("RefIn", a=2)
- self.assertEqual(cm.exception.message,
+ self.assertEqual(str(cm.exception),
"Input 'a' of 'RefIn' Op requires l-value input")
def testSpecifyDevice(self):
@@ -1340,9 +1340,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
a, b = self._lib.apply_op("MixedStruct", n_a=n_a)
self.assertTrue(isinstance(a, list))
self.assertEqual(n_a, len(a))
- self.assertTrue(all(x.dtype == types.int32 for x in a))
+ self.assertTrue(all(x.dtype == dtypes.int32 for x in a))
self.assertTrue(isinstance(b, ops.Tensor))
- self.assertEqual(types.float32, b.dtype)
+ self.assertEqual(dtypes.float32, b.dtype)
def testStructuredOutputMultipleLists(self):
self._add_op("name: 'ComplexStruct' "
@@ -1355,15 +1355,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
for n_a in [0, 1, 3]:
for n_b in [0, 1, 3]:
for t_c in [[],
- [types.int32],
- [types.int32, types.float32]]:
+ [dtypes.int32],
+ [dtypes.int32, dtypes.float32]]:
a, b, c = self._lib.apply_op("ComplexStruct",
n_a=n_a, n_b=n_b, t_c=t_c)
self.assertEqual(n_a, len(a))
- self.assertTrue(all(x.dtype == types.int32 for x in a))
+ self.assertTrue(all(x.dtype == dtypes.int32 for x in a))
self.assertEqual(n_b, len(b))
- self.assertTrue(all(x.dtype == types.int64 for x in b))
+ self.assertTrue(all(x.dtype == dtypes.int64 for x in b))
self.assertEqual(t_c, [x.dtype for x in c])
@@ -1387,19 +1387,19 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):
def testNoGraph(self):
out = self._lib.apply_op("Simple", a=3)
- self.assertEquals(out.graph, ops.get_default_graph())
+ self.assertEqual(out.graph, ops.get_default_graph())
def testDefaultGraph(self):
with self._g.as_default():
out = self._lib.apply_op("Simple", a=3)
- self.assertEquals(out.graph, self._g)
+ self.assertEqual(out.graph, self._g)
def testIgnoreDefaultGraphWithGraphArgument(self):
default_g = ops.Graph()
with default_g.as_default():
out = self._lib.apply_op("Simple", a=3, g=self._g)
- self.assertEquals(ops.get_default_graph(), default_g)
- self.assertEquals(out.graph, self._g)
+ self.assertEqual(ops.get_default_graph(), default_g)
+ self.assertEqual(out.graph, self._g)
def testDifferentGraphFails(self):
a = self._lib.apply_op("Simple", a=3, g=self._g)
@@ -1407,7 +1407,7 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):
b = self._lib.apply_op("Simple", a=4, g=other_g)
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("Binary", a=a, b=b)
- self.assertTrue("must be from the same graph" in cm.exception.message)
+ self.assertTrue("must be from the same graph" in str(cm.exception))
def testDifferentGraphFailsWithGraphArgument(self):
other_g = ops.Graph()
@@ -1416,7 +1416,7 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as cm:
self._lib.apply_op("Binary", a=a, b=b, g=self._g)
self.assertTrue(
- "not from the passed-in graph" in cm.exception.message)
+ "not from the passed-in graph" in str(cm.exception))
if __name__ == "__main__":
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index a6408f9f0a..87148af2fd 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -88,12 +88,12 @@ def parse_example(serialized,
```
serialized = [
- features:
- { feature: [ key: { "ft" value: float_list: { value: [1.0, 2.0] } } ] },
- features:
- { feature: [] },
- features:
- { feature: [ key: { "ft" value: float_list: { value: [3.0] } } ] }
+ features
+ { feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } },
+ features
+ { feature []},
+ features
+ { feature { key: "ft" value { float_list { value: [3.0] } } }
]
```
@@ -109,14 +109,14 @@ def parse_example(serialized,
```
[
- features: {
- feature: { key: "kw" value: { bytes_list: { value: [ "knit", "big" ] } } }
- feature: { key: "gps" value: { float_list: { value: [] } } }
+ features {
+ feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } }
+ feature { key: "gps" value { float_list { value: [] } } }
},
- features: {
- feature: { key: "kw" value: { bytes_list: { value: [ "emmy" ] } } }
- feature: { key: "dank" value: { int64_list: { value: [ 42 ] } } }
- feature: { key: "gps" value: { } }
+ features {
+ feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } }
+ feature { key: "dank" value { int64_list { value: [ 42 ] } } }
+ feature { key: "gps" value { } }
}
]
```
@@ -152,13 +152,13 @@ def parse_example(serialized,
```
[
- features: {
- feature: { key: "age" value: { int64_list: { value: [ 0 ] } } }
- feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } }
+ features {
+ feature { key: "age" value { int64_list { value: [ 0 ] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
},
- features: {
- feature: { key: "age" value: { int64_list: { value: [] } } }
- feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } }
+ features {
+ feature { key: "age" value { int64_list { value: [] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
}
]
```
@@ -204,6 +204,8 @@ def parse_example(serialized,
The keys of the dict must match the dense_keys of the feature.
dense_shapes: A list of tuples with the same length as `dense_keys`.
The shape of the data for each dense feature referenced by `dense_keys`.
+ Required for any input tensors identified by dense_keys whose shapes are
+ anything other than [] or [1].
name: A name for this operation (optional).
Returns:
@@ -297,22 +299,22 @@ def parse_single_example(serialized, # pylint: disable=invalid-name
For `SparseTensor`s, the first (batch) column of the indices matrix is removed
(the indices matrix is a column vector), the values vector is unchanged, and
- the first (batch_size) entry of the shape vector is removed (it is now a
+ the first (`batch_size`) entry of the shape vector is removed (it is now a
single element vector).
See also `parse_example`.
Args:
serialized: A scalar string Tensor, a single serialized Example.
- See parse_example documentation for more details.
+ See `parse_example` documentation for more details.
names: (Optional) A scalar string Tensor, the associated name.
- See parse_example documentation for more details.
- sparse_keys: See parse_example documentation for more details.
- sparse_types: See parse_example documentation for more details.
- dense_keys: See parse_example documentation for more details.
- dense_types: See parse_example documentation for more details.
- dense_defaults: See parse_example documentation for more details.
- dense_shapes: See parse_example documentation for more details.
+ See `parse_example` documentation for more details.
+ sparse_keys: See `parse_example` documentation for more details.
+ sparse_types: See `parse_example` documentation for more details.
+ dense_keys: See `parse_example` documentation for more details.
+ dense_types: See `parse_example` documentation for more details.
+ dense_defaults: See `parse_example` documentation for more details.
+ dense_shapes: See `parse_example` documentation for more details.
name: A name for this operation (optional).
Returns:
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 37b1cb115e..cefeec54bb 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -19,10 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_random_ops
@@ -35,14 +35,14 @@ from tensorflow.python.ops.gen_random_ops import *
def _ShapeTensor(shape):
"""Convert to an int32 or int64 tensor, defaulting to int32 if empty."""
if isinstance(shape, (tuple, list)) and not shape:
- dtype = types.int32
+ dtype = dtypes.int32
else:
dtype = None
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
# pylint: disable=protected-access
-def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
+def random_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
seed=None, name=None):
"""Outputs random values from a normal distribution.
@@ -80,7 +80,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
ops.NoGradient("RandomStandardNormal")
-def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
+def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
seed=None, name=None):
"""Outputs random values from a truncated normal distribution.
@@ -123,7 +123,7 @@ ops.NoGradient("TruncatedNormal")
def random_uniform(shape, minval=0.0, maxval=1.0,
- dtype=types.float32, seed=None,
+ dtype=dtypes.float32, seed=None,
name=None):
"""Outputs random values from a uniform distribution.
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index d168245989..1507371b40 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -46,10 +46,10 @@ import tensorflow.python.platform
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import gen_sparse_ops
@@ -178,7 +178,7 @@ def sparse_reorder(sp_input, name=None):
Reordering does not affect the shape of the `SparseTensor`.
- For example, if sp_input has shape `[4, 5]` and `indices` / `values`:
+ For example, if `sp_input` has shape `[4, 5]` and `indices` / `values`:
[0, 3]: b
[0, 1]: a
@@ -334,8 +334,8 @@ def sparse_to_indicator(sp_input, vocab_size, name=None):
rank = indices_shape[1]
ids = sp_input.values
- if ids.dtype != types.int64:
- ids = math_ops.cast(ids, types.int64)
+ if ids.dtype != dtypes.int64:
+ ids = math_ops.cast(ids, dtypes.int64)
# Slice off the last dimension of indices, then then tack on the ids
indices_columns_to_preserve = array_ops.slice(
@@ -451,8 +451,8 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
default_value = ops.convert_to_tensor(
default_value, dtype=sp_input.values.dtype)
- num_rows = math_ops.cast(sp_input.shape[0], types.int32)
- all_row_indices = math_ops.cast(math_ops.range(num_rows), types.int64)
+ num_rows = math_ops.cast(sp_input.shape[0], dtypes.int32)
+ all_row_indices = math_ops.cast(math_ops.range(num_rows), dtypes.int64)
empty_row_indices, _ = array_ops.list_diff(
all_row_indices, sp_input.indices[:, 0])
empty_row_indicator = gen_sparse_ops.sparse_to_dense(
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py
index 046625eefa..c6e91dcd71 100644
--- a/tensorflow/python/ops/sparse_ops_test.py
+++ b/tensorflow/python/ops/sparse_ops_test.py
@@ -23,9 +23,9 @@ import tensorflow.python.platform
import numpy as np
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
-from tensorflow.python.framework import types
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import googletest
@@ -41,9 +41,9 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase):
val = np.array([0, 10, 13, 14, 32, 33])
shape = np.array([5, 6])
return ops.SparseTensor(
- constant_op.constant(ind, types.int64),
+ constant_op.constant(ind, dtypes.int64),
constant_op.constant(val, dtype),
- constant_op.constant(shape, types.int64))
+ constant_op.constant(shape, dtypes.int64))
def _SparseTensor_2x3x4(self, dtype):
ind = np.array([
@@ -55,13 +55,13 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase):
val = np.array([1, 10, 12, 103, 111, 113, 122])
shape = np.array([2, 3, 4])
return ops.SparseTensor(
- constant_op.constant(ind, types.int64),
+ constant_op.constant(ind, dtypes.int64),
constant_op.constant(val, dtype),
- constant_op.constant(shape, types.int64))
+ constant_op.constant(shape, dtypes.int64))
def testInt32(self):
with self.test_session(use_gpu=False):
- sp_input = self._SparseTensor_5x6(types.int32)
+ sp_input = self._SparseTensor_5x6(dtypes.int32)
output = sparse_ops.sparse_to_indicator(sp_input, 50).eval()
expected_output = np.zeros((5, 50), dtype=np.bool)
@@ -73,7 +73,7 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase):
def testInt64(self):
with self.test_session(use_gpu=False):
- sp_input = self._SparseTensor_5x6(types.int64)
+ sp_input = self._SparseTensor_5x6(dtypes.int64)
output = sparse_ops.sparse_to_indicator(sp_input, 50).eval()
expected_output = np.zeros((5, 50), dtype=np.bool)
@@ -85,7 +85,7 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase):
def testHigherRank(self):
with self.test_session(use_gpu=False):
- sp_input = self._SparseTensor_2x3x4(types.int64)
+ sp_input = self._SparseTensor_2x3x4(dtypes.int64)
output = sparse_ops.sparse_to_indicator(sp_input, 200).eval()
expected_output = np.zeros((2, 3, 200), dtype=np.bool)
@@ -107,9 +107,9 @@ class SparseRetainTest(test_util.TensorFlowTestCase):
val = np.array([0, 10, 13, 14, 32, 33])
shape = np.array([5, 6])
return ops.SparseTensor(
- constant_op.constant(ind, types.int64),
- constant_op.constant(val, types.int32),
- constant_op.constant(shape, types.int64))
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int32),
+ constant_op.constant(shape, dtypes.int64))
def testBasic(self):
with self.test_session(use_gpu=False) as sess:
@@ -153,9 +153,9 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
val = np.array([0, 10, 13, 14, 32, 33])
shape = np.array([5, 6])
return ops.SparseTensor(
- constant_op.constant(ind, types.int64),
- constant_op.constant(val, types.int32),
- constant_op.constant(shape, types.int64))
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int32),
+ constant_op.constant(shape, dtypes.int64))
def _SparseTensor_String5x6(self):
ind = np.array([
@@ -165,18 +165,18 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
val = np.array(["a", "b", "c", "d", "e", "f"])
shape = np.array([5, 6])
return ops.SparseTensor(
- constant_op.constant(ind, types.int64),
- constant_op.constant(val, types.string),
- constant_op.constant(shape, types.int64))
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.string),
+ constant_op.constant(shape, dtypes.int64))
def _SparseTensor_2x6(self):
ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4]])
val = np.array([0, 10, 13, 14])
shape = np.array([2, 6])
return ops.SparseTensor(
- constant_op.constant(ind, types.int64),
- constant_op.constant(val, types.int32),
- constant_op.constant(shape, types.int64))
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int32),
+ constant_op.constant(shape, dtypes.int64))
def testFillNumber(self):
with self.test_session(use_gpu=False) as sess:
@@ -207,7 +207,8 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(
output.indices,
[[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
- self.assertAllEqual(output.values, ["a", "b", "c", "d", "", "e", "f", ""])
+ self.assertAllEqual(output.values,
+ [b"a", b"b", b"c", b"d", b"", b"e", b"f", b""])
self.assertAllEqual(output.shape, [5, 6])
self.assertAllEqual(empty_row_indicator_out,
np.array([0, 0, 1, 0, 1]).astype(np.bool))
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 8ccbad16c6..c5b490ce4c 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -22,9 +22,9 @@ from __future__ import print_function
import contextlib
import six
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import types
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import logging
@@ -45,7 +45,7 @@ class _VariableStore(object):
"""Create a variable store."""
self._vars = {} # A dictionary of the stored TensorFlow variables.
- def get_variable(self, name, shape=None, dtype=types.float32,
+ def get_variable(self, name, shape=None, dtype=dtypes.float32,
initializer=None, reuse=None, trainable=True,
collections=None):
"""Gets an existing variable with these parameters or create a new one.
@@ -82,7 +82,7 @@ class _VariableStore(object):
or when violating reuse during variable creation.
"""
should_check = reuse is not None
- dtype = types.as_dtype(dtype)
+ dtype = dtypes.as_dtype(dtype)
shape = tensor_shape.as_shape(shape)
if name in self._vars:
# Here we handle the case when returning an existing variable.
@@ -158,7 +158,7 @@ class _VariableScope(object):
"""Set initializer for this scope."""
self._initializer = initializer
- def get_variable(self, var_store, name, shape=None, dtype=types.float32,
+ def get_variable(self, var_store, name, shape=None, dtype=dtypes.float32,
initializer=None, trainable=True, collections=None):
"""Gets an existing variable with this name or create a new one."""
if initializer is None and self._initializer:
@@ -194,7 +194,7 @@ def _get_default_variable_store():
return store
-def get_variable(name, shape=None, dtype=types.float32, initializer=None,
+def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
trainable=True, collections=None):
"""Gets an existing variable with these parameters or create a new one.