aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yutaka Leon <yleon@google.com>2017-05-04 12:31:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-04 13:50:10 -0700
commitdd140f79e06a81c52cd8fc9ec6cda975a78a401f (patch)
treeca8cb309a8853c31074e649f4d9642ed9a2bacd0
parente46a12bc9fbcea1fef224daa47eb9f1cf9e56472 (diff)
Organize the lookup table ops into it's own lookup_ops.cc file instead of data_flow_ops.cc
Change: 155119120
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py106
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py18
-rw-r--r--tensorflow/contrib/lookup/BUILD2
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py64
-rw-r--r--tensorflow/contrib/slim/python/slim/learning.py4
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/kernels/BUILD10
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc598
-rw-r--r--tensorflow/core/ops/lookup_ops.cc666
-rw-r--r--tensorflow/python/BUILD30
-rw-r--r--tensorflow/python/estimator/estimator_test.py9
-rw-r--r--tensorflow/python/feature_column/BUILD2
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py8
-rw-r--r--tensorflow/python/feature_column/lookup_ops.py54
-rw-r--r--tensorflow/python/ops/data_flow_ops.py42
-rw-r--r--tensorflow/python/ops/lookup_ops.py77
-rw-r--r--tensorflow/python/ops/standard_ops.py1
-rw-r--r--tensorflow/python/saved_model/main_op_impl.py4
-rw-r--r--tensorflow/python/training/monitored_session.py4
-rw-r--r--tensorflow/python/training/saver_test_utils.py14
-rw-r--r--tensorflow/python/training/supervisor.py8
26 files changed, 950 insertions, 811 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index b2dad0162e..a09cc53571 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -33,9 +33,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
@@ -224,7 +224,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(keys_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[keys_sparse].indices.eval(),
@@ -242,7 +242,7 @@ class TransformerTest(test.TestCase):
output = feature_column_ops._Transformer(features).transform(keys_sparse)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.dtype, dtypes.int64)
@@ -311,7 +311,7 @@ class TransformerTest(test.TestCase):
self.assertIn(weighted_ids, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
ids_tensor.dense_shape.eval())
self.assertAllEqual(output[weighted_ids][0].indices.eval(),
@@ -341,7 +341,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@@ -363,7 +363,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@@ -387,7 +387,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@@ -409,7 +409,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@@ -601,7 +601,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
one_hot_column, embedding_column, real_valued_column])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
def testRealValuedColumn(self):
@@ -714,7 +714,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_column])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
output.eval())
@@ -732,7 +732,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
output.eval())
@@ -750,7 +750,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval())
@@ -784,7 +784,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape)
def testEmbeddingColumnSucceedsForDNN(self):
@@ -891,7 +891,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
@@ -914,7 +914,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self):
@@ -965,7 +965,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
"Error creating input layer for column: ids_weighted_by_weights"):
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
feature_column_ops.input_from_feature_columns(features, [weighted_ids])
def testCrossedColumnFailsForDNN(self):
@@ -1072,7 +1072,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# score: (sum of weights)
self.assertAllEqual(output.eval(), [[10.], [50.], [0.]])
@@ -1310,7 +1310,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, 4])
@@ -1344,7 +1344,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, hash_buckets])
@@ -1374,7 +1374,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape)
@@ -1403,7 +1403,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape)
@@ -1433,7 +1433,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
embedding_weights)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
expected_input_shape = [4, 3, embedding_dimension]
@@ -1500,7 +1500,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = [
@@ -1581,7 +1581,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_ids], num_outputs=5)
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
def testWeightedSparseColumnWithDenseInputTensor(self):
@@ -1597,7 +1597,7 @@ class WeightedSumTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
def testCrossedColumn(self):
@@ -1651,7 +1651,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
with self.test_session() as sess:
variables_lib.initialize_all_variables().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[movies][0]
self.assertEqual(weights.get_shape(), (3, 1))
@@ -1726,7 +1726,7 @@ class WeightedSumTest(test.TestCase):
features, [age, language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@@ -1766,7 +1766,7 @@ class WeightedSumTest(test.TestCase):
self.assertEqual(len(variables), 1)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@@ -1830,7 +1830,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@@ -1858,7 +1858,7 @@ class WeightedSumTest(test.TestCase):
features, [language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# score: 0.1 + language_weight['hindi'] + language_weight['english']
sess.run(bias.assign([0.1]))
@@ -1881,7 +1881,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[movies][0]
self.assertEqual(weights.get_shape(), (15, 1))
@@ -1915,7 +1915,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language][0]
sess.run(weights.assign(weights + 0.4))
@@ -1939,7 +1939,7 @@ class WeightedSumTest(test.TestCase):
features, [language_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[language_language][0]
sess.run(weights.assign(weights + 0.4))
@@ -1972,7 +1972,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language][0]
sess.run(weights.assign(weights + 0.4))
@@ -2013,7 +2013,7 @@ class WeightedSumTest(test.TestCase):
scope=scope))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertEqual(2, len(column_to_variable[country]))
self.assertEqual(3, len(column_to_variable[language]))
@@ -2050,7 +2050,7 @@ class WeightedSumTest(test.TestCase):
features, [country, age, incomes], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
incomes_weights = column_to_variable[incomes][0]
sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]]))
@@ -2086,7 +2086,7 @@ class WeightedSumTest(test.TestCase):
features, [country, age, height, incomes], num_outputs=5))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
height_weights = column_to_variable[height][0]
sess.run(
@@ -2116,7 +2116,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3],
[0.4]]))
@@ -2144,7 +2144,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket, country], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# dimension = 2, bucket_size = 4, num_classes = 1
sess.run(column_to_variable[bucket][0].assign(
@@ -2173,7 +2173,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket, country], num_outputs=5))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
# dimension = 2, bucket_size = 4, num_classes = 5
sess.run(column_to_variable[bucket][0].assign(
@@ -2209,7 +2209,7 @@ class WeightedSumTest(test.TestCase):
features, [country_price], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[country_price][0]
sess.run(weights.assign(weights + 0.4))
@@ -2248,7 +2248,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language_price], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language_price][0]
sess.run(weights.assign(weights + 0.4))
@@ -2272,7 +2272,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@@ -2287,7 +2287,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@@ -2302,7 +2302,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.6], [0.7]])
@@ -2323,7 +2323,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@@ -2335,7 +2335,7 @@ class WeightedSumTest(test.TestCase):
features, [feature_column.real_valued_column("age")], num_outputs=3)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
sess.run(bias.assign([0.1, 0.2, 0.3]))
self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3],
[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]])
@@ -2349,7 +2349,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (1, 3))
sess.run(weights.assign([[0.01, 0.03, 0.05]]))
@@ -2373,7 +2373,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
sess.run(
@@ -2399,7 +2399,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@@ -2439,7 +2439,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@@ -2468,7 +2468,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@@ -2533,7 +2533,7 @@ class ParseExampleTest(test.TestCase):
self.assertIn(bucket, output)
self.assertIn(wire_cast, output)
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
self.assertAllEqual(output[wire_cast].values.eval(), [2, 0])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index 43b3d2a78f..58072500d1 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@@ -157,7 +157,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
self.context_feature_columns)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
- sess.run(data_flow_ops.tables_initializer())
+ sess.run(lookup_ops.tables_initializer())
sequence_input_val = sess.run(sequence_input)
expected_shape = np.array([
3, # expected batch size
@@ -178,7 +178,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
# Obtain values of activations and final state.
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
- sess.run(data_flow_ops.tables_initializer())
+ sess.run(lookup_ops.tables_initializer())
activations, final_state = sess.run([activations_t, final_state_t])
expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 74a6da20d4..36f843ba8e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -57,7 +57,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@@ -1292,7 +1292,7 @@ class Estimator(BaseEstimator):
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
resources.initialize_resources(resources.shared_resources()),
- data_flow_ops.tables_initializer())
+ lookup_ops.tables_initializer())
# Perform the export
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 207a189a94..d5777088de 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -32,7 +32,7 @@ from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses as losses_lib
from tensorflow.python.platform import test
@@ -1214,7 +1214,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1.),))
with session.Session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(
[0, 2],
model_fn_ops.predictions["classes"].eval())
@@ -1266,7 +1266,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1.),))
with session.Session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(
[b"key0", b"key2"],
model_fn_ops.predictions["classes"].eval())
@@ -1301,7 +1301,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.),))
with session.Session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
@@ -1327,7 +1327,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((0., 0., 1.),))
with session.Session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
index f5bd03429c..feea6c5fed 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
@@ -35,8 +35,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@@ -55,7 +55,7 @@ class PrepareInputsForRnnTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
- sess.run(data_flow_ops.initialize_all_tables())
+ sess.run(lookup_ops.tables_initializer())
features_val = sess.run(features_by_time)
self.assertAllEqual(expected, features_val)
@@ -316,7 +316,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
- sess.run(data_flow_ops.initialize_all_tables())
+ sess.run(lookup_ops.tables_initializer())
actual_sequence, actual_context = sess.run(
[sequence, context])
assert_equal(expected_sequence, actual_sequence)
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index 4b7867f2d0..98365c05f6 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -37,8 +37,8 @@ from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
@@ -429,11 +429,14 @@ def _get_ready_op():
def _get_local_init_op():
+ """Returns the local init ops to initialize tables and local variables."""
local_init_op = _get_first_op_from_collection(
ops.GraphKeys.LOCAL_INIT_OP)
if local_init_op is None:
- op_list = [variables.local_variables_initializer(),
- data_flow_ops.tables_initializer()]
+ op_list = [
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()
+ ]
if op_list:
local_init_op = control_flow_ops.group(*op_list)
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
@@ -680,7 +683,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
else:
session.run(variables.global_variables_initializer())
session.run(variables.local_variables_initializer())
- session.run(data_flow_ops.tables_initializer())
+ session.run(lookup_ops.tables_initializer())
coord = coordinator.Coordinator()
threads = None
try:
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
index b53be29283..36a1f5f60c 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
@@ -67,17 +67,17 @@ def _export_graph(graph, saver, checkpoint_path, export_dir,
with graph.as_default():
with tf_session.Session('') as session:
variables.local_variables_initializer()
- data_flow_ops.tables_initializer()
+ lookup_ops.tables_initializer()
saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver)
- export.init(init_op=control_flow_ops.group(
- variables.local_variables_initializer(),
- data_flow_ops.tables_initializer()),
- default_graph_signature=default_graph_signature,
- named_graph_signatures=named_graph_signatures,
- assets_collection=ops.get_collection(
- ops.GraphKeys.ASSET_FILEPATHS))
+ export.init(
+ init_op=control_flow_ops.group(
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()),
+ default_graph_signature=default_graph_signature,
+ named_graph_signatures=named_graph_signatures,
+ assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
return export.export(export_dir, contrib_variables.get_global_step(),
session, exports_to_keep=exports_to_keep)
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD
index 5966c86dfb..bbbd340352 100644
--- a/tensorflow/contrib/lookup/BUILD
+++ b/tensorflow/contrib/lookup/BUILD
@@ -30,11 +30,11 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:data_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:lookup_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
"//tensorflow/python:variables",
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 0ec40a63f2..5ec169b6db 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
@@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase):
table3 = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), default_val)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(3, table1.size().eval())
self.assertAllEqual(3, table2.size().eval())
self.assertAllEqual(3, table3.size().eval())
@@ -1184,7 +1184,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_file(self):
@@ -1198,7 +1198,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_file(self):
@@ -1212,7 +1212,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_file_with_default_value(self):
@@ -1224,7 +1224,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_file_with_oov_buckets(self):
@@ -1236,7 +1236,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual(
(
1, # From vocabulary file.
@@ -1259,7 +1259,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, -1, -1), ids.eval())
self.assertEqual(2, table.size().eval())
@@ -1286,7 +1286,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), ids.eval())
self.assertEqual(3, table.size().eval())
@@ -1345,7 +1345,7 @@ class IndexTableFromTensor(test.TestCase):
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_tensor_with_tensor_init(self):
@@ -1356,7 +1356,7 @@ class IndexTableFromTensor(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
@@ -1367,7 +1367,7 @@ class IndexTableFromTensor(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_tensor_with_default_value(self):
@@ -1378,7 +1378,7 @@ class IndexTableFromTensor(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_mapping(self):
@@ -1394,7 +1394,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertRaises(errors_impl.OpError, ids.eval)
with self.assertRaisesRegexp(
errors_impl.OpError, "keys and values cannot be empty"):
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
with self.test_session():
@@ -1422,7 +1422,7 @@ class StringToIndexTest(test.TestCase):
indices = lookup.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, indices.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), indices.eval())
@@ -1433,7 +1433,7 @@ class StringToIndexTest(test.TestCase):
_ = lookup.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError,
- data_flow_ops.tables_initializer().run)
+ lookup_ops.tables_initializer().run)
def test_string_to_index_with_default_value(self):
default_value = -42
@@ -1444,7 +1444,7 @@ class StringToIndexTest(test.TestCase):
feats, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, indices.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), indices.eval())
@@ -1463,7 +1463,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval())
@@ -1475,7 +1475,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval())
@@ -1489,7 +1489,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", default_value, default_value),
features.eval())
@@ -1501,7 +1501,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- init = data_flow_ops.tables_initializer()
+ init = lookup_ops.tables_initializer()
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Invalid vocab_size", init.run)
@@ -1513,7 +1513,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
@@ -1528,7 +1528,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval())
@@ -1540,7 +1540,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
mapping=mapping_strings)
indices = constant_op.constant([0, 1, 4], dtypes.int64)
features = table.lookup(indices)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
def test_index_to_string_with_default_value(self):
@@ -1553,7 +1553,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval())
@@ -1567,7 +1567,7 @@ class IndexToStringTest(test.TestCase):
feats = lookup.index_to_string(indices, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, feats.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
feats.eval())
@@ -1577,11 +1577,11 @@ class IndexToStringTest(test.TestCase):
mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
self.assertRaises(errors_impl.OpError,
- data_flow_ops.tables_initializer().run)
+ lookup_ops.tables_initializer().run)
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
@@ -1592,7 +1592,7 @@ class IndexToStringTest(test.TestCase):
indices, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, feats.eval)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
@@ -1755,7 +1755,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value,
shared_name=shared_name)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
input_string = constant_op.constant(["brain", "salad", "tank"])
@@ -2081,7 +2081,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup.StrongHashSpec((1, 2)),
name="table2")
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
input_string = constant_op.constant(
["fruit", "brain", "salad", "surgery", "UNK"])
@@ -2167,7 +2167,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
default_value2),
oov_buckets)
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
input_string_1 = constant_op.constant(
["brain", "salad", "surgery", "UNK"])
diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py
index 5ced8a4f08..b70d612f55 100644
--- a/tensorflow/contrib/slim/python/slim/learning.py
+++ b/tensorflow/contrib/slim/python/slim/learning.py
@@ -261,7 +261,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
@@ -657,7 +657,7 @@ def train(train_op,
if local_init_op == _USE_DEFAULT:
local_init_op = control_flow_ops.group(
tf_variables.local_variables_initializer(),
- data_flow_ops.tables_initializer())
+ lookup_ops.tables_initializer())
if sync_optimizer is not None and isinstance(
sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 119bc0f899..435618ace7 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -506,6 +506,7 @@ tf_gen_op_libs(
"image_ops",
"io_ops",
"linalg_ops",
+ "lookup_ops",
"logging_ops",
"math_ops",
"nn_ops",
@@ -582,6 +583,7 @@ cc_library(
":image_ops_op_lib",
":io_ops_op_lib",
":linalg_ops_op_lib",
+ ":lookup_ops_op_lib",
":logging_ops_op_lib",
":math_ops_op_lib",
":nn_ops_op_lib",
@@ -708,6 +710,7 @@ cc_library(
"//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io",
"//tensorflow/core/kernels:linalg",
+ "//tensorflow/core/kernels:lookup",
"//tensorflow/core/kernels:logging",
"//tensorflow/core/kernels:math",
"//tensorflow/core/kernels:multinomial_op",
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 0847d1279b..02ab30a04f 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1327,6 +1327,14 @@ cc_library(
],
)
+cc_library(
+ name = "lookup",
+ deps = [
+ ":lookup_table_init_op",
+ ":lookup_table_op",
+ ],
+)
+
DATA_FLOW_DEPS = [
":bounds_check",
":concat_lib",
@@ -1450,10 +1458,10 @@ LOOKUP_DEPS = [
":initializable_lookup_table",
":lookup_util",
"//tensorflow/core:core_cpu",
- "//tensorflow/core:data_flow_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:lookup_ops_op_lib",
]
tf_kernel_library(
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index f35a1bb648..032ede6459 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1876,604 +1876,6 @@ size: The number of incomplete elements (i.e. those with some of their value
// --------------------------------------------------------------------------
-REGISTER_OP("LookupTableFind")
- .Input("table_handle: Ref(string)")
- .Input("keys: Tin")
- .Input("default_value: Tout")
- .Output("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- // Default value must be scalar or vector.
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
- c->set_output(0, c->UnknownShape());
- return Status::OK();
- })
- .Doc(R"doc(
-Looks up keys in a table, outputs the corresponding values.
-
-The tensor `keys` must of the same type as the keys of the table.
-The output `values` is of the type of the table values.
-
-The scalar `default_value` is the value output for keys not present in the
-table. It must also be of the same type as the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Same shape as `keys`. Values found in the table, or `default_values`
- for missing keys.
-)doc");
-
-REGISTER_OP("LookupTableFindV2")
- .Input("table_handle: resource")
- .Input("keys: Tin")
- .Input("default_value: Tout")
- .Output("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- // Default value must be scalar or vector.
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
- c->set_output(0, c->UnknownShape());
- return Status::OK();
- })
- .Doc(R"doc(
-Looks up keys in a table, outputs the corresponding values.
-
-The tensor `keys` must of the same type as the keys of the table.
-The output `values` is of the type of the table values.
-
-The scalar `default_value` is the value output for keys not present in the
-table. It must also be of the same type as the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Same shape as `keys`. Values found in the table, or `default_values`
- for missing keys.
-)doc");
-
-REGISTER_OP("LookupTableInsert")
- .Input("table_handle: Ref(string)")
- .Input("keys: Tin")
- .Input("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- // TODO(ebrevdo): Validate keys and values shape.
- return Status::OK();
- })
- .Doc(R"doc(
-Updates the table to associates keys with values.
-
-The tensor `keys` must be of the same type as the keys of the table.
-The tensor `values` must be of the type of the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Values to associate with keys.
-)doc");
-
-REGISTER_OP("LookupTableInsertV2")
- .Input("table_handle: resource")
- .Input("keys: Tin")
- .Input("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- // TODO: Validate keys and values shape.
- return Status::OK();
- })
- .Doc(R"doc(
-Updates the table to associates keys with values.
-
-The tensor `keys` must be of the same type as the keys of the table.
-The tensor `values` must be of the type of the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Values to associate with keys.
-)doc");
-
-REGISTER_OP("LookupTableSize")
- .Input("table_handle: Ref(string)")
- .Output("size: int64")
- .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
- .Doc(R"doc(
-Computes the number of elements in the given table.
-
-table_handle: Handle to the table.
-size: Scalar that contains number of elements in the table.
-)doc");
-
-REGISTER_OP("LookupTableSizeV2")
- .Input("table_handle: resource")
- .Output("size: int64")
- .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs)
- .Doc(R"doc(
-Computes the number of elements in the given table.
-
-table_handle: Handle to the table.
-size: Scalar that contains number of elements in the table.
-)doc");
-
-REGISTER_OP("LookupTableExport")
- .Input("table_handle: Ref(string)")
- .Output("keys: Tkeys")
- .Output("values: Tvalues")
- .Attr("Tkeys: type")
- .Attr("Tvalues: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- ShapeHandle values = c->UnknownShape();
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
- ShapeHandle keys = c->Vector(c->Dim(values, 0));
- c->set_output(0, keys);
- c->set_output(1, values);
- return Status::OK();
- })
- .Doc(R"doc(
-Outputs all keys and values in the table.
-
-table_handle: Handle to the table.
-keys: Vector of all keys present in the table.
-values: Tensor of all values in the table. Indexed in parallel with `keys`.
-)doc");
-
-REGISTER_OP("LookupTableExportV2")
- .Input("table_handle: resource")
- .Output("keys: Tkeys")
- .Output("values: Tvalues")
- .Attr("Tkeys: type")
- .Attr("Tvalues: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- ShapeHandle values = c->UnknownShape();
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
- ShapeHandle keys = c->Vector(c->Dim(values, 0));
- c->set_output(0, keys);
- c->set_output(1, values);
- return Status::OK();
- })
- .Doc(R"doc(
-Outputs all keys and values in the table.
-
-table_handle: Handle to the table.
-keys: Vector of all keys present in the table.
-values: Tensor of all values in the table. Indexed in parallel with `keys`.
-)doc");
-
-REGISTER_OP("LookupTableImport")
- .Input("table_handle: Ref(string)")
- .Input("keys: Tin")
- .Input("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- // TODO(ebrevdo): Validate keys and values shape.
- return Status::OK();
- })
- .Doc(R"doc(
-Replaces the contents of the table with the specified keys and values.
-
-The tensor `keys` must be of the same type as the keys of the table.
-The tensor `values` must be of the type of the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Values to associate with keys.
-)doc");
-
-REGISTER_OP("LookupTableImportV2")
- .Input("table_handle: resource")
- .Input("keys: Tin")
- .Input("values: Tout")
- .Attr("Tin: type")
- .Attr("Tout: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- // TODO: Validate keys and values shape.
- return Status::OK();
- })
- .Doc(R"doc(
-Replaces the contents of the table with the specified keys and values.
-
-The tensor `keys` must be of the same type as the keys of the table.
-The tensor `values` must be of the type of the table values.
-
-table_handle: Handle to the table.
-keys: Any shape. Keys to look up.
-values: Values to associate with keys.
-)doc");
-
-REGISTER_OP("HashTable")
- .Output("table_handle: Ref(string)")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .SetIsStateful()
- .SetShapeFn(TwoElementOutput)
- .Doc(R"doc(
-Creates a non-initialized hash table.
-
-This op creates a hash table, specifying the type of its keys and values.
-Before using the table you will have to initialize it. After initialization the
-table will be immutable.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-use_node_name_sharing: If true and shared_name is empty, the table is shared
- using the node name.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("HashTableV2")
- .Output("table_handle: resource")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .SetIsStateful()
- .SetShapeFn(ScalarOutput)
- .Doc(R"doc(
-Creates a non-initialized hash table.
-
-This op creates a hash table, specifying the type of its keys and values.
-Before using the table you will have to initialize it. After initialization the
-table will be immutable.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-use_node_name_sharing: If true and shared_name is empty, the table is shared
- using the node name.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableHashTable")
- .Output("table_handle: Ref(string)")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .SetIsStateful()
- .SetShapeFn(TwoElementOutput)
- .Doc(R"doc(
-Creates an empty hash table.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a scalar. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-use_node_name_sharing: If true and shared_name is empty, the table is shared
- using the node name.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableHashTableV2")
- .Output("table_handle: resource")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .SetIsStateful()
- .SetShapeFn(ScalarOutput)
- .Doc(R"doc(
-Creates an empty hash table.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a scalar. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-use_node_name_sharing: If true and shared_name is empty, the table is shared
- using the node name.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableHashTableOfTensors")
- .Output("table_handle: Ref(string)")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .Attr("value_shape: shape = {}")
- .SetIsStateful()
- .SetShapeFn(TwoElementOutput)
- .Doc(R"doc(
-Creates an empty hash table.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a vector. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableHashTableOfTensorsV2")
- .Output("table_handle: resource")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .Attr("value_shape: shape = {}")
- .SetIsStateful()
- .SetShapeFn(ScalarOutput)
- .Doc(R"doc(
-Creates an empty hash table.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a vector. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-)doc");
-
-REGISTER_OP("MutableDenseHashTable")
- .Input("empty_key: key_dtype")
- .Output("table_handle: Ref(string)")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .Attr("value_shape: shape = {}")
- .Attr("initial_num_buckets: int = 131072") // 2^17
- .Attr("max_load_factor: float = 0.8")
- .SetIsStateful()
- .SetShapeFn(TwoElementOutput)
- .Doc(R"doc(
-Creates an empty hash table that uses tensors as the backing store. It uses
-"open addressing" with quadratic reprobing to resolve collisions.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a scalar. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-empty_key: The key used to represent empty key buckets internally. Must not
- be used in insert or lookup operations.
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-value_shape: The shape of each value.
-initial_num_buckets: The initial number of hash table buckets. Must be a power
- to 2.
-max_load_factor: The maximum ratio between number of entries and number of
- buckets before growing the table. Must be between 0 and 1.
-)doc");
-
-REGISTER_OP("MutableDenseHashTableV2")
- .Input("empty_key: key_dtype")
- .Output("table_handle: resource")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("use_node_name_sharing: bool = false")
- .Attr("key_dtype: type")
- .Attr("value_dtype: type")
- .Attr("value_shape: shape = {}")
- .Attr("initial_num_buckets: int = 131072") // 2^17
- .Attr("max_load_factor: float = 0.8")
- .SetIsStateful()
- .SetShapeFn(ScalarOutput)
- .Doc(R"doc(
-Creates an empty hash table that uses tensors as the backing store. It uses
-"open addressing" with quadratic reprobing to resolve collisions.
-
-This op creates a mutable hash table, specifying the type of its keys and
-values. Each value must be a scalar. Data can be inserted into the table using
-the insert operations. It does not support the initialization operation.
-
-empty_key: The key used to represent empty key buckets internally. Must not
- be used in insert or lookup operations.
-table_handle: Handle to a table.
-container: If non-empty, this table is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this table is shared under the given name across
- multiple sessions.
-key_dtype: Type of the table keys.
-value_dtype: Type of the table values.
-value_shape: The shape of each value.
-initial_num_buckets: The initial number of hash table buckets. Must be a power
- to 2.
-max_load_factor: The maximum ratio between number of entries and number of
- buckets before growing the table. Must be between 0 and 1.
-)doc");
-
-REGISTER_OP("InitializeTable")
- .Input("table_handle: Ref(string)")
- .Input("keys: Tkey")
- .Input("values: Tval")
- .Attr("Tkey: type")
- .Attr("Tval: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- ShapeHandle keys;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
- TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
- return Status::OK();
- })
- .Doc(R"doc(
-Table initializer that takes two tensors for keys and values respectively.
-
-table_handle: Handle to a table which will be initialized.
-keys: Keys of type Tkey.
-values: Values of type Tval.
-)doc");
-
-REGISTER_OP("InitializeTableV2")
- .Input("table_handle: resource")
- .Input("keys: Tkey")
- .Input("values: Tval")
- .Attr("Tkey: type")
- .Attr("Tval: type")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- ShapeHandle keys;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
- TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
- return Status::OK();
- })
- .Doc(R"doc(
-Table initializer that takes two tensors for keys and values respectively.
-
-table_handle: Handle to a table which will be initialized.
-keys: Keys of type Tkey.
-values: Values of type Tval.
-)doc");
-
-REGISTER_OP("InitializeTableFromTextFile")
- .Input("table_handle: Ref(string)")
- .Input("filename: string")
- .Attr("key_index: int >= -2")
- .Attr("value_index: int >= -2")
- .Attr("vocab_size: int >= -1 = -1")
- .Attr("delimiter: string = '\t'")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
-
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
- return Status::OK();
- })
- .Doc(R"doc(
-Initializes a table from a text file.
-
-It inserts one key-value pair into the table for each line of the file.
-The key and value is extracted from the whole line content, elements from the
-split line based on `delimiter` or the line number (starting from zero).
-Where to extract the key and value from a line is specified by `key_index` and
-`value_index`.
-
-- A value of -1 means use the line number(starting from zero), expects `int64`.
-- A value of -2 means use the whole line content, expects `string`.
-- A value >= 0 means use the index (starting at zero) of the split line based
- on `delimiter`.
-
-table_handle: Handle to a table which will be initialized.
-filename: Filename of a vocabulary text file.
-key_index: Column index in a line to get the table `key` values from.
-value_index: Column index that represents information of a line to get the table
- `value` values from.
-vocab_size: Number of elements of the file, use -1 if unknown.
-delimiter: Delimiter to separate fields in a line.
-)doc");
-
-REGISTER_OP("InitializeTableFromTextFileV2")
- .Input("table_handle: resource")
- .Input("filename: string")
- .Attr("key_index: int >= -2")
- .Attr("value_index: int >= -2")
- .Attr("vocab_size: int >= -1 = -1")
- .Attr("delimiter: string = '\t'")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
- return Status::OK();
- })
- .Doc(R"doc(
-Initializes a table from a text file.
-
-It inserts one key-value pair into the table for each line of the file.
-The key and value is extracted from the whole line content, elements from the
-split line based on `delimiter` or the line number (starting from zero).
-Where to extract the key and value from a line is specified by `key_index` and
-`value_index`.
-
-- A value of -1 means use the line number(starting from zero), expects `int64`.
-- A value of -2 means use the whole line content, expects `string`.
-- A value >= 0 means use the index (starting at zero) of the split line based
- on `delimiter`.
-
-table_handle: Handle to a table which will be initialized.
-filename: Filename of a vocabulary text file.
-key_index: Column index in a line to get the table `key` values from.
-value_index: Column index that represents information of a line to get the table
- `value` values from.
-vocab_size: Number of elements of the file, use -1 if unknown.
-delimiter: Delimiter to separate fields in a line.
-)doc");
-
REGISTER_OP("GetSessionHandle")
.Input("value: T")
.Output("handle: string")
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
new file mode 100644
index 0000000000..498a65690d
--- /dev/null
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -0,0 +1,666 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+// --------------------------------------------------------------------------
+
+namespace {
+Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle unused_handle;
+ for (int i = 0; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
+ }
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
+Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle unused_handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+ for (int i = 1; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
+ }
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
+Status TwoElementOutput(InferenceContext* c) {
+ c->set_output(0, c->Vector(2));
+ return Status::OK();
+}
+
+Status ScalarOutput(InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+}
+} // namespace
+
+REGISTER_OP("LookupTableFind")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tin")
+ .Input("default_value: Tout")
+ .Output("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ // Default value must be scalar or vector.
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Looks up keys in a table, outputs the corresponding values.
+
+The tensor `keys` must of the same type as the keys of the table.
+The output `values` is of the type of the table values.
+
+The scalar `default_value` is the value output for keys not present in the
+table. It must also be of the same type as the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Same shape as `keys`. Values found in the table, or `default_values`
+ for missing keys.
+)doc");
+
+REGISTER_OP("LookupTableFindV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("default_value: Tout")
+ .Output("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ // Default value must be scalar or vector.
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Looks up keys in a table, outputs the corresponding values.
+
+The tensor `keys` must of the same type as the keys of the table.
+The output `values` is of the type of the table values.
+
+The scalar `default_value` is the value output for keys not present in the
+table. It must also be of the same type as the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Same shape as `keys`. Values found in the table, or `default_values`
+ for missing keys.
+)doc");
+
+REGISTER_OP("LookupTableInsert")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ // TODO(ebrevdo): Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Updates the table to associates keys with values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("LookupTableInsertV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ // TODO: Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Updates the table to associates keys with values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("LookupTableSize")
+ .Input("table_handle: Ref(string)")
+ .Output("size: int64")
+ .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
+ .Doc(R"doc(
+Computes the number of elements in the given table.
+
+table_handle: Handle to the table.
+size: Scalar that contains number of elements in the table.
+)doc");
+
+REGISTER_OP("LookupTableSizeV2")
+ .Input("table_handle: resource")
+ .Output("size: int64")
+ .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs)
+ .Doc(R"doc(
+Computes the number of elements in the given table.
+
+table_handle: Handle to the table.
+size: Scalar that contains number of elements in the table.
+)doc");
+
+REGISTER_OP("LookupTableExport")
+ .Input("table_handle: Ref(string)")
+ .Output("keys: Tkeys")
+ .Output("values: Tvalues")
+ .Attr("Tkeys: type")
+ .Attr("Tvalues: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ ShapeHandle values = c->UnknownShape();
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
+ ShapeHandle keys = c->Vector(c->Dim(values, 0));
+ c->set_output(0, keys);
+ c->set_output(1, values);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs all keys and values in the table.
+
+table_handle: Handle to the table.
+keys: Vector of all keys present in the table.
+values: Tensor of all values in the table. Indexed in parallel with `keys`.
+)doc");
+
+REGISTER_OP("LookupTableExportV2")
+ .Input("table_handle: resource")
+ .Output("keys: Tkeys")
+ .Output("values: Tvalues")
+ .Attr("Tkeys: type")
+ .Attr("Tvalues: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ ShapeHandle values = c->UnknownShape();
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
+ ShapeHandle keys = c->Vector(c->Dim(values, 0));
+ c->set_output(0, keys);
+ c->set_output(1, values);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs all keys and values in the table.
+
+table_handle: Handle to the table.
+keys: Vector of all keys present in the table.
+values: Tensor of all values in the table. Indexed in parallel with `keys`.
+)doc");
+
+REGISTER_OP("LookupTableImport")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ // TODO(ebrevdo): Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Replaces the contents of the table with the specified keys and values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("LookupTableImportV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ // TODO: Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Replaces the contents of the table with the specified keys and values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("HashTable")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Creates a non-initialized hash table.
+
+This op creates a hash table, specifying the type of its keys and values.
+Before using the table you will have to initialize it. After initialization the
+table will be immutable.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("HashTableV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates a non-initialized hash table.
+
+This op creates a hash table, specifying the type of its keys and values.
+Before using the table you will have to initialize it. After initialization the
+table will be immutable.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTable")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTableV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTableOfTensors")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a vector. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTableOfTensorsV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a vector. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableDenseHashTable")
+ .Input("empty_key: key_dtype")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .Attr("initial_num_buckets: int = 131072") // 2^17
+ .Attr("max_load_factor: float = 0.8")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Creates an empty hash table that uses tensors as the backing store. It uses
+"open addressing" with quadratic reprobing to resolve collisions.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+empty_key: The key used to represent empty key buckets internally. Must not
+ be used in insert or lookup operations.
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+value_shape: The shape of each value.
+initial_num_buckets: The initial number of hash table buckets. Must be a power
+ to 2.
+max_load_factor: The maximum ratio between number of entries and number of
+ buckets before growing the table. Must be between 0 and 1.
+)doc");
+
+REGISTER_OP("MutableDenseHashTableV2")
+ .Input("empty_key: key_dtype")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .Attr("initial_num_buckets: int = 131072") // 2^17
+ .Attr("max_load_factor: float = 0.8")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table that uses tensors as the backing store. It uses
+"open addressing" with quadratic reprobing to resolve collisions.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+empty_key: The key used to represent empty key buckets internally. Must not
+ be used in insert or lookup operations.
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+value_shape: The shape of each value.
+initial_num_buckets: The initial number of hash table buckets. Must be a power
+ to 2.
+max_load_factor: The maximum ratio between number of entries and number of
+ buckets before growing the table. Must be between 0 and 1.
+)doc");
+
+REGISTER_OP("InitializeTable")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tkey")
+ .Input("values: Tval")
+ .Attr("Tkey: type")
+ .Attr("Tval: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
+ TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Table initializer that takes two tensors for keys and values respectively.
+
+table_handle: Handle to a table which will be initialized.
+keys: Keys of type Tkey.
+values: Values of type Tval.
+)doc");
+
+REGISTER_OP("InitializeTableV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tkey")
+ .Input("values: Tval")
+ .Attr("Tkey: type")
+ .Attr("Tval: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
+ TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Table initializer that takes two tensors for keys and values respectively.
+
+table_handle: Handle to a table which will be initialized.
+keys: Keys of type Tkey.
+values: Values of type Tval.
+)doc");
+
+REGISTER_OP("InitializeTableFromTextFile")
+ .Input("table_handle: Ref(string)")
+ .Input("filename: string")
+ .Attr("key_index: int >= -2")
+ .Attr("value_index: int >= -2")
+ .Attr("vocab_size: int >= -1 = -1")
+ .Attr("delimiter: string = '\t'")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Initializes a table from a text file.
+
+It inserts one key-value pair into the table for each line of the file.
+The key and value is extracted from the whole line content, elements from the
+split line based on `delimiter` or the line number (starting from zero).
+Where to extract the key and value from a line is specified by `key_index` and
+`value_index`.
+
+- A value of -1 means use the line number(starting from zero), expects `int64`.
+- A value of -2 means use the whole line content, expects `string`.
+- A value >= 0 means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+table_handle: Handle to a table which will be initialized.
+filename: Filename of a vocabulary text file.
+key_index: Column index in a line to get the table `key` values from.
+value_index: Column index that represents information of a line to get the table
+ `value` values from.
+vocab_size: Number of elements of the file, use -1 if unknown.
+delimiter: Delimiter to separate fields in a line.
+)doc");
+
+REGISTER_OP("InitializeTableFromTextFileV2")
+ .Input("table_handle: resource")
+ .Input("filename: string")
+ .Attr("key_index: int >= -2")
+ .Attr("value_index: int >= -2")
+ .Attr("vocab_size: int >= -1 = -1")
+ .Attr("delimiter: string = '\t'")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Initializes a table from a text file.
+
+It inserts one key-value pair into the table for each line of the file.
+The key and value is extracted from the whole line content, elements from the
+split line based on `delimiter` or the line number (starting from zero).
+Where to extract the key and value from a line is specified by `key_index` and
+`value_index`.
+
+- A value of -1 means use the line number(starting from zero), expects `int64`.
+- A value of -2 means use the whole line content, expects `string`.
+- A value >= 0 means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+table_handle: Handle to a table which will be initialized.
+filename: Filename of a vocabulary text file.
+key_index: Column index in a line to get the table `key` values from.
+value_index: Column index that represents information of a line to get the table
+ `value` values from.
+vocab_size: Number of elements of the file, use -1 if unknown.
+delimiter: Delimiter to separate fields in a line.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 817d157da2..9fd5ada71e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1022,7 +1022,6 @@ tf_gen_op_wrapper_private_py(
require_shape_functions = True,
visibility = [
"//learning/brain/python/ops:__pkg__",
- "//tensorflow/python/feature_column:__pkg__",
"//tensorflow/python/kernel_tests:__pkg__",
],
)
@@ -1058,6 +1057,16 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "lookup_ops_gen",
+ require_shape_functions = True,
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow/python/feature_column:__pkg__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "math_ops_gen",
require_shape_functions = True,
visibility = [
@@ -1475,6 +1484,20 @@ py_library(
)
py_library(
+ name = "lookup_ops",
+ srcs = ["ops/lookup_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":framework",
+ ":framework_for_generated_wrappers",
+ ":lookup_ops_gen",
+ ":math_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "math_grad",
srcs = ["ops/math_grad.py"],
srcs_version = "PY2AND3",
@@ -1862,6 +1885,7 @@ py_library(
":io_ops",
":linalg_ops",
":logging_ops",
+ ":lookup_ops",
":math_grad",
":math_ops",
":numerics",
@@ -2269,6 +2293,7 @@ py_library(
":io_ops",
":io_ops_gen",
":lib",
+ ":lookup_ops",
":math_ops",
":platform",
":protos_all_py",
@@ -2991,6 +3016,7 @@ cuda_py_tests(
":framework",
":framework_for_generated_wrappers",
":framework_test_lib",
+ ":lookup_ops",
":gradients",
":math_ops",
":nn_grad",
@@ -3021,7 +3047,7 @@ py_library(
srcs = ["training/saver_test_utils.py"],
srcs_version = "PY2AND3",
deps = [
- ":data_flow_ops_gen",
+ ":lookup_ops_gen",
":training",
],
)
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index f70c285f04..b8064f0a77 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -42,8 +42,8 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import state_ops
@@ -1391,9 +1391,10 @@ class EstimatorExportTest(test.TestCase):
my_int = variables.Variable(1, name='my_int',
collections=[ops.GraphKeys.LOCAL_VARIABLES])
scores = constant_op.constant([3.])
- with ops.control_dependencies(
- [variables.local_variables_initializer(),
- data_flow_ops.tables_initializer()]):
+ with ops.control_dependencies([
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()
+ ]):
assign_op = state_ops.assign(my_int, 12345)
# local_initSop must be an Operation, not a Tensor.
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index d734273845..ac7aef96ac 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -80,9 +80,9 @@ py_library(
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
- "//tensorflow/python:data_flow_ops_gen",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:lookup_ops_gen",
"//tensorflow/python:math_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index d85142abcf..ad67a082dc 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@@ -41,7 +41,7 @@ from tensorflow.python.platform import test
def _initialized_session():
sess = session.Session()
sess.run(variables_lib.global_variables_initializer())
- sess.run(data_flow_ops.tables_initializer())
+ sess.run(lookup_ops.tables_initializer())
return sess
@@ -1277,7 +1277,7 @@ class VocabularyCategoricalColumnTest(test.TestCase):
# pylint: enable=protected-access
with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
def test_invalid_vocabulary_size(self):
with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
@@ -1307,7 +1307,7 @@ class VocabularyCategoricalColumnTest(test.TestCase):
# pylint: enable=protected-access
with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
with self.test_session():
- data_flow_ops.tables_initializer().run()
+ lookup_ops.tables_initializer().run()
def test_invalid_num_oov_buckets(self):
with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
diff --git a/tensorflow/python/feature_column/lookup_ops.py b/tensorflow/python/feature_column/lookup_ops.py
index 13a67fa518..8225b47b20 100644
--- a/tensorflow/python/feature_column/lookup_ops.py
+++ b/tensorflow/python/feature_column/lookup_ops.py
@@ -27,7 +27,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import gen_lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.training.saver import BaseSaverBuilder
@@ -151,7 +151,7 @@ class InitializableLookupTableBase(LookupInterface):
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as scope:
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope)
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=scope)
# pylint: enable=protected-access
def lookup(self, keys, name=None):
@@ -182,7 +182,7 @@ class InitializableLookupTableBase(LookupInterface):
name, "%s_Lookup" % self._name,
(self._table_ref, key_tensor, self._default_value)) as scope:
# pylint: disable=protected-access
- values = gen_data_flow_ops._lookup_table_find(
+ values = gen_lookup_ops._lookup_table_find(
self._table_ref, key_tensor, self._default_value, name=scope)
# pylint: enable=protected-access
@@ -229,7 +229,7 @@ class HashTable(InitializableLookupTableBase):
with ops.name_scope(
name, "hash_table", (initializer, default_value)) as scope:
# pylint: disable=protected-access
- table_ref = gen_data_flow_ops._hash_table(
+ table_ref = gen_lookup_ops._hash_table(
shared_name=shared_name,
key_dtype=initializer.key_dtype,
value_dtype=initializer.value_dtype,
@@ -308,10 +308,8 @@ class KeyValueTensorInitializer(TableInitializerBase):
self._name,
values=(table.table_ref, self._keys, self._values)) as scope:
# pylint: disable=protected-access
- init_op = gen_data_flow_ops._initialize_table(table.table_ref,
- self._keys,
- self._values,
- name=scope)
+ init_op = gen_lookup_ops._initialize_table(
+ table.table_ref, self._keys, self._values, name=scope)
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
@@ -477,7 +475,7 @@ class TextFileInitializer(TableInitializerBase):
dtypes.string,
name="asset_filepath")
# pylint: disable=protected-access
- init_op = gen_data_flow_ops._initialize_table_from_text_file(
+ init_op = gen_lookup_ops._initialize_table_from_text_file(
table.table_ref,
filename,
self._key_index,
@@ -1333,14 +1331,14 @@ class MutableHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
# pylint: disable=protected-access
if self._default_value.get_shape().ndims == 0:
- self._table_ref = gen_data_flow_ops._mutable_hash_table(
+ self._table_ref = gen_lookup_ops._mutable_hash_table(
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=key_dtype,
value_dtype=value_dtype,
name=name)
else:
- self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors(
+ self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors(
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=key_dtype,
@@ -1368,7 +1366,7 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as name:
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=name)
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -1394,10 +1392,8 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
(self._table_ref, keys, self._default_value)) as name:
# pylint: disable=protected-access
- values = gen_data_flow_ops._lookup_table_find(self._table_ref,
- keys,
- self._default_value,
- name=name)
+ values = gen_lookup_ops._lookup_table_find(
+ self._table_ref, keys, self._default_value, name=name)
values.set_shape(keys.get_shape().concatenate(self._value_shape))
return values
@@ -1423,7 +1419,7 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
# pylint: disable=protected-access
- op = gen_data_flow_ops._lookup_table_insert(
+ op = gen_lookup_ops._lookup_table_insert(
self._table_ref, keys, values, name=name)
return op
@@ -1440,11 +1436,8 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_export_values" % self._name,
[self._table_ref]) as name:
# pylint: disable=protected-access
- exported_keys, exported_values = gen_data_flow_ops._lookup_table_export(
- self._table_ref,
- self._key_dtype,
- self._value_dtype,
- name=name)
+ exported_keys, exported_values = gen_lookup_ops._lookup_table_export(
+ self._table_ref, self._key_dtype, self._value_dtype, name=name)
exported_values.set_shape(exported_keys.get_shape().concatenate(
self._value_shape))
@@ -1464,7 +1457,7 @@ class MutableHashTable(LookupInterface):
def restore(self, restored_tensors, unused_restored_shapes):
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_import(
+ return gen_lookup_ops._lookup_table_import(
self.op._table_ref, restored_tensors[0], restored_tensors[1])
@@ -1539,7 +1532,7 @@ class MutableDenseHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
# pylint: disable=protected-access
- self._table_ref = gen_data_flow_ops._mutable_dense_hash_table(
+ self._table_ref = gen_lookup_ops._mutable_dense_hash_table(
empty_key=empty_key,
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
@@ -1567,7 +1560,7 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as name:
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
+ return gen_lookup_ops._lookup_table_size(self._table_ref, name=name)
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -1593,7 +1586,7 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
[self._table_ref, keys]) as name:
# pylint: disable=protected-access
- values = gen_data_flow_ops._lookup_table_find(
+ values = gen_lookup_ops._lookup_table_find(
self._table_ref, keys, self._default_value, name=name)
if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
@@ -1623,7 +1616,7 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
# pylint: disable=protected-access
- op = gen_data_flow_ops._lookup_table_insert(
+ op = gen_lookup_ops._lookup_table_insert(
self._table_ref, keys, values, name=name)
return op
@@ -1640,7 +1633,7 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_export_values" % self._name,
[self._table_ref]) as name:
# pylint: disable=protected-access
- exported_keys, exported_values = gen_data_flow_ops._lookup_table_export(
+ exported_keys, exported_values = gen_lookup_ops._lookup_table_export(
self._table_ref, self._key_dtype, self._value_dtype, name=name)
exported_values.set_shape(exported_keys.get_shape().concatenate(
@@ -1661,6 +1654,5 @@ class MutableDenseHashTable(LookupInterface):
def restore(self, restored_tensors, unused_restored_shapes):
# pylint: disable=protected-access
- return gen_data_flow_ops._lookup_table_import(self.op._table_ref,
- restored_tensors[0],
- restored_tensors[1])
+ return gen_lookup_ops._lookup_table_import(
+ self.op._table_ref, restored_tensors[0], restored_tensors[1])
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 95e803e2aa..9a208613ad 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -38,7 +38,6 @@ from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.deprecation import deprecated
def _as_type_list(dtypes):
@@ -1037,47 +1036,6 @@ class Barrier(object):
self._barrier_ref, name=name)
-@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.")
-def initialize_all_tables(name="init_all_tables"):
- """Returns an Op that initializes all tables of the default graph.
-
- Args:
- name: Optional name for the initialization op.
-
- Returns:
- An Op that initializes all tables. Note that if there are
- not tables the returned Op is a NoOp.
- """
- return tables_initializer(name)
-
-
-def tables_initializer(name="init_all_tables"):
- """Returns an Op that initializes all tables of the default graph.
-
- Args:
- name: Optional name for the initialization op.
-
- Returns:
- An Op that initializes all tables. Note that if there are
- not tables the returned Op is a NoOp.
- """
- initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
- if initializers:
- return control_flow_ops.group(*initializers, name=name)
- return control_flow_ops.no_op(name=name)
-
-
-ops.NotDifferentiable("LookupTableFind")
-ops.NotDifferentiable("LookupTableInsert")
-ops.NotDifferentiable("LookupTableSize")
-ops.NotDifferentiable("HashTable")
-ops.NotDifferentiable("InitializeTable")
-ops.NotDifferentiable("InitializeTableFromTextFile")
-ops.NotDifferentiable("MutableDenseHashTable")
-ops.NotDifferentiable("MutableHashTable")
-ops.NotDifferentiable("MutableHashTableOfTensors")
-
-
class ConditionalAccumulatorBase(object):
"""A conditional accumulator for aggregating gradients.
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
new file mode 100644
index 0000000000..54dba9e38e
--- /dev/null
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#==============================================================================
+"""Data Flow Operations."""
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_lookup_ops import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.deprecation import deprecated
+
+
+@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.")
+def initialize_all_tables(name="init_all_tables"):
+ """Returns an Op that initializes all tables of the default graph.
+
+ Args:
+ name: Optional name for the initialization op.
+
+ Returns:
+ An Op that initializes all tables. Note that if there are
+ not tables the returned Op is a NoOp.
+ """
+ return tables_initializer(name)
+
+
+def tables_initializer(name="init_all_tables"):
+ """Returns an Op that initializes all tables of the default graph.
+
+ Args:
+ name: Optional name for the initialization op.
+
+ Returns:
+ An Op that initializes all tables. Note that if there are
+ not tables the returned Op is a NoOp.
+ """
+ initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
+ if initializers:
+ return control_flow_ops.group(*initializers, name=name)
+ return control_flow_ops.no_op(name=name)
+
+
+ops.NotDifferentiable("LookupTableFind")
+ops.NotDifferentiable("LookupTableFindV2")
+ops.NotDifferentiable("LookupTableInsert")
+ops.NotDifferentiable("LookupTableInsertV2")
+ops.NotDifferentiable("LookupTableSize")
+ops.NotDifferentiable("LookupTableSizeV2")
+ops.NotDifferentiable("HashTable")
+ops.NotDifferentiable("HashTableV2")
+ops.NotDifferentiable("InitializeTable")
+ops.NotDifferentiable("InitializeTableV2")
+ops.NotDifferentiable("InitializeTableFromTextFile")
+ops.NotDifferentiable("InitializeTableFromTextFileV2")
+ops.NotDifferentiable("MutableDenseHashTable")
+ops.NotDifferentiable("MutableDenseHashTableV2")
+ops.NotDifferentiable("MutableHashTable")
+ops.NotDifferentiable("MutableHashTableV2")
+ops.NotDifferentiable("MutableHashTableOfTensors")
+ops.NotDifferentiable("MutableHashTableOfTensorsV2")
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 09e04d4247..a39d28490c 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -57,6 +57,7 @@ from tensorflow.python.ops.io_ops import *
from tensorflow.python.ops.linalg_ops import *
from tensorflow.python.ops.logging_ops import Print
from tensorflow.python.ops.logging_ops import get_summary_op
+from tensorflow.python.ops.lookup_ops import *
from tensorflow.python.ops.math_ops import *
from tensorflow.python.ops.numerics import *
from tensorflow.python.ops.parsing_ops import *
diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py
index 66cf9d4d8a..355fd57bf1 100644
--- a/tensorflow/python/saved_model/main_op_impl.py
+++ b/tensorflow/python/saved_model/main_op_impl.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
@@ -35,7 +35,7 @@ def main_op():
"""
init = variables.global_variables_initializer()
init_local = variables.local_variables_initializer()
- init_tables = tf_data_flow_ops.tables_initializer()
+ init_tables = lookup_ops.tables_initializer()
return control_flow_ops.group(init, init_local, init_tables)
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 4c81af56ad..fcec3ed97c 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -26,7 +26,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
@@ -238,7 +238,7 @@ class Scaffold(object):
@staticmethod
def _default_local_init_op():
return control_flow_ops.group(variables.local_variables_initializer(),
- data_flow_ops.tables_initializer())
+ lookup_ops.tables_initializer())
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py
index 5f31e2aa53..6a73565f82 100644
--- a/tensorflow/python/training/saver_test_utils.py
+++ b/tensorflow/python/training/saver_test_utils.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import gen_lookup_ops
from tensorflow.python.training import saver as saver_module
@@ -34,7 +34,7 @@ class CheckpointedOp(object):
# pylint: disable=protected-access
def __init__(self, name, table_ref=None):
if table_ref is None:
- self.table_ref = gen_data_flow_ops._mutable_hash_table(
+ self.table_ref = gen_lookup_ops._mutable_hash_table(
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
else:
self.table_ref = table_ref
@@ -52,10 +52,10 @@ class CheckpointedOp(object):
return self._saveable
def insert(self, keys, values):
- return gen_data_flow_ops._lookup_table_insert(self.table_ref, keys, values)
+ return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values)
def lookup(self, keys, default):
- return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default)
+ return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default)
def keys(self):
return self._export()[0]
@@ -64,8 +64,8 @@ class CheckpointedOp(object):
return self._export()[1]
def _export(self):
- return gen_data_flow_ops._lookup_table_export(self.table_ref, dtypes.string,
- dtypes.float32)
+ return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string,
+ dtypes.float32)
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
"""A custom saveable for CheckpointedOp."""
@@ -81,6 +81,6 @@ class CheckpointedOp(object):
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
def restore(self, restore_tensors, shapes):
- return gen_data_flow_ops._lookup_table_import(
+ return gen_lookup_ops._lookup_table_import(
self.op.table_ref, restore_tensors[0], restore_tensors[1])
# pylint: enable=protected-access
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 277c11386d..230ed1db68 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -27,7 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as _summary
@@ -426,8 +426,10 @@ class Supervisor(object):
local_init_op = self._get_first_op_from_collection(
ops.GraphKeys.LOCAL_INIT_OP)
if local_init_op is None:
- op_list = [variables.local_variables_initializer(),
- data_flow_ops.tables_initializer()]
+ op_list = [
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()
+ ]
if op_list:
local_init_op = control_flow_ops.group(*op_list)
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)