aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-01-16 13:00:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-16 13:05:58 -0800
commit0830ceefcdbd2e5f57e676f72a4abaf8d351bc28 (patch)
tree92bec1c19eede8352a03e586586faca7955e4fcb /tensorflow/contrib/lookup
parent05bcbb649ce567fd11cd52a03b992411b8470c32 (diff)
Improves the eager compatibility with tf.contrib.lookup.
Fixes #16160 PiperOrigin-RevId: 182098964
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 65aaaf85c3..f681b7b132 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -1656,23 +1656,22 @@ class InitializeTableFromFileOpTest(test.TestCase):
f.write("\n".join(values) + "\n")
return vocabulary_file
+ @test_util.run_in_graph_and_eager_modes()
def testInitializeStringTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt")
+ default_value = -1
+ table = lookup.HashTable(
+ lookup.TextFileInitializer(vocabulary_file, dtypes.string,
+ lookup.TextFileIndex.WHOLE_LINE,
+ dtypes.int64,
+ lookup.TextFileIndex.LINE_NUMBER),
+ default_value)
+ self.evaluate(table.init)
- with self.test_session():
- default_value = -1
- table = lookup.HashTable(
- lookup.TextFileInitializer(vocabulary_file, dtypes.string,
- lookup.TextFileIndex.WHOLE_LINE,
- dtypes.int64,
- lookup.TextFileIndex.LINE_NUMBER),
- default_value)
- table.init.run()
-
- output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
+ output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
- result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
+ result = self.evaluate(output)
+ self.assertAllEqual([0, 1, -1], result)
def testInitializeInt64Table(self):
vocabulary_file = self._createVocabFile(