diff options
author | 2016-01-20 15:36:06 -0800 | |
---|---|---|
committer | 2016-01-20 17:20:16 -0800 | |
commit | 877fcd1a113797a1c5847dd5fdbef7868addded0 (patch) | |
tree | b41cd402d67458cc9cf60cbe650d46f72ecfa0de | |
parent | db7478e8998f7703c57a75a950c905ec0cb59d7b (diff) |
Prepare to hide tf.tensor_util
1. There is a new tf.unsupported module to hold things which some people use
but which we don't yet support.
2. tf.tensor_util.ConstantValue is now tf.unsupported.constant_value. Most
users use this, but tf.tensor_util.ConstantValue is still available; it
will be removed in a following commit.
3. tensor_util.MakeTensorShapeProto is now make_tensor_shape_proto. It looks
like all users of this access the tensor_util module directly (not through
tf), so for now it is not in unsupported.
This commit does not remove tensor_util from tf.__all__; a few more downstream
users must be changed before that can happen.
Change: 112626961
22 files changed, 205 insertions, 119 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e19cf74aef..93d05516b1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -16,7 +16,10 @@ load("/tensorflow/core/platform/default/build_config", "tf_proto_library_py") py_library( name = "python", - srcs = ["__init__.py"], + srcs = [ + "__init__.py", + "unsupported.py", + ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__pkg__"], deps = [ @@ -352,6 +355,7 @@ py_test( ":framework_test_lib", ":ops", ":platform_test", + "//tensorflow:tensorflow_py", ], ) diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index ec86293890..85632a2fd1 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -58,12 +58,13 @@ from tensorflow.python.client.client_lib import * # Ops from tensorflow.python.ops.standard_ops import * -# Bring learn, nn, image_ops, user_ops, compat as a subpackages +# Bring in subpackages from tensorflow.python.ops import learn from tensorflow.python.ops import nn from tensorflow.python.ops import image_ops as image from tensorflow.python.user_ops import user_ops from tensorflow.python.util import compat +from tensorflow.python import unsupported # Import the names from python/training.py as train.Name. from tensorflow.python.training import training as train @@ -80,7 +81,8 @@ from tensorflow.python.platform import test # Don't export modules except for the few we really want _whitelist = set([app, compat, errors, flags, image, learn, logging, nn, - python_io, resource_loader, test, train, user_ops]) + python_io, resource_loader, test, train, unsupported, + user_ops]) # TODO(b/25561952): tf.tensor_util is DEPRECATED. Please avoid. _whitelist.update([tensor_util]) # pylint: disable=undefined-variable __all__ = [name for name, x in locals().items() if not name.startswith('_') and diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 62d6e7ecfa..733ce54b34 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -45,12 +45,13 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by def get_module_to_name(): - return {tf: 'tf', - tf.errors: 'tf.errors', - tf.image: 'tf.image', - tf.nn: 'tf.nn', - tf.train: 'tf.train', - tf.python_io: 'tf.python_io'} + return {tf: "tf", + tf.errors: "tf.errors", + tf.image: "tf.image", + tf.nn: "tf.nn", + tf.train: "tf.train", + tf.python_io: "tf.python_io", + tf.unsupported: "tf.unsupported",} def all_libraries(module_to_name, members, documented): # A list of (filename, docs.Library) pairs representing the individual files @@ -112,7 +113,8 @@ def all_libraries(module_to_name, members, documented): "Int64List", "Example", "InferenceExample", "FeatureList", "FeatureLists", "RankingExample", "SequenceExample"]), - library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT) + library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT), + library("unsupported", "Unsupported", tf.unsupported), ] _hidden_symbols = ["Event", "Summary", "xrange", diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py index b9f99d685e..e33df6493a 100644 --- a/tensorflow/python/framework/tensor_shape_test.py +++ b/tensorflow/python/framework/tensor_shape_test.py @@ -256,20 +256,20 @@ class ShapeTest(test_util.TensorFlowTestCase): unknown / unknown # pylint: disable=pointless-statement def testConvertFromProto(self): - proto = tensor_util.MakeTensorShapeProto([]) + proto = tensor_util.make_tensor_shape_proto([]) self.assertEqual(tensor_shape.TensorShape([]), tensor_shape.TensorShape(proto)) self.assertEqual(tensor_shape.TensorShape([]), tensor_shape.as_shape(proto)) - proto = tensor_util.MakeTensorShapeProto([1, 37, 42]) + proto = tensor_util.make_tensor_shape_proto([1, 37, 42]) self.assertEqual(tensor_shape.TensorShape([1, 37, 42]), tensor_shape.TensorShape(proto)) self.assertEqual(tensor_shape.TensorShape([1, 37, 42]), tensor_shape.as_shape(proto)) partial_proto_shape = tensor_shape.as_shape( - tensor_util.MakeTensorShapeProto([-1, 37, 42])) + tensor_util.make_tensor_shape_proto([-1, 37, 42])) partial_shape = tensor_shape.TensorShape([None, 37, 42]) self.assertNotEqual(partial_proto_shape, partial_shape) self.assertEqual(partial_proto_shape[0].value, None) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index af6ae4df5a..f4cdf8ebb9 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -126,14 +126,15 @@ def GetNumpyAppendFn(dtype): return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype) -def MakeTensorShapeProto(shape): - """Create a TensorShapeProto. +# TODO(mrry,irving): Make this a method of `TensorShape`. +def make_tensor_shape_proto(shape): + """Converts a list of integers to a `TensorShapeProto`. Args: shape: List of integers representing the dimensions of the tensor. Returns: - A TensorShapeProto. + A `TensorShapeProto`. """ return tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=x) for x in shape]) @@ -368,7 +369,7 @@ def make_tensor_proto(values, dtype=None, shape=None): tensor_proto = tensor_pb2.TensorProto( dtype=numpy_dtype.as_datatype_enum, - tensor_shape=MakeTensorShapeProto(shape)) + tensor_shape=make_tensor_shape_proto(shape)) if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: tensor_proto.tensor_content = nparray.tostring() @@ -494,7 +495,7 @@ def ShapeEquals(tensor_proto, shape): return all(x == y for x, y in zip(tensor_shape_list, shape)) -def ConstantValue(tensor): +def constant_value(tensor): """Returns the constant value of the given tensor, if efficiently calculable. This function attempts to partially evaluate the given tensor, and @@ -539,32 +540,38 @@ def ConstantValue(tensor): else: return None elif tensor.op.type == "Range": - start = ConstantValue(tensor.op.inputs[0]) + start = constant_value(tensor.op.inputs[0]) if start is None: return None - limit = ConstantValue(tensor.op.inputs[1]) + limit = constant_value(tensor.op.inputs[1]) if limit is None: return None - delta = ConstantValue(tensor.op.inputs[2]) + delta = constant_value(tensor.op.inputs[2]) if delta is None: return None return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype) elif tensor.op.type == "Cast": - pre_cast = ConstantValue(tensor.op.inputs[0]) + pre_cast = constant_value(tensor.op.inputs[0]) if pre_cast is None: return None cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) return pre_cast.astype(cast_dtype.as_numpy_dtype) elif tensor.op.type == "Concat": - dim = ConstantValue(tensor.op.inputs[0]) + dim = constant_value(tensor.op.inputs[0]) if dim is None: return None values = [] for x in tensor.op.inputs[1:]: - value = ConstantValue(x) + value = constant_value(x) if value is None: return None values.append(value) return np.concatenate(values, axis=dim) else: return None + + +# Add some temporary backwards compatibility aliases until all downstream code +# is changed. TODO(irving): Remove these aliases. +ConstantValue = constant_value +MakeTensorShapeProto = make_tensor_shape_proto diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 6d3823741d..d22c740e63 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -21,18 +21,13 @@ from __future__ import print_function import tensorflow.python.platform import numpy as np +import tensorflow as tf -from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import constant_op -from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.platform import googletest -class TensorUtilTest(test_util.TensorFlowTestCase): +class TensorUtilTest(tf.test.TestCase): def testFloat(self): t = tensor_util.make_tensor_proto(10.0) @@ -57,7 +52,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) def testFloatTyped(self): - t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) + t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=tf.float32) self.assertProtoEquals(""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } @@ -68,7 +63,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) def testFloatTypeCoerce(self): - t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32) + t = tensor_util.make_tensor_proto([10, 20, 30], dtype=tf.float32) self.assertProtoEquals(""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } @@ -80,7 +75,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testFloatTypeCoerceNdarray(self): arr = np.asarray([10, 20, 30], dtype="int") - t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32) + t = tensor_util.make_tensor_proto(arr, dtype=tf.float32) self.assertProtoEquals(""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } @@ -137,7 +132,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testFloatTypesWithImplicitRepeat(self): for dtype, nptype in [ - (dtypes.float32, np.float32), (dtypes.float64, np.float64)]: + (tf.float32, np.float32), (tf.float64, np.float64)]: t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype) a = tensor_util.MakeNdarray(t) self.assertAllClose(np.array([[10.0, 10.0, 10.0, 10.0], @@ -168,10 +163,10 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testIntTypes(self): for dtype, nptype in [ - (dtypes.int32, np.int32), - (dtypes.uint8, np.uint8), - (dtypes.int16, np.int16), - (dtypes.int8, np.int8)]: + (tf.int32, np.int32), + (tf.uint8, np.uint8), + (tf.int16, np.int16), + (tf.int8, np.int8)]: # Test with array. t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype) self.assertEquals(dtype, t.dtype) @@ -189,11 +184,11 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testIntTypesWithImplicitRepeat(self): for dtype, nptype in [ - (dtypes.int64, np.int64), - (dtypes.int32, np.int32), - (dtypes.uint8, np.uint8), - (dtypes.int16, np.int16), - (dtypes.int8, np.int8)]: + (tf.int64, np.int64), + (tf.int32, np.int32), + (tf.uint8, np.uint8), + (tf.int16, np.int16), + (tf.int8, np.int8)]: t = tensor_util.make_tensor_proto([10], shape=[3, 4], dtype=dtype) a = tensor_util.MakeNdarray(t) self.assertAllEqual(np.array([[10, 10, 10, 10], @@ -201,7 +196,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): [10, 10, 10, 10]], dtype=nptype), a) def testLong(self): - t = tensor_util.make_tensor_proto(10, dtype=dtypes.int64) + t = tensor_util.make_tensor_proto(10, dtype=tf.int64) self.assertProtoEquals(""" dtype: DT_INT64 tensor_shape {} @@ -213,7 +208,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testLongN(self): t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3], - dtype=dtypes.int64) + dtype=tf.int64) self.assertProtoEquals(""" dtype: DT_INT64 tensor_shape { dim { size: 1 } dim { size: 3 } } @@ -279,7 +274,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): self.assertAllEqual(np.array([[b"a", b"ab"], [b"abc", b"abcd"]]), a) def testComplex(self): - t = tensor_util.make_tensor_proto((1+2j), dtype=dtypes.complex64) + t = tensor_util.make_tensor_proto((1+2j), dtype=tf.complex64) self.assertProtoEquals(""" dtype: DT_COMPLEX64 tensor_shape {} @@ -292,7 +287,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testComplexWithImplicitRepeat(self): t = tensor_util.make_tensor_proto((1+1j), shape=[3, 4], - dtype=dtypes.complex64) + dtype=tf.complex64) a = tensor_util.MakeNdarray(t) self.assertAllClose(np.array([[(1+1j), (1+1j), (1+1j), (1+1j)], [(1+1j), (1+1j), (1+1j), (1+1j)], @@ -301,7 +296,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testComplexN(self): t = tensor_util.make_tensor_proto([(1+2j), (3+4j), (5+6j)], shape=[1, 3], - dtype=dtypes.complex64) + dtype=tf.complex64) self.assertProtoEquals(""" dtype: DT_COMPLEX64 tensor_shape { dim { size: 1 } dim { size: 3 } } @@ -318,7 +313,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testComplexNpArray(self): t = tensor_util.make_tensor_proto( - np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=dtypes.complex64) + np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=tf.complex64) # scomplex_val are real_0, imag_0, real_1, imag_1, ... self.assertProtoEquals(""" dtype: DT_COMPLEX64 @@ -357,81 +352,81 @@ class TensorUtilTest(test_util.TensorFlowTestCase): self.assertTrue(tensor_util.ShapeEquals(t, [2, 2])) self.assertTrue(tensor_util.ShapeEquals(t, (2, 2))) self.assertTrue( - tensor_util.ShapeEquals(t, tensor_util.MakeTensorShapeProto([2, 2]))) + tensor_util.ShapeEquals(t, tensor_util.make_tensor_shape_proto([2, 2]))) self.assertFalse(tensor_util.ShapeEquals(t, [5, 3])) self.assertFalse(tensor_util.ShapeEquals(t, [1, 4])) self.assertFalse(tensor_util.ShapeEquals(t, [4])) -class ConstantValueTest(test_util.TensorFlowTestCase): +class ConstantValueTest(tf.test.TestCase): def testConstant(self): np_val = np.random.rand(3, 4, 7).astype(np.float32) - tf_val = constant_op.constant(np_val) - self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val)) + tf_val = tf.constant(np_val) + self.assertAllClose(np_val, tf.unsupported.constant_value(tf_val)) np_val = np.random.rand(3, 0, 7).astype(np.float32) - tf_val = constant_op.constant(np_val) - self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val)) + tf_val = tf.constant(np_val) + self.assertAllClose(np_val, tf.unsupported.constant_value(tf_val)) def testUnknown(self): - tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=dtypes.float32) - self.assertIs(None, tensor_util.ConstantValue(tf_val)) + tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=tf.float32) + self.assertIs(None, tf.unsupported.constant_value(tf_val)) def testShape(self): np_val = np.array([1, 2, 3], dtype=np.int32) - tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) - c_val = tensor_util.ConstantValue(tf_val) + tf_val = tf.shape(tf.constant(0.0, shape=[1, 2, 3])) + c_val = tf.unsupported.constant_value(tf_val) self.assertAllEqual(np_val, c_val) self.assertEqual(np.int32, c_val.dtype) def testSize(self): - tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3])) - c_val = tensor_util.ConstantValue(tf_val) + tf_val = tf.size(tf.constant(0.0, shape=[1, 2, 3])) + c_val = tf.unsupported.constant_value(tf_val) self.assertEqual(6, c_val) def testSizeOfScalar(self): - tf_val = array_ops.size(constant_op.constant(0.0)) - c_val = tensor_util.ConstantValue(tf_val) + tf_val = tf.size(tf.constant(0.0)) + c_val = tf.unsupported.constant_value(tf_val) self.assertEqual(1, c_val) self.assertEqual(np.int32, type(c_val)) def testRank(self): - tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3])) - c_val = tensor_util.ConstantValue(tf_val) + tf_val = tf.rank(tf.constant(0.0, shape=[1, 2, 3])) + c_val = tf.unsupported.constant_value(tf_val) self.assertEqual(3, c_val) def testCast(self): np_val = np.random.rand(3, 4, 7).astype(np.float32) - tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64) - c_val = tensor_util.ConstantValue(tf_val) + tf_val = tf.cast(tf.constant(np_val), tf.float64) + c_val = tf.unsupported.constant_value(tf_val) self.assertAllClose(np_val.astype(np.float64), c_val) np_val = np.random.rand(3, 0, 7).astype(np.float32) - tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64) - c_val = tensor_util.ConstantValue(tf_val) + tf_val = tf.cast(tf.constant(np_val), tf.float64) + c_val = tf.unsupported.constant_value(tf_val) self.assertAllClose(np_val.astype(np.float64), c_val) def testConcat(self): np_val = np.random.rand(3, 4, 7).astype(np.float32) - tf_val = array_ops.concat( + tf_val = tf.concat( 0, [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]]) - c_val = tensor_util.ConstantValue(tf_val) + c_val = tf.unsupported.constant_value(tf_val) self.assertAllClose(np_val, c_val) - tf_val = array_ops.concat( - array_ops.placeholder(dtypes.int32), + tf_val = tf.concat( + tf.placeholder(tf.int32), [np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]]) - c_val = tensor_util.ConstantValue(tf_val) + c_val = tf.unsupported.constant_value(tf_val) self.assertIs(None, c_val) - tf_val = array_ops.concat( + tf_val = tf.concat( 1, - [np_val[0, :, :], array_ops.placeholder(dtypes.float32), + [np_val[0, :, :], tf.placeholder(tf.float32), np_val[2, :, :]]) - c_val = tensor_util.ConstantValue(tf_val) + c_val = tf.unsupported.constant_value(tf_val) self.assertIs(None, c_val) if __name__ == "__main__": - googletest.main() + tf.test.main() diff --git a/tensorflow/python/kernel_tests/learn_test.py b/tensorflow/python/kernel_tests/learn_test.py index e74f7ac0df..802e1f31a2 100644 --- a/tensorflow/python/kernel_tests/learn_test.py +++ b/tensorflow/python/kernel_tests/learn_test.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import tensor_util def assert_summary_scope(regexp): """Assert that all generated summaries match regexp.""" for summary in tf.get_collection(tf.GraphKeys.SUMMARIES): - tag = tensor_util.ConstantValue(summary.op.inputs[0]) + tag = tf.unsupported.constant_value(summary.op.inputs[0]) assert tag is not None, 'All summaries must have constant tags' tag = str(tag) assert isinstance(tag[0], six.string_types), tag[0] diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 7304f5ddf9..0eb49e76c2 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -74,7 +74,7 @@ def _ConcatGrad(op, grad): for (begin, size) in zip(offset, sizes): out_grads.append(array_ops.slice(grad, begin, size)) elif isinstance(grad, ops.IndexedSlices): - concat_dim_static = tensor_util.ConstantValue(concat_dim) + concat_dim_static = tensor_util.constant_value(concat_dim) if concat_dim_static is None: raise ValueError("Can only compute IndexedSlices gradient with " "statically-known concat_dim") diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 084398ec64..01484a2b88 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -325,7 +325,7 @@ def _UnpackShape(op): @ops.RegisterShape("Concat") def _ConcatShape(op): - concat_dim = tensor_util.ConstantValue(op.inputs[0]) + concat_dim = tensor_util.constant_value(op.inputs[0]) if concat_dim is None: # Return an unknown shape with the same rank as the inputs, or an # unknown rank if no input's rank is known. @@ -713,8 +713,8 @@ def _SliceShape(op): ndims = rank_vector_shape.num_elements() if ndims is not None: input_shape.assert_has_rank(ndims) - begin_value = tensor_util.ConstantValue(op.inputs[1]) - sizes_value = tensor_util.ConstantValue(op.inputs[2]) + begin_value = tensor_util.constant_value(op.inputs[1]) + sizes_value = tensor_util.constant_value(op.inputs[2]) if sizes_value is not None: returned_dims = [] for i, slice_size in enumerate(sizes_value.ravel()): @@ -795,7 +795,7 @@ def _ExpandDimsShape(op): input_shape = op.inputs[0].get_shape() if input_shape.dims is None: return [tensor_shape.unknown_shape()] - dim = tensor_util.ConstantValue(op.inputs[1]) + dim = tensor_util.constant_value(op.inputs[1]) input_ndims = input_shape.ndims if dim < -input_ndims - 1 or dim > input_ndims: raise ValueError( @@ -865,7 +865,7 @@ def _ReshapeShape(op): else: num_elements = tensor_shape.Dimension(None) new_shape_shape = op.inputs[1].get_shape().with_rank_at_most(1) - new_shape = tensor_util.ConstantValue(op.inputs[1]) + new_shape = tensor_util.constant_value(op.inputs[1]) if new_shape is None: # Attempt to infer the rank of the output from the length of # new_shape. @@ -908,7 +908,7 @@ def _ReshapeShape(op): @ops.RegisterShape("BroadcastGradientArgs") def _BroadcastGradientArgsShape(op): """Shape function for the BroadcastGradientArgs op.""" - # TODO(mrry): Implement ConstantValue for BroadcastGradientArgs? + # TODO(mrry): Implement constant_value for BroadcastGradientArgs? op.inputs[0].get_shape().assert_has_rank(1) op.inputs[1].get_shape().assert_has_rank(1) return [tensor_shape.vector(None), tensor_shape.vector(None)] @@ -929,7 +929,7 @@ def _FillShape(op): """ dimensions_shape = op.inputs[0].get_shape().with_rank_at_most(1) op.inputs[1].get_shape().assert_is_compatible_with(tensor_shape.scalar()) - fill_dims = tensor_util.ConstantValue(op.inputs[0]) + fill_dims = tensor_util.constant_value(op.inputs[0]) if fill_dims is None: # Attempt to infer the rank of the output from the length of # dimensions. @@ -981,7 +981,7 @@ def _PadShape(op): input_shape = input_shape.with_rank(paddings_shape[0].value) paddings_shape = paddings_shape.merge_with( tensor_shape.matrix(input_shape.ndims, 2)) - paddings = tensor_util.ConstantValue(op.inputs[1]) + paddings = tensor_util.constant_value(op.inputs[1]) if paddings is None: return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] else: @@ -1062,7 +1062,7 @@ def _TransposeShape(op): input_shape = op.inputs[0].get_shape() transpose_shape = op.inputs[1].get_shape().merge_with(tensor_shape.vector( input_shape.ndims)) - transpose_vec = tensor_util.ConstantValue(op.inputs[1]) + transpose_vec = tensor_util.constant_value(op.inputs[1]) if transpose_vec is None: return [tensor_shape.unknown_shape(ndims=transpose_shape[0].value)] else: @@ -1073,7 +1073,7 @@ def _TransposeShape(op): @ops.RegisterShape("Split") def _SplitShape(op): """Shape function for the Split op.""" - split_dim = tensor_util.ConstantValue(op.inputs[0]) + split_dim = tensor_util.constant_value(op.inputs[0]) num_split = len(op.outputs) input_shape = op.inputs[1].get_shape() if split_dim is None: @@ -1114,7 +1114,7 @@ def _TileShape(op): """ multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1) input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements()) - multiples = tensor_util.ConstantValue(op.inputs[1]) + multiples = tensor_util.constant_value(op.inputs[1]) if multiples is None: return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] else: @@ -1130,7 +1130,7 @@ def _TileGradShape(op): """Shape function for the TileGrad op.""" multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1) input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements()) - multiples = tensor_util.ConstantValue(op.inputs[1]) + multiples = tensor_util.constant_value(op.inputs[1]) if multiples is None: return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] else: @@ -1230,8 +1230,8 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"): @ops.RegisterShape("EditDistance") def _EditDistanceShape(op): """Shape function for the EditDistance op.""" - hypothesis_shape = tensor_util.ConstantValue(op.inputs[2]) - truth_shape = tensor_util.ConstantValue(op.inputs[5]) + hypothesis_shape = tensor_util.constant_value(op.inputs[2]) + truth_shape = tensor_util.constant_value(op.inputs[5]) if hypothesis_shape is not None and truth_shape is not None: if len(hypothesis_shape) != len(truth_shape): raise ValueError( diff --git a/tensorflow/python/ops/attention_ops.py b/tensorflow/python/ops/attention_ops.py index 3bc13ca9b5..20caf5df75 100644 --- a/tensorflow/python/ops/attention_ops.py +++ b/tensorflow/python/ops/attention_ops.py @@ -42,7 +42,7 @@ def _ExtractGlimpseShape(op): offsets_shape = op.inputs[2].get_shape().merge_with( input_shape[:1].concatenate([2])) offsets_shape = offsets_shape - size_value = tensor_util.ConstantValue(op.inputs[1]) + size_value = tensor_util.constant_value(op.inputs[1]) if size_value is not None: height = size_value[0] width = size_value[1] diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index c8bdd5317f..6454cf104b 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -300,7 +300,7 @@ class QueueBase(object): # NOTE(mrry): Not using a shape function because we need access to # the Queue object. op = ret[0].op - batch_dim = tensor_shape.Dimension(tensor_util.ConstantValue(op.inputs[1])) + batch_dim = tensor_shape.Dimension(tensor_util.constant_value(op.inputs[1])) for output, shape in zip(op.values(), self._shapes): output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape)) diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index a99a8ea2f5..78ecc98859 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -80,7 +80,7 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False): % str(value)) # TODO(mrry): Consider adding static shape information to # IndexedSlices, to avoid using numpy here. - dense_shape_value = tensor_util.ConstantValue(value.dense_shape) + dense_shape_value = tensor_util.constant_value(value.dense_shape) if dense_shape_value is not None: num_elements = np.prod(dense_shape_value) if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS: diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py index f2e58277ec..668adb1c26 100644 --- a/tensorflow/python/ops/image_grad.py +++ b/tensorflow/python/ops/image_grad.py @@ -70,7 +70,7 @@ def _ResizeBilinearGrad(op, grad): def _ResizeShape(op): """Shape function for the resize grad ops.""" input_shape = op.inputs[0].get_shape().with_rank(4) - size = tensor_util.ConstantValue(op.inputs[1]) + size = tensor_util.constant_value(op.inputs[1]) if size is not None: height = size[0] width = size[1] diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index abc1739828..94a288afce 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -775,7 +775,7 @@ ops.RegisterShape('AdjustContrastv2')( def _ResizeShape(op): """Shape function for the resize_bilinear and resize_nearest_neighbor ops.""" input_shape = op.inputs[0].get_shape().with_rank(4) - size = tensor_util.ConstantValue(op.inputs[1]) + size = tensor_util.constant_value(op.inputs[1]) if size is not None: height = size[0] width = size[1] @@ -810,7 +810,7 @@ def _random_cropShape(op): input_shape = op.inputs[0].get_shape().with_rank(3) unused_size_shape = op.inputs[1].get_shape().merge_with( tensor_shape.vector(2)) - size = tensor_util.ConstantValue(op.inputs[1]) + size = tensor_util.constant_value(op.inputs[1]) if size is not None: height = size[0] width = size[1] diff --git a/tensorflow/python/ops/learn.py b/tensorflow/python/ops/learn.py index 50b5c56b42..016b14d67d 100644 --- a/tensorflow/python/ops/learn.py +++ b/tensorflow/python/ops/learn.py @@ -93,7 +93,7 @@ def xavier_initializer(n_inputs, n_outputs, uniform=True): def _assert_summary_tag_unique(tag): for summary in ops.get_collection(ops.GraphKeys.SUMMARIES): - old_tag = tensor_util.ConstantValue(summary.op.inputs[0]) + old_tag = tensor_util.constant_value(summary.op.inputs[0]) if tag == str(old_tag): raise ValueError('Conflict with summary tag: %s exists on summary %s %s' % (tag, summary, old_tag)) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index c2a19fe31a..3abc04e51b 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -609,9 +609,9 @@ def range(start, limit=None, delta=1, name="range"): @ops.RegisterShape("Range") def _RangeShape(op): - start_value = tensor_util.ConstantValue(op.inputs[0]) - limit_value = tensor_util.ConstantValue(op.inputs[1]) - delta_value = tensor_util.ConstantValue(op.inputs[2]) + start_value = tensor_util.constant_value(op.inputs[0]) + limit_value = tensor_util.constant_value(op.inputs[1]) + delta_value = tensor_util.constant_value(op.inputs[2]) if start_value is None or limit_value is None or delta_value is None: return [tensor_shape.vector(None)] else: @@ -1280,7 +1280,7 @@ def _ArgOpShape(op): elif input_shape.ndims <= 1: return [tensor_shape.scalar()] - dimension = tensor_util.ConstantValue(op.inputs[1]) + dimension = tensor_util.constant_value(op.inputs[1]) if dimension is None: return [tensor_shape.unknown_shape(ndims=input_shape.ndims - 1)] elif 0 <= dimension and dimension < input_shape.ndims: @@ -1306,7 +1306,7 @@ def _ArgOpShape(op): def _ReductionShape(op): """Common shape function for reduction ops.""" input_shape = op.inputs[0].get_shape() - reduction_indices = tensor_util.ConstantValue(op.inputs[1]) + reduction_indices = tensor_util.constant_value(op.inputs[1]) keep_dims = op.get_attr("keep_dims") if reduction_indices is None or input_shape.ndims is None: if keep_dims: @@ -1375,7 +1375,7 @@ def _SparseSegmentReductionGradShape(op): unused_segment_ids_shape = op.inputs[2].get_shape().merge_with(indices_shape) unused_output_dim0_shape = op.inputs[3].get_shape().merge_with( tensor_shape.scalar()) - output_dim0 = tensor_util.ConstantValue(op.inputs[3]) + output_dim0 = tensor_util.constant_value(op.inputs[3]) if output_dim0 is not None: dim0 = output_dim0[0] else: @@ -1393,12 +1393,12 @@ def _UnsortedSegmentSumShape(op): if mid is None: return [tensor_shape.unknown_shape()] else: - num_segments = tensor_util.ConstantValue(op.inputs[2]) + num_segments = tensor_util.constant_value(op.inputs[2]) return [tensor_shape.TensorShape([num_segments]).concatenate( data_shape[mid:])] @ops.RegisterShape("LinSpace") def _LinspaceShape(op): - num = tensor_util.ConstantValue(op.inputs[2]) + num = tensor_util.constant_value(op.inputs[2]) return [tensor_shape.vector(num)] diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index ad05f823fb..9df7342cc1 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -296,7 +296,7 @@ def _TopKShape(op): """Shape function for TopK and TopKV2 ops.""" input_shape = op.inputs[0].get_shape().with_rank_at_least(1) if len(op.inputs) >= 2: - k = tensor_util.ConstantValue(op.inputs[1]) + k = tensor_util.constant_value(op.inputs[1]) else: k = op.get_attr("k") last = input_shape[-1].value @@ -352,7 +352,7 @@ def _MaxPoolWithArgMaxShape(op): @ops.RegisterShape("AvgPoolGrad") def _AvgPoolGradShape(op): """Shape function for the AvgPoolGrad op.""" - orig_input_shape = tensor_util.ConstantValue(op.inputs[0]) + orig_input_shape = tensor_util.constant_value(op.inputs[0]) if orig_input_shape is not None: return [tensor_shape.TensorShape(orig_input_shape.tolist())] else: @@ -366,7 +366,7 @@ def _AvgPoolGradShape(op): @ops.RegisterShape("Conv2DBackpropFilter") def _Conv2DBackpropFilterShape(op): """Shape function for the Conv2DBackpropFilter op.""" - filter_shape = tensor_util.ConstantValue(op.inputs[1]) + filter_shape = tensor_util.constant_value(op.inputs[1]) if filter_shape is not None: return [tensor_shape.TensorShape(filter_shape.tolist())] else: @@ -380,7 +380,7 @@ def _Conv2DBackpropFilterShape(op): @ops.RegisterShape("Conv2DBackpropInput") def _Conv2DBackpropInputShape(op): """Shape function for the Conv2DBackpropInput op.""" - input_shape = tensor_util.ConstantValue(op.inputs[0]) + input_shape = tensor_util.constant_value(op.inputs[0]) if input_shape is not None: return [tensor_shape.TensorShape(input_shape.tolist())] else: diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index e210f34df1..ac81eaf676 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -391,7 +391,7 @@ def _parse_example_raw(serialized, dense_defaults_vec.append(default_value) - dense_shapes = [tensor_util.MakeTensorShapeProto(shape) + dense_shapes = [tensor_util.make_tensor_shape_proto(shape) if isinstance(shape, (list, tuple)) else shape for shape in dense_shapes] @@ -825,10 +825,10 @@ def _parse_single_sequence_example_raw(serialized, context_dense_defaults_vec.append(default_value) - context_dense_shapes = [tensor_util.MakeTensorShapeProto(shape) + context_dense_shapes = [tensor_util.make_tensor_shape_proto(shape) if isinstance(shape, (list, tuple)) else shape for shape in context_dense_shapes] - feature_list_dense_shapes = [tensor_util.MakeTensorShapeProto(shape) + feature_list_dense_shapes = [tensor_util.make_tensor_shape_proto(shape) if isinstance(shape, (list, tuple)) else shape for shape in feature_list_dense_shapes] diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 428d591acc..dac67b4fb4 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -217,7 +217,7 @@ ops.NoGradient("RandomUniform") @ops.RegisterShape("RandomUniform") @ops.RegisterShape("RandomUniformInt") def _RandomShape(op): - shape_val = tensor_util.ConstantValue(op.inputs[0]) + shape_val = tensor_util.constant_value(op.inputs[0]) if shape_val is not None: return [tensor_shape.TensorShape(shape_val.tolist())] else: diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 1a7ce8e40f..1a9f9954a0 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -296,7 +296,7 @@ def _SparseSplitShape(op): @ops.RegisterShape("SparseToDense") def _SparseToDenseShape(op): - input_shape = tensor_util.ConstantValue(op.inputs[1]) + input_shape = tensor_util.constant_value(op.inputs[1]) if input_shape is not None: if np.ndim(input_shape) > 1: raise ValueError("Input shape should be a vector") diff --git a/tensorflow/python/unsupported.py b/tensorflow/python/unsupported.py new file mode 100644 index 0000000000..23c40766ab --- /dev/null +++ b/tensorflow/python/unsupported.py @@ -0,0 +1,34 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module includes unsupported and experimental features which are exposed +but not part of the supported public API. Anything in this module can change +without notice, even across a patch release. + +## Utilities + +@@constant_value +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.platform +from tensorflow.python.util.all_util import make_all + +# pylint: disable=unused-import +from tensorflow.python.framework.tensor_util import constant_value + +__all__ = make_all(__name__) diff --git a/tensorflow/python/util/all_util.py b/tensorflow/python/util/all_util.py new file mode 100644 index 0000000000..e88e4ea847 --- /dev/null +++ b/tensorflow/python/util/all_util.py @@ -0,0 +1,42 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Generate __all__ from a module docstring.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import sys + +_reference_pattern = re.compile(r'^@@(\w+)$') + + +def make_all(module_name): + """Generate `__all__` from a module's docstring. + + Usage: `make_all(__name__)`. The caller module must have a docstring, + and `__all__` will contain all symbols with `@@` references. + + Args: + module_name: The name of the module (usually `__name__`). + + Returns: + A list suitable for use as `__all__`. + """ + doc = sys.modules[module_name].__doc__ + return [m.group(1) for m in _reference_pattern.finditer(doc)] + +__all__ = ['make_all'] |