aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-01-20 15:36:06 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-20 17:20:16 -0800
commit877fcd1a113797a1c5847dd5fdbef7868addded0 (patch)
treeb41cd402d67458cc9cf60cbe650d46f72ecfa0de
parentdb7478e8998f7703c57a75a950c905ec0cb59d7b (diff)
Prepare to hide tf.tensor_util
1. There is a new tf.unsupported module to hold things which some people use but which we don't yet support. 2. tf.tensor_util.ConstantValue is now tf.unsupported.constant_value. Most users use this, but tf.tensor_util.ConstantValue is still available; it will be removed in a following commit. 3. tensor_util.MakeTensorShapeProto is now make_tensor_shape_proto. It looks like all users of this access the tensor_util module directly (not through tf), so for now it is not in unsupported. This commit does not remove tensor_util from tf.__all__; a few more downstream users must be changed before that can happen. Change: 112626961
-rw-r--r--tensorflow/python/BUILD6
-rw-r--r--tensorflow/python/__init__.py6
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py16
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py6
-rw-r--r--tensorflow/python/framework/tensor_util.py29
-rw-r--r--tensorflow/python/framework/tensor_util_test.py105
-rw-r--r--tensorflow/python/kernel_tests/learn_test.py2
-rw-r--r--tensorflow/python/ops/array_grad.py2
-rw-r--r--tensorflow/python/ops/array_ops.py28
-rw-r--r--tensorflow/python/ops/attention_ops.py2
-rw-r--r--tensorflow/python/ops/data_flow_ops.py2
-rw-r--r--tensorflow/python/ops/gradients.py2
-rw-r--r--tensorflow/python/ops/image_grad.py2
-rw-r--r--tensorflow/python/ops/image_ops.py4
-rw-r--r--tensorflow/python/ops/learn.py2
-rw-r--r--tensorflow/python/ops/math_ops.py16
-rw-r--r--tensorflow/python/ops/nn_ops.py8
-rw-r--r--tensorflow/python/ops/parsing_ops.py6
-rw-r--r--tensorflow/python/ops/random_ops.py2
-rw-r--r--tensorflow/python/ops/sparse_ops.py2
-rw-r--r--tensorflow/python/unsupported.py34
-rw-r--r--tensorflow/python/util/all_util.py42
22 files changed, 205 insertions, 119 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e19cf74aef..93d05516b1 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -16,7 +16,10 @@ load("/tensorflow/core/platform/default/build_config", "tf_proto_library_py")
py_library(
name = "python",
- srcs = ["__init__.py"],
+ srcs = [
+ "__init__.py",
+ "unsupported.py",
+ ],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__pkg__"],
deps = [
@@ -352,6 +355,7 @@ py_test(
":framework_test_lib",
":ops",
":platform_test",
+ "//tensorflow:tensorflow_py",
],
)
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index ec86293890..85632a2fd1 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -58,12 +58,13 @@ from tensorflow.python.client.client_lib import *
# Ops
from tensorflow.python.ops.standard_ops import *
-# Bring learn, nn, image_ops, user_ops, compat as a subpackages
+# Bring in subpackages
from tensorflow.python.ops import learn
from tensorflow.python.ops import nn
from tensorflow.python.ops import image_ops as image
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
+from tensorflow.python import unsupported
# Import the names from python/training.py as train.Name.
from tensorflow.python.training import training as train
@@ -80,7 +81,8 @@ from tensorflow.python.platform import test
# Don't export modules except for the few we really want
_whitelist = set([app, compat, errors, flags, image, learn, logging, nn,
- python_io, resource_loader, test, train, user_ops])
+ python_io, resource_loader, test, train, unsupported,
+ user_ops])
# TODO(b/25561952): tf.tensor_util is DEPRECATED. Please avoid.
_whitelist.update([tensor_util]) # pylint: disable=undefined-variable
__all__ = [name for name, x in locals().items() if not name.startswith('_') and
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index 62d6e7ecfa..733ce54b34 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -45,12 +45,13 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
def get_module_to_name():
- return {tf: 'tf',
- tf.errors: 'tf.errors',
- tf.image: 'tf.image',
- tf.nn: 'tf.nn',
- tf.train: 'tf.train',
- tf.python_io: 'tf.python_io'}
+ return {tf: "tf",
+ tf.errors: "tf.errors",
+ tf.image: "tf.image",
+ tf.nn: "tf.nn",
+ tf.train: "tf.train",
+ tf.python_io: "tf.python_io",
+ tf.unsupported: "tf.unsupported",}
def all_libraries(module_to_name, members, documented):
# A list of (filename, docs.Library) pairs representing the individual files
@@ -112,7 +113,8 @@ def all_libraries(module_to_name, members, documented):
"Int64List", "Example", "InferenceExample",
"FeatureList", "FeatureLists",
"RankingExample", "SequenceExample"]),
- library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT)
+ library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT),
+ library("unsupported", "Unsupported", tf.unsupported),
]
_hidden_symbols = ["Event", "Summary", "xrange",
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index b9f99d685e..e33df6493a 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -256,20 +256,20 @@ class ShapeTest(test_util.TensorFlowTestCase):
unknown / unknown # pylint: disable=pointless-statement
def testConvertFromProto(self):
- proto = tensor_util.MakeTensorShapeProto([])
+ proto = tensor_util.make_tensor_shape_proto([])
self.assertEqual(tensor_shape.TensorShape([]),
tensor_shape.TensorShape(proto))
self.assertEqual(tensor_shape.TensorShape([]),
tensor_shape.as_shape(proto))
- proto = tensor_util.MakeTensorShapeProto([1, 37, 42])
+ proto = tensor_util.make_tensor_shape_proto([1, 37, 42])
self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
tensor_shape.TensorShape(proto))
self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
tensor_shape.as_shape(proto))
partial_proto_shape = tensor_shape.as_shape(
- tensor_util.MakeTensorShapeProto([-1, 37, 42]))
+ tensor_util.make_tensor_shape_proto([-1, 37, 42]))
partial_shape = tensor_shape.TensorShape([None, 37, 42])
self.assertNotEqual(partial_proto_shape, partial_shape)
self.assertEqual(partial_proto_shape[0].value, None)
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index af6ae4df5a..f4cdf8ebb9 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -126,14 +126,15 @@ def GetNumpyAppendFn(dtype):
return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
-def MakeTensorShapeProto(shape):
- """Create a TensorShapeProto.
+# TODO(mrry,irving): Make this a method of `TensorShape`.
+def make_tensor_shape_proto(shape):
+ """Converts a list of integers to a `TensorShapeProto`.
Args:
shape: List of integers representing the dimensions of the tensor.
Returns:
- A TensorShapeProto.
+ A `TensorShapeProto`.
"""
return tensor_shape_pb2.TensorShapeProto(
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=x) for x in shape])
@@ -368,7 +369,7 @@ def make_tensor_proto(values, dtype=None, shape=None):
tensor_proto = tensor_pb2.TensorProto(
dtype=numpy_dtype.as_datatype_enum,
- tensor_shape=MakeTensorShapeProto(shape))
+ tensor_shape=make_tensor_shape_proto(shape))
if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
tensor_proto.tensor_content = nparray.tostring()
@@ -494,7 +495,7 @@ def ShapeEquals(tensor_proto, shape):
return all(x == y for x, y in zip(tensor_shape_list, shape))
-def ConstantValue(tensor):
+def constant_value(tensor):
"""Returns the constant value of the given tensor, if efficiently calculable.
This function attempts to partially evaluate the given tensor, and
@@ -539,32 +540,38 @@ def ConstantValue(tensor):
else:
return None
elif tensor.op.type == "Range":
- start = ConstantValue(tensor.op.inputs[0])
+ start = constant_value(tensor.op.inputs[0])
if start is None:
return None
- limit = ConstantValue(tensor.op.inputs[1])
+ limit = constant_value(tensor.op.inputs[1])
if limit is None:
return None
- delta = ConstantValue(tensor.op.inputs[2])
+ delta = constant_value(tensor.op.inputs[2])
if delta is None:
return None
return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
elif tensor.op.type == "Cast":
- pre_cast = ConstantValue(tensor.op.inputs[0])
+ pre_cast = constant_value(tensor.op.inputs[0])
if pre_cast is None:
return None
cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
return pre_cast.astype(cast_dtype.as_numpy_dtype)
elif tensor.op.type == "Concat":
- dim = ConstantValue(tensor.op.inputs[0])
+ dim = constant_value(tensor.op.inputs[0])
if dim is None:
return None
values = []
for x in tensor.op.inputs[1:]:
- value = ConstantValue(x)
+ value = constant_value(x)
if value is None:
return None
values.append(value)
return np.concatenate(values, axis=dim)
else:
return None
+
+
+# Add some temporary backwards compatibility aliases until all downstream code
+# is changed. TODO(irving): Remove these aliases.
+ConstantValue = constant_value
+MakeTensorShapeProto = make_tensor_shape_proto
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 6d3823741d..d22c740e63 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -21,18 +21,13 @@ from __future__ import print_function
import tensorflow.python.platform
import numpy as np
+import tensorflow as tf
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import constant_op
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
-from tensorflow.python.platform import googletest
-class TensorUtilTest(test_util.TensorFlowTestCase):
+class TensorUtilTest(tf.test.TestCase):
def testFloat(self):
t = tensor_util.make_tensor_proto(10.0)
@@ -57,7 +52,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
def testFloatTyped(self):
- t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32)
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=tf.float32)
self.assertProtoEquals("""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
@@ -68,7 +63,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
def testFloatTypeCoerce(self):
- t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32)
+ t = tensor_util.make_tensor_proto([10, 20, 30], dtype=tf.float32)
self.assertProtoEquals("""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
@@ -80,7 +75,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
def testFloatTypeCoerceNdarray(self):
arr = np.asarray([10, 20, 30], dtype="int")
- t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32)
+ t = tensor_util.make_tensor_proto(arr, dtype=tf.float32)
self.assertProtoEquals("""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
@@ -137,7 +132,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
def testFloatTypesWithImplicitRepeat(self):
for dtype, nptype in [
- (dtypes.float32, np.float32), (dtypes.float64, np.float64)]:
+ (tf.float32, np.float32), (tf.float64, np.float64)]:
t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype)
a = tensor_util.MakeNdarray(t)
self.assertAllClose(np.array([[10.0, 10.0, 10.0, 10.0],
@@ -168,10 +163,10 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
def testIntTypes(self):
for dtype, nptype in [
- (dtypes.int32, np.int32),
- (dtypes.uint8, np.uint8),
- (dtypes.int16, np.int16),
- (dtypes.int8, np.int8)]:
+ (tf.int32, np.int32),
+ (tf.uint8, np.uint8),
+ (tf.int16, np.int16),
+ (tf.int8, np.int8)]:
# Test with array.
t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype)
self.assertEquals(dtype, t.dtype)
@@ -189,11 +184,11 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
def testIntTypesWithImplicitRepeat(self):
for dtype, nptype in [
- (dtypes.int64, np.int64),
- (dtypes.int32, np.int32),
- (dtypes.uint8, np.uint8),
- (dtypes.int16, np.int16),
- (dtypes.int8, np.int8)]:
+ (tf.int64, np.int64),
+ (tf.int32, np.int32),
+ (tf.uint8, np.uint8),
+ (tf.int16, np.int16),
+ (tf.int8, np.int8)]:
t = tensor_util.make_tensor_proto([10], shape=[3, 4], dtype=dtype)
a = tensor_util.MakeNdarray(t)
self.assertAllEqual(np.array([[10, 10, 10, 10],
@@ -201,7 +196,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
[10, 10, 10, 10]], dtype=nptype), a)
def testLong(self):
- t = tensor_util.make_tensor_proto(10, dtype=dtypes.int64)
+ t = tensor_util.make_tensor_proto(10, dtype=tf.int64)
self.assertProtoEquals("""
dtype: DT_INT64
tensor_shape {}
@@ -213,7 +208,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
def testLongN(self):
t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3],
- dtype=dtypes.int64)
+ dtype=tf.int64)
self.assertProtoEquals("""
dtype: DT_INT64
tensor_shape { dim { size: 1 } dim { size: 3 } }
@@ -279,7 +274,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
self.assertAllEqual(np.array([[b"a", b"ab"], [b"abc", b"abcd"]]), a)
def testComplex(self):
- t = tensor_util.make_tensor_proto((1+2j), dtype=dtypes.complex64)
+ t = tensor_util.make_tensor_proto((1+2j), dtype=tf.complex64)
self.assertProtoEquals("""
dtype: DT_COMPLEX64
tensor_shape {}
@@ -292,7 +287,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
def testComplexWithImplicitRepeat(self):
t = tensor_util.make_tensor_proto((1+1j), shape=[3, 4],
- dtype=dtypes.complex64)
+ dtype=tf.complex64)
a = tensor_util.MakeNdarray(t)
self.assertAllClose(np.array([[(1+1j), (1+1j), (1+1j), (1+1j)],
[(1+1j), (1+1j), (1+1j), (1+1j)],
@@ -301,7 +296,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
def testComplexN(self):
t = tensor_util.make_tensor_proto([(1+2j), (3+4j), (5+6j)], shape=[1, 3],
- dtype=dtypes.complex64)
+ dtype=tf.complex64)
self.assertProtoEquals("""
dtype: DT_COMPLEX64
tensor_shape { dim { size: 1 } dim { size: 3 } }
@@ -318,7 +313,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
def testComplexNpArray(self):
t = tensor_util.make_tensor_proto(
- np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=dtypes.complex64)
+ np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=tf.complex64)
# scomplex_val are real_0, imag_0, real_1, imag_1, ...
self.assertProtoEquals("""
dtype: DT_COMPLEX64
@@ -357,81 +352,81 @@ class TensorUtilTest(test_util.TensorFlowTestCase):
self.assertTrue(tensor_util.ShapeEquals(t, [2, 2]))
self.assertTrue(tensor_util.ShapeEquals(t, (2, 2)))
self.assertTrue(
- tensor_util.ShapeEquals(t, tensor_util.MakeTensorShapeProto([2, 2])))
+ tensor_util.ShapeEquals(t, tensor_util.make_tensor_shape_proto([2, 2])))
self.assertFalse(tensor_util.ShapeEquals(t, [5, 3]))
self.assertFalse(tensor_util.ShapeEquals(t, [1, 4]))
self.assertFalse(tensor_util.ShapeEquals(t, [4]))
-class ConstantValueTest(test_util.TensorFlowTestCase):
+class ConstantValueTest(tf.test.TestCase):
def testConstant(self):
np_val = np.random.rand(3, 4, 7).astype(np.float32)
- tf_val = constant_op.constant(np_val)
- self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val))
+ tf_val = tf.constant(np_val)
+ self.assertAllClose(np_val, tf.unsupported.constant_value(tf_val))
np_val = np.random.rand(3, 0, 7).astype(np.float32)
- tf_val = constant_op.constant(np_val)
- self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val))
+ tf_val = tf.constant(np_val)
+ self.assertAllClose(np_val, tf.unsupported.constant_value(tf_val))
def testUnknown(self):
- tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=dtypes.float32)
- self.assertIs(None, tensor_util.ConstantValue(tf_val))
+ tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=tf.float32)
+ self.assertIs(None, tf.unsupported.constant_value(tf_val))
def testShape(self):
np_val = np.array([1, 2, 3], dtype=np.int32)
- tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
- c_val = tensor_util.ConstantValue(tf_val)
+ tf_val = tf.shape(tf.constant(0.0, shape=[1, 2, 3]))
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertAllEqual(np_val, c_val)
self.assertEqual(np.int32, c_val.dtype)
def testSize(self):
- tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3]))
- c_val = tensor_util.ConstantValue(tf_val)
+ tf_val = tf.size(tf.constant(0.0, shape=[1, 2, 3]))
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertEqual(6, c_val)
def testSizeOfScalar(self):
- tf_val = array_ops.size(constant_op.constant(0.0))
- c_val = tensor_util.ConstantValue(tf_val)
+ tf_val = tf.size(tf.constant(0.0))
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertEqual(1, c_val)
self.assertEqual(np.int32, type(c_val))
def testRank(self):
- tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
- c_val = tensor_util.ConstantValue(tf_val)
+ tf_val = tf.rank(tf.constant(0.0, shape=[1, 2, 3]))
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertEqual(3, c_val)
def testCast(self):
np_val = np.random.rand(3, 4, 7).astype(np.float32)
- tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64)
- c_val = tensor_util.ConstantValue(tf_val)
+ tf_val = tf.cast(tf.constant(np_val), tf.float64)
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertAllClose(np_val.astype(np.float64), c_val)
np_val = np.random.rand(3, 0, 7).astype(np.float32)
- tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64)
- c_val = tensor_util.ConstantValue(tf_val)
+ tf_val = tf.cast(tf.constant(np_val), tf.float64)
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertAllClose(np_val.astype(np.float64), c_val)
def testConcat(self):
np_val = np.random.rand(3, 4, 7).astype(np.float32)
- tf_val = array_ops.concat(
+ tf_val = tf.concat(
0, [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]])
- c_val = tensor_util.ConstantValue(tf_val)
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertAllClose(np_val, c_val)
- tf_val = array_ops.concat(
- array_ops.placeholder(dtypes.int32),
+ tf_val = tf.concat(
+ tf.placeholder(tf.int32),
[np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]])
- c_val = tensor_util.ConstantValue(tf_val)
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertIs(None, c_val)
- tf_val = array_ops.concat(
+ tf_val = tf.concat(
1,
- [np_val[0, :, :], array_ops.placeholder(dtypes.float32),
+ [np_val[0, :, :], tf.placeholder(tf.float32),
np_val[2, :, :]])
- c_val = tensor_util.ConstantValue(tf_val)
+ c_val = tf.unsupported.constant_value(tf_val)
self.assertIs(None, c_val)
if __name__ == "__main__":
- googletest.main()
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/learn_test.py b/tensorflow/python/kernel_tests/learn_test.py
index e74f7ac0df..802e1f31a2 100644
--- a/tensorflow/python/kernel_tests/learn_test.py
+++ b/tensorflow/python/kernel_tests/learn_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.framework import tensor_util
def assert_summary_scope(regexp):
"""Assert that all generated summaries match regexp."""
for summary in tf.get_collection(tf.GraphKeys.SUMMARIES):
- tag = tensor_util.ConstantValue(summary.op.inputs[0])
+ tag = tf.unsupported.constant_value(summary.op.inputs[0])
assert tag is not None, 'All summaries must have constant tags'
tag = str(tag)
assert isinstance(tag[0], six.string_types), tag[0]
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 7304f5ddf9..0eb49e76c2 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -74,7 +74,7 @@ def _ConcatGrad(op, grad):
for (begin, size) in zip(offset, sizes):
out_grads.append(array_ops.slice(grad, begin, size))
elif isinstance(grad, ops.IndexedSlices):
- concat_dim_static = tensor_util.ConstantValue(concat_dim)
+ concat_dim_static = tensor_util.constant_value(concat_dim)
if concat_dim_static is None:
raise ValueError("Can only compute IndexedSlices gradient with "
"statically-known concat_dim")
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 084398ec64..01484a2b88 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -325,7 +325,7 @@ def _UnpackShape(op):
@ops.RegisterShape("Concat")
def _ConcatShape(op):
- concat_dim = tensor_util.ConstantValue(op.inputs[0])
+ concat_dim = tensor_util.constant_value(op.inputs[0])
if concat_dim is None:
# Return an unknown shape with the same rank as the inputs, or an
# unknown rank if no input's rank is known.
@@ -713,8 +713,8 @@ def _SliceShape(op):
ndims = rank_vector_shape.num_elements()
if ndims is not None:
input_shape.assert_has_rank(ndims)
- begin_value = tensor_util.ConstantValue(op.inputs[1])
- sizes_value = tensor_util.ConstantValue(op.inputs[2])
+ begin_value = tensor_util.constant_value(op.inputs[1])
+ sizes_value = tensor_util.constant_value(op.inputs[2])
if sizes_value is not None:
returned_dims = []
for i, slice_size in enumerate(sizes_value.ravel()):
@@ -795,7 +795,7 @@ def _ExpandDimsShape(op):
input_shape = op.inputs[0].get_shape()
if input_shape.dims is None:
return [tensor_shape.unknown_shape()]
- dim = tensor_util.ConstantValue(op.inputs[1])
+ dim = tensor_util.constant_value(op.inputs[1])
input_ndims = input_shape.ndims
if dim < -input_ndims - 1 or dim > input_ndims:
raise ValueError(
@@ -865,7 +865,7 @@ def _ReshapeShape(op):
else:
num_elements = tensor_shape.Dimension(None)
new_shape_shape = op.inputs[1].get_shape().with_rank_at_most(1)
- new_shape = tensor_util.ConstantValue(op.inputs[1])
+ new_shape = tensor_util.constant_value(op.inputs[1])
if new_shape is None:
# Attempt to infer the rank of the output from the length of
# new_shape.
@@ -908,7 +908,7 @@ def _ReshapeShape(op):
@ops.RegisterShape("BroadcastGradientArgs")
def _BroadcastGradientArgsShape(op):
"""Shape function for the BroadcastGradientArgs op."""
- # TODO(mrry): Implement ConstantValue for BroadcastGradientArgs?
+ # TODO(mrry): Implement constant_value for BroadcastGradientArgs?
op.inputs[0].get_shape().assert_has_rank(1)
op.inputs[1].get_shape().assert_has_rank(1)
return [tensor_shape.vector(None), tensor_shape.vector(None)]
@@ -929,7 +929,7 @@ def _FillShape(op):
"""
dimensions_shape = op.inputs[0].get_shape().with_rank_at_most(1)
op.inputs[1].get_shape().assert_is_compatible_with(tensor_shape.scalar())
- fill_dims = tensor_util.ConstantValue(op.inputs[0])
+ fill_dims = tensor_util.constant_value(op.inputs[0])
if fill_dims is None:
# Attempt to infer the rank of the output from the length of
# dimensions.
@@ -981,7 +981,7 @@ def _PadShape(op):
input_shape = input_shape.with_rank(paddings_shape[0].value)
paddings_shape = paddings_shape.merge_with(
tensor_shape.matrix(input_shape.ndims, 2))
- paddings = tensor_util.ConstantValue(op.inputs[1])
+ paddings = tensor_util.constant_value(op.inputs[1])
if paddings is None:
return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
else:
@@ -1062,7 +1062,7 @@ def _TransposeShape(op):
input_shape = op.inputs[0].get_shape()
transpose_shape = op.inputs[1].get_shape().merge_with(tensor_shape.vector(
input_shape.ndims))
- transpose_vec = tensor_util.ConstantValue(op.inputs[1])
+ transpose_vec = tensor_util.constant_value(op.inputs[1])
if transpose_vec is None:
return [tensor_shape.unknown_shape(ndims=transpose_shape[0].value)]
else:
@@ -1073,7 +1073,7 @@ def _TransposeShape(op):
@ops.RegisterShape("Split")
def _SplitShape(op):
"""Shape function for the Split op."""
- split_dim = tensor_util.ConstantValue(op.inputs[0])
+ split_dim = tensor_util.constant_value(op.inputs[0])
num_split = len(op.outputs)
input_shape = op.inputs[1].get_shape()
if split_dim is None:
@@ -1114,7 +1114,7 @@ def _TileShape(op):
"""
multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1)
input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements())
- multiples = tensor_util.ConstantValue(op.inputs[1])
+ multiples = tensor_util.constant_value(op.inputs[1])
if multiples is None:
return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
else:
@@ -1130,7 +1130,7 @@ def _TileGradShape(op):
"""Shape function for the TileGrad op."""
multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1)
input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements())
- multiples = tensor_util.ConstantValue(op.inputs[1])
+ multiples = tensor_util.constant_value(op.inputs[1])
if multiples is None:
return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
else:
@@ -1230,8 +1230,8 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
@ops.RegisterShape("EditDistance")
def _EditDistanceShape(op):
"""Shape function for the EditDistance op."""
- hypothesis_shape = tensor_util.ConstantValue(op.inputs[2])
- truth_shape = tensor_util.ConstantValue(op.inputs[5])
+ hypothesis_shape = tensor_util.constant_value(op.inputs[2])
+ truth_shape = tensor_util.constant_value(op.inputs[5])
if hypothesis_shape is not None and truth_shape is not None:
if len(hypothesis_shape) != len(truth_shape):
raise ValueError(
diff --git a/tensorflow/python/ops/attention_ops.py b/tensorflow/python/ops/attention_ops.py
index 3bc13ca9b5..20caf5df75 100644
--- a/tensorflow/python/ops/attention_ops.py
+++ b/tensorflow/python/ops/attention_ops.py
@@ -42,7 +42,7 @@ def _ExtractGlimpseShape(op):
offsets_shape = op.inputs[2].get_shape().merge_with(
input_shape[:1].concatenate([2]))
offsets_shape = offsets_shape
- size_value = tensor_util.ConstantValue(op.inputs[1])
+ size_value = tensor_util.constant_value(op.inputs[1])
if size_value is not None:
height = size_value[0]
width = size_value[1]
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index c8bdd5317f..6454cf104b 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -300,7 +300,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
op = ret[0].op
- batch_dim = tensor_shape.Dimension(tensor_util.ConstantValue(op.inputs[1]))
+ batch_dim = tensor_shape.Dimension(tensor_util.constant_value(op.inputs[1]))
for output, shape in zip(op.values(), self._shapes):
output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape))
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index a99a8ea2f5..78ecc98859 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -80,7 +80,7 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
% str(value))
# TODO(mrry): Consider adding static shape information to
# IndexedSlices, to avoid using numpy here.
- dense_shape_value = tensor_util.ConstantValue(value.dense_shape)
+ dense_shape_value = tensor_util.constant_value(value.dense_shape)
if dense_shape_value is not None:
num_elements = np.prod(dense_shape_value)
if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py
index f2e58277ec..668adb1c26 100644
--- a/tensorflow/python/ops/image_grad.py
+++ b/tensorflow/python/ops/image_grad.py
@@ -70,7 +70,7 @@ def _ResizeBilinearGrad(op, grad):
def _ResizeShape(op):
"""Shape function for the resize grad ops."""
input_shape = op.inputs[0].get_shape().with_rank(4)
- size = tensor_util.ConstantValue(op.inputs[1])
+ size = tensor_util.constant_value(op.inputs[1])
if size is not None:
height = size[0]
width = size[1]
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index abc1739828..94a288afce 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -775,7 +775,7 @@ ops.RegisterShape('AdjustContrastv2')(
def _ResizeShape(op):
"""Shape function for the resize_bilinear and resize_nearest_neighbor ops."""
input_shape = op.inputs[0].get_shape().with_rank(4)
- size = tensor_util.ConstantValue(op.inputs[1])
+ size = tensor_util.constant_value(op.inputs[1])
if size is not None:
height = size[0]
width = size[1]
@@ -810,7 +810,7 @@ def _random_cropShape(op):
input_shape = op.inputs[0].get_shape().with_rank(3)
unused_size_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.vector(2))
- size = tensor_util.ConstantValue(op.inputs[1])
+ size = tensor_util.constant_value(op.inputs[1])
if size is not None:
height = size[0]
width = size[1]
diff --git a/tensorflow/python/ops/learn.py b/tensorflow/python/ops/learn.py
index 50b5c56b42..016b14d67d 100644
--- a/tensorflow/python/ops/learn.py
+++ b/tensorflow/python/ops/learn.py
@@ -93,7 +93,7 @@ def xavier_initializer(n_inputs, n_outputs, uniform=True):
def _assert_summary_tag_unique(tag):
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES):
- old_tag = tensor_util.ConstantValue(summary.op.inputs[0])
+ old_tag = tensor_util.constant_value(summary.op.inputs[0])
if tag == str(old_tag):
raise ValueError('Conflict with summary tag: %s exists on summary %s %s' %
(tag, summary, old_tag))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c2a19fe31a..3abc04e51b 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -609,9 +609,9 @@ def range(start, limit=None, delta=1, name="range"):
@ops.RegisterShape("Range")
def _RangeShape(op):
- start_value = tensor_util.ConstantValue(op.inputs[0])
- limit_value = tensor_util.ConstantValue(op.inputs[1])
- delta_value = tensor_util.ConstantValue(op.inputs[2])
+ start_value = tensor_util.constant_value(op.inputs[0])
+ limit_value = tensor_util.constant_value(op.inputs[1])
+ delta_value = tensor_util.constant_value(op.inputs[2])
if start_value is None or limit_value is None or delta_value is None:
return [tensor_shape.vector(None)]
else:
@@ -1280,7 +1280,7 @@ def _ArgOpShape(op):
elif input_shape.ndims <= 1:
return [tensor_shape.scalar()]
- dimension = tensor_util.ConstantValue(op.inputs[1])
+ dimension = tensor_util.constant_value(op.inputs[1])
if dimension is None:
return [tensor_shape.unknown_shape(ndims=input_shape.ndims - 1)]
elif 0 <= dimension and dimension < input_shape.ndims:
@@ -1306,7 +1306,7 @@ def _ArgOpShape(op):
def _ReductionShape(op):
"""Common shape function for reduction ops."""
input_shape = op.inputs[0].get_shape()
- reduction_indices = tensor_util.ConstantValue(op.inputs[1])
+ reduction_indices = tensor_util.constant_value(op.inputs[1])
keep_dims = op.get_attr("keep_dims")
if reduction_indices is None or input_shape.ndims is None:
if keep_dims:
@@ -1375,7 +1375,7 @@ def _SparseSegmentReductionGradShape(op):
unused_segment_ids_shape = op.inputs[2].get_shape().merge_with(indices_shape)
unused_output_dim0_shape = op.inputs[3].get_shape().merge_with(
tensor_shape.scalar())
- output_dim0 = tensor_util.ConstantValue(op.inputs[3])
+ output_dim0 = tensor_util.constant_value(op.inputs[3])
if output_dim0 is not None:
dim0 = output_dim0[0]
else:
@@ -1393,12 +1393,12 @@ def _UnsortedSegmentSumShape(op):
if mid is None:
return [tensor_shape.unknown_shape()]
else:
- num_segments = tensor_util.ConstantValue(op.inputs[2])
+ num_segments = tensor_util.constant_value(op.inputs[2])
return [tensor_shape.TensorShape([num_segments]).concatenate(
data_shape[mid:])]
@ops.RegisterShape("LinSpace")
def _LinspaceShape(op):
- num = tensor_util.ConstantValue(op.inputs[2])
+ num = tensor_util.constant_value(op.inputs[2])
return [tensor_shape.vector(num)]
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index ad05f823fb..9df7342cc1 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -296,7 +296,7 @@ def _TopKShape(op):
"""Shape function for TopK and TopKV2 ops."""
input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
if len(op.inputs) >= 2:
- k = tensor_util.ConstantValue(op.inputs[1])
+ k = tensor_util.constant_value(op.inputs[1])
else:
k = op.get_attr("k")
last = input_shape[-1].value
@@ -352,7 +352,7 @@ def _MaxPoolWithArgMaxShape(op):
@ops.RegisterShape("AvgPoolGrad")
def _AvgPoolGradShape(op):
"""Shape function for the AvgPoolGrad op."""
- orig_input_shape = tensor_util.ConstantValue(op.inputs[0])
+ orig_input_shape = tensor_util.constant_value(op.inputs[0])
if orig_input_shape is not None:
return [tensor_shape.TensorShape(orig_input_shape.tolist())]
else:
@@ -366,7 +366,7 @@ def _AvgPoolGradShape(op):
@ops.RegisterShape("Conv2DBackpropFilter")
def _Conv2DBackpropFilterShape(op):
"""Shape function for the Conv2DBackpropFilter op."""
- filter_shape = tensor_util.ConstantValue(op.inputs[1])
+ filter_shape = tensor_util.constant_value(op.inputs[1])
if filter_shape is not None:
return [tensor_shape.TensorShape(filter_shape.tolist())]
else:
@@ -380,7 +380,7 @@ def _Conv2DBackpropFilterShape(op):
@ops.RegisterShape("Conv2DBackpropInput")
def _Conv2DBackpropInputShape(op):
"""Shape function for the Conv2DBackpropInput op."""
- input_shape = tensor_util.ConstantValue(op.inputs[0])
+ input_shape = tensor_util.constant_value(op.inputs[0])
if input_shape is not None:
return [tensor_shape.TensorShape(input_shape.tolist())]
else:
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index e210f34df1..ac81eaf676 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -391,7 +391,7 @@ def _parse_example_raw(serialized,
dense_defaults_vec.append(default_value)
- dense_shapes = [tensor_util.MakeTensorShapeProto(shape)
+ dense_shapes = [tensor_util.make_tensor_shape_proto(shape)
if isinstance(shape, (list, tuple)) else shape
for shape in dense_shapes]
@@ -825,10 +825,10 @@ def _parse_single_sequence_example_raw(serialized,
context_dense_defaults_vec.append(default_value)
- context_dense_shapes = [tensor_util.MakeTensorShapeProto(shape)
+ context_dense_shapes = [tensor_util.make_tensor_shape_proto(shape)
if isinstance(shape, (list, tuple)) else shape
for shape in context_dense_shapes]
- feature_list_dense_shapes = [tensor_util.MakeTensorShapeProto(shape)
+ feature_list_dense_shapes = [tensor_util.make_tensor_shape_proto(shape)
if isinstance(shape, (list, tuple)) else shape
for shape in feature_list_dense_shapes]
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 428d591acc..dac67b4fb4 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -217,7 +217,7 @@ ops.NoGradient("RandomUniform")
@ops.RegisterShape("RandomUniform")
@ops.RegisterShape("RandomUniformInt")
def _RandomShape(op):
- shape_val = tensor_util.ConstantValue(op.inputs[0])
+ shape_val = tensor_util.constant_value(op.inputs[0])
if shape_val is not None:
return [tensor_shape.TensorShape(shape_val.tolist())]
else:
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 1a7ce8e40f..1a9f9954a0 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -296,7 +296,7 @@ def _SparseSplitShape(op):
@ops.RegisterShape("SparseToDense")
def _SparseToDenseShape(op):
- input_shape = tensor_util.ConstantValue(op.inputs[1])
+ input_shape = tensor_util.constant_value(op.inputs[1])
if input_shape is not None:
if np.ndim(input_shape) > 1:
raise ValueError("Input shape should be a vector")
diff --git a/tensorflow/python/unsupported.py b/tensorflow/python/unsupported.py
new file mode 100644
index 0000000000..23c40766ab
--- /dev/null
+++ b/tensorflow/python/unsupported.py
@@ -0,0 +1,34 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""This module includes unsupported and experimental features which are exposed
+but not part of the supported public API. Anything in this module can change
+without notice, even across a patch release.
+
+## Utilities
+
+@@constant_value
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+from tensorflow.python.util.all_util import make_all
+
+# pylint: disable=unused-import
+from tensorflow.python.framework.tensor_util import constant_value
+
+__all__ = make_all(__name__)
diff --git a/tensorflow/python/util/all_util.py b/tensorflow/python/util/all_util.py
new file mode 100644
index 0000000000..e88e4ea847
--- /dev/null
+++ b/tensorflow/python/util/all_util.py
@@ -0,0 +1,42 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generate __all__ from a module docstring."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import sys
+
+_reference_pattern = re.compile(r'^@@(\w+)$')
+
+
+def make_all(module_name):
+ """Generate `__all__` from a module's docstring.
+
+ Usage: `make_all(__name__)`. The caller module must have a docstring,
+ and `__all__` will contain all symbols with `@@` references.
+
+ Args:
+ module_name: The name of the module (usually `__name__`).
+
+ Returns:
+ A list suitable for use as `__all__`.
+ """
+ doc = sys.modules[module_name].__doc__
+ return [m.group(1) for m in _reference_pattern.finditer(doc)]
+
+__all__ = ['make_all']