diff options
author | 2017-06-06 14:38:43 -0700 | |
---|---|---|
committer | 2017-06-06 14:42:25 -0700 | |
commit | b4951553961f84ac9aa499b4ded96ba8264d4604 (patch) | |
tree | 8f37205559caeb261008fbfe04beed26c4ad43f3 /tensorflow/contrib/lookup | |
parent | 51acad09c1b2fdc68e949c4c02154cc8f57ea78e (diff) |
Add only string constants to ASSET_FILEPATHS collection.
PiperOrigin-RevId: 158192152
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 18 |
3 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index b0475c41c9..1090cecab5 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -19,6 +19,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index e49b62afa2..7600d30539 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import functools +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -485,7 +486,10 @@ class TextFileInitializer(TableInitializerBase): name=scope) # pylint: enable=protected-access ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) - ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) + # If the filename tensor is anything other than a string constant (e.g., if + # it is a placeholder) then it does not make sense to track it as an asset. + if constant_op.is_constant(filename): + ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) return init_op diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 180dfefe29..f0499010d4 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1198,6 +1198,24 @@ class IndexTableFromFile(test.TestCase): self.assertRaises(errors_impl.OpError, ids.eval) lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + self.assertEqual(1, + len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))) + + def test_string_index_table_from_file_placeholder_filename(self): + vocabulary_file = self._createVocabFile("f2i_vocab1.txt") + with self.test_session(): + vocabulary_placeholder = array_ops.placeholder(dtypes.string, []) + table = lookup.index_table_from_file( + vocabulary_file=vocabulary_placeholder, num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + + feed_dict = {vocabulary_placeholder.name: vocabulary_file} + lookup_ops.tables_initializer().run(feed_dict=feed_dict) + self.assertAllEqual((1, 2, 3), ids.eval()) + self.assertEqual(0, + len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))) def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( |