aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2017-06-06 14:38:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-06 14:42:25 -0700
commitb4951553961f84ac9aa499b4ded96ba8264d4604 (patch)
tree8f37205559caeb261008fbfe04beed26c4ad43f3 /tensorflow/contrib/lookup
parent51acad09c1b2fdc68e949c4c02154cc8f57ea78e (diff)
Add only string constants to ASSET_FILEPATHS collection.
PiperOrigin-RevId: 158192152
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/BUILD1
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py6
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py18
3 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD
index b0475c41c9..1090cecab5 100644
--- a/tensorflow/contrib/lookup/BUILD
+++ b/tensorflow/contrib/lookup/BUILD
@@ -19,6 +19,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index e49b62afa2..7600d30539 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -485,7 +486,10 @@ class TextFileInitializer(TableInitializerBase):
name=scope)
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
- ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
+ # If the filename tensor is anything other than a string constant (e.g., if
+ # it is a placeholder) then it does not make sense to track it as an asset.
+ if constant_op.is_constant(filename):
+ ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
return init_op
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 180dfefe29..f0499010d4 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -1198,6 +1198,24 @@ class IndexTableFromFile(test.TestCase):
self.assertRaises(errors_impl.OpError, ids.eval)
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
+ self.assertEqual(1,
+ len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
+
+ def test_string_index_table_from_file_placeholder_filename(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
+ with self.test_session():
+ vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
+ table = lookup.index_table_from_file(
+ vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
+ ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
+
+ self.assertRaises(errors_impl.OpError, ids.eval)
+
+ feed_dict = {vocabulary_placeholder.name: vocabulary_file}
+ lookup_ops.tables_initializer().run(feed_dict=feed_dict)
+ self.assertAllEqual((1, 2, 3), ids.eval())
+ self.assertEqual(0,
+ len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(