diff options
author | Yutaka Leon <yleon@google.com> | 2017-05-04 12:31:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-04 13:50:10 -0700 |
commit | dd140f79e06a81c52cd8fc9ec6cda975a78a401f (patch) | |
tree | ca8cb309a8853c31074e649f4d9642ed9a2bacd0 | |
parent | e46a12bc9fbcea1fef224daa47eb9f1cf9e56472 (diff) |
Organize the lookup table ops into it's own lookup_ops.cc file instead of data_flow_ops.cc
Change: 155119120
26 files changed, 950 insertions, 811 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index b2dad0162e..a09cc53571 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -33,9 +33,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope @@ -224,7 +224,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(keys_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[keys_sparse].indices.eval(), @@ -242,7 +242,7 @@ class TransformerTest(test.TestCase): output = feature_column_ops._Transformer(features).transform(keys_sparse) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # While the input is a dense Tensor, the output should be a SparseTensor. self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertEqual(output.dtype, dtypes.int64) @@ -311,7 +311,7 @@ class TransformerTest(test.TestCase): self.assertIn(weighted_ids, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(), ids_tensor.dense_shape.eval()) self.assertAllEqual(output[weighted_ids][0].indices.eval(), @@ -341,7 +341,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -363,7 +363,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -387,7 +387,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -409,7 +409,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -601,7 +601,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): one_hot_column, embedding_column, real_valued_column]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10]) def testRealValuedColumn(self): @@ -714,7 +714,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_column]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]], output.eval()) @@ -732,7 +732,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]], output.eval()) @@ -750,7 +750,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], output.eval()) @@ -784,7 +784,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([3, 10], output.eval().shape) def testEmbeddingColumnSucceedsForDNN(self): @@ -891,7 +891,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self): @@ -914,7 +914,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self): @@ -965,7 +965,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "Error creating input layer for column: ids_weighted_by_weights"): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() feature_column_ops.input_from_feature_columns(features, [weighted_ids]) def testCrossedColumnFailsForDNN(self): @@ -1072,7 +1072,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # score: (sum of weights) self.assertAllEqual(output.eval(), [[10.], [50.], [0.]]) @@ -1310,7 +1310,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = np.array([4, 3, 4]) @@ -1344,7 +1344,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = np.array([4, 3, hash_buckets]) @@ -1374,7 +1374,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) self.assertAllEqual(expected_input_shape, model_input.shape) @@ -1403,7 +1403,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) self.assertAllEqual(expected_input_shape, model_input.shape) @@ -1433,7 +1433,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): embedding_weights) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input, gradients = sess.run([model_input_tensor, gradient_tensor]) expected_input_shape = [4, 3, embedding_dimension] @@ -1500,7 +1500,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = [ @@ -1581,7 +1581,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_ids], num_outputs=5) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) def testWeightedSparseColumnWithDenseInputTensor(self): @@ -1597,7 +1597,7 @@ class WeightedSumTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) def testCrossedColumn(self): @@ -1651,7 +1651,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) with self.test_session() as sess: variables_lib.initialize_all_variables().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] self.assertEqual(weights.get_shape(), (3, 1)) @@ -1726,7 +1726,7 @@ class WeightedSumTest(test.TestCase): features, [age, language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1766,7 +1766,7 @@ class WeightedSumTest(test.TestCase): self.assertEqual(len(variables), 1) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1830,7 +1830,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1858,7 +1858,7 @@ class WeightedSumTest(test.TestCase): features, [language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # score: 0.1 + language_weight['hindi'] + language_weight['english'] sess.run(bias.assign([0.1])) @@ -1881,7 +1881,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] self.assertEqual(weights.get_shape(), (15, 1)) @@ -1915,7 +1915,7 @@ class WeightedSumTest(test.TestCase): features, [country_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1939,7 +1939,7 @@ class WeightedSumTest(test.TestCase): features, [language_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[language_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1972,7 +1972,7 @@ class WeightedSumTest(test.TestCase): features, [country_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language][0] sess.run(weights.assign(weights + 0.4)) @@ -2013,7 +2013,7 @@ class WeightedSumTest(test.TestCase): scope=scope)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(2, len(column_to_variable[country])) self.assertEqual(3, len(column_to_variable[language])) @@ -2050,7 +2050,7 @@ class WeightedSumTest(test.TestCase): features, [country, age, incomes], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() incomes_weights = column_to_variable[incomes][0] sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]])) @@ -2086,7 +2086,7 @@ class WeightedSumTest(test.TestCase): features, [country, age, height, incomes], num_outputs=5)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() height_weights = column_to_variable[height][0] sess.run( @@ -2116,7 +2116,7 @@ class WeightedSumTest(test.TestCase): features, [bucket], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3], [0.4]])) @@ -2144,7 +2144,7 @@ class WeightedSumTest(test.TestCase): features, [bucket, country], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # dimension = 2, bucket_size = 4, num_classes = 1 sess.run(column_to_variable[bucket][0].assign( @@ -2173,7 +2173,7 @@ class WeightedSumTest(test.TestCase): features, [bucket, country], num_outputs=5)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # dimension = 2, bucket_size = 4, num_classes = 5 sess.run(column_to_variable[bucket][0].assign( @@ -2209,7 +2209,7 @@ class WeightedSumTest(test.TestCase): features, [country_price], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_price][0] sess.run(weights.assign(weights + 0.4)) @@ -2248,7 +2248,7 @@ class WeightedSumTest(test.TestCase): features, [country_language_price], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language_price][0] sess.run(weights.assign(weights + 0.4)) @@ -2272,7 +2272,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2287,7 +2287,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2302,7 +2302,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.6], [0.7]]) @@ -2323,7 +2323,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2335,7 +2335,7 @@ class WeightedSumTest(test.TestCase): features, [feature_column.real_valued_column("age")], num_outputs=3) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() sess.run(bias.assign([0.1, 0.2, 0.3])) self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) @@ -2349,7 +2349,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (1, 3)) sess.run(weights.assign([[0.01, 0.03, 0.05]])) @@ -2373,7 +2373,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) sess.run( @@ -2399,7 +2399,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2439,7 +2439,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2468,7 +2468,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2533,7 +2533,7 @@ class ParseExampleTest(test.TestCase): self.assertIn(bucket, output) self.assertIn(wire_cast, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]]) self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]]) self.assertAllEqual(output[wire_cast].values.eval(), [2, 0]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 43b3d2a78f..58072500d1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -157,7 +157,7 @@ class DynamicRnnEstimatorTest(test.TestCase): self.context_feature_columns) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) sequence_input_val = sess.run(sequence_input) expected_shape = np.array([ 3, # expected batch size @@ -178,7 +178,7 @@ class DynamicRnnEstimatorTest(test.TestCase): # Obtain values of activations and final state. with session.Session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) activations, final_state = sess.run([activations_t, final_state_t]) expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 74a6da20d4..36f843ba8e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -57,7 +57,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -1292,7 +1292,7 @@ class Estimator(BaseEstimator): init_op = control_flow_ops.group( variables.local_variables_initializer(), resources.initialize_resources(resources.shared_resources()), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 207a189a94..d5777088de 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -32,7 +32,7 @@ from tensorflow.core.framework import summary_pb2 from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses as losses_lib from tensorflow.python.platform import test @@ -1214,7 +1214,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.), (0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( [0, 2], model_fn_ops.predictions["classes"].eval()) @@ -1266,7 +1266,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.), (0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( [b"key0", b"key2"], model_fn_ops.predictions["classes"].eval()) @@ -1301,7 +1301,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) _assert_summary_tags(self, ["loss"]) @@ -1327,7 +1327,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) _assert_summary_tags(self, ["loss"]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index f5bd03429c..feea6c5fed 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -35,8 +35,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -55,7 +55,7 @@ class PrepareInputsForRnnTest(test.TestCase): with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.initialize_all_tables()) + sess.run(lookup_ops.tables_initializer()) features_val = sess.run(features_by_time) self.assertAllEqual(expected, features_val) @@ -316,7 +316,7 @@ class StateSavingRnnEstimatorTest(test.TestCase): with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.initialize_all_tables()) + sess.run(lookup_ops.tables_initializer()) actual_sequence, actual_context = sess.run( [sequence, context]) assert_equal(expected_sequence, actual_sequence) diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 4b7867f2d0..98365c05f6 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -37,8 +37,8 @@ from tensorflow.python.client import session as tf_session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -429,11 +429,14 @@ def _get_ready_op(): def _get_local_init_op(): + """Returns the local init ops to initialize tables and local variables.""" local_init_op = _get_first_op_from_collection( ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: - op_list = [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()] + op_list = [ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) @@ -680,7 +683,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None): else: session.run(variables.global_variables_initializer()) session.run(variables.local_variables_initializer()) - session.run(data_flow_ops.tables_initializer()) + session.run(lookup_ops.tables_initializer()) coord = coordinator.Coordinator() threads = None try: diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index b53be29283..36a1f5f60c 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver @@ -67,17 +67,17 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, with graph.as_default(): with tf_session.Session('') as session: variables.local_variables_initializer() - data_flow_ops.tables_initializer() + lookup_ops.tables_initializer() saver.restore(session, checkpoint_path) export = exporter.Exporter(saver) - export.init(init_op=control_flow_ops.group( - variables.local_variables_initializer(), - data_flow_ops.tables_initializer()), - default_graph_signature=default_graph_signature, - named_graph_signatures=named_graph_signatures, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS)) + export.init( + init_op=control_flow_ops.group( + variables.local_variables_initializer(), + lookup_ops.tables_initializer()), + default_graph_signature=default_graph_signature, + named_graph_signatures=named_graph_signatures, + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) return export.export(export_dir, contrib_variables.get_global_step(), session, exports_to_keep=exports_to_keep) diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index 5966c86dfb..bbbd340352 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -30,11 +30,11 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:data_flow_ops", "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lookup_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 0ec40a63f2..5ec169b6db 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver @@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase): table3 = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(3, table1.size().eval()) self.assertAllEqual(3, table2.size().eval()) self.assertAllEqual(3, table3.size().eval()) @@ -1184,7 +1184,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int32_index_table_from_file(self): @@ -1198,7 +1198,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_file(self): @@ -1212,7 +1212,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_index_table_from_file_with_default_value(self): @@ -1224,7 +1224,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_file_with_oov_buckets(self): @@ -1236,7 +1236,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( ( 1, # From vocabulary file. @@ -1259,7 +1259,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, -1, -1), ids.eval()) self.assertEqual(2, table.size().eval()) @@ -1286,7 +1286,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), ids.eval()) self.assertEqual(3, table.size().eval()) @@ -1345,7 +1345,7 @@ class IndexTableFromTensor(test.TestCase): ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int32_index_table_from_tensor_with_tensor_init(self): @@ -1356,7 +1356,7 @@ class IndexTableFromTensor(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_tensor_with_tensor_init(self): @@ -1367,7 +1367,7 @@ class IndexTableFromTensor(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_index_table_from_tensor_with_default_value(self): @@ -1378,7 +1378,7 @@ class IndexTableFromTensor(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_tensor_missing_mapping(self): @@ -1394,7 +1394,7 @@ class IndexTableFromTensor(test.TestCase): self.assertRaises(errors_impl.OpError, ids.eval) with self.assertRaisesRegexp( errors_impl.OpError, "keys and values cannot be empty"): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() def test_index_table_from_tensor_with_invalid_hashers(self): with self.test_session(): @@ -1422,7 +1422,7 @@ class StringToIndexTest(test.TestCase): indices = lookup.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, indices.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), indices.eval()) @@ -1433,7 +1433,7 @@ class StringToIndexTest(test.TestCase): _ = lookup.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, - data_flow_ops.tables_initializer().run) + lookup_ops.tables_initializer().run) def test_string_to_index_with_default_value(self): default_value = -42 @@ -1444,7 +1444,7 @@ class StringToIndexTest(test.TestCase): feats, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, indices.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), indices.eval()) @@ -1463,7 +1463,7 @@ class IndexToStringTableFromFileTest(test.TestCase): vocabulary_file=vocabulary_file) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -1475,7 +1475,7 @@ class IndexToStringTableFromFileTest(test.TestCase): vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -1489,7 +1489,7 @@ class IndexToStringTableFromFileTest(test.TestCase): default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", default_value, default_value), features.eval()) @@ -1501,7 +1501,7 @@ class IndexToStringTableFromFileTest(test.TestCase): features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - init = data_flow_ops.tables_initializer() + init = lookup_ops.tables_initializer() self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Invalid vocab_size", init.run) @@ -1513,7 +1513,7 @@ class IndexToStringTableFromFileTest(test.TestCase): features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval()) @@ -1528,7 +1528,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) features = table.lookup(indices) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -1540,7 +1540,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): mapping=mapping_strings) indices = constant_op.constant([0, 1, 4], dtypes.int64) features = table.lookup(indices) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval()) def test_index_to_string_with_default_value(self): @@ -1553,7 +1553,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): features = table.lookup(indices) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -1567,7 +1567,7 @@ class IndexToStringTest(test.TestCase): feats = lookup.index_to_string(indices, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, feats.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), feats.eval()) @@ -1577,11 +1577,11 @@ class IndexToStringTest(test.TestCase): mapping_strings = constant_op.constant(["hello", "hello"]) indices = constant_op.constant([0, 1, 4], dtypes.int64) feats = lookup.index_to_string(indices, mapping=mapping_strings) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval()) self.assertRaises(errors_impl.OpError, - data_flow_ops.tables_initializer().run) + lookup_ops.tables_initializer().run) def test_index_to_string_with_default_value(self): default_value = b"NONE" @@ -1592,7 +1592,7 @@ class IndexToStringTest(test.TestCase): indices, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, feats.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval()) @@ -1755,7 +1755,7 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value, shared_name=shared_name) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -2081,7 +2081,7 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup.StrongHashSpec((1, 2)), name="table2") - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string = constant_op.constant( ["fruit", "brain", "salad", "surgery", "UNK"]) @@ -2167,7 +2167,7 @@ class IdTableWithHashBucketsTest(test.TestCase): default_value2), oov_buckets) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string_1 = constant_op.constant( ["brain", "salad", "surgery", "UNK"]) diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 5ced8a4f08..b70d612f55 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -261,7 +261,7 @@ from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging @@ -657,7 +657,7 @@ def train(train_op, if local_init_op == _USE_DEFAULT: local_init_op = control_flow_ops.group( tf_variables.local_variables_initializer(), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) if sync_optimizer is not None and isinstance( sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 119bc0f899..435618ace7 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -506,6 +506,7 @@ tf_gen_op_libs( "image_ops", "io_ops", "linalg_ops", + "lookup_ops", "logging_ops", "math_ops", "nn_ops", @@ -582,6 +583,7 @@ cc_library( ":image_ops_op_lib", ":io_ops_op_lib", ":linalg_ops_op_lib", + ":lookup_ops_op_lib", ":logging_ops_op_lib", ":math_ops_op_lib", ":nn_ops_op_lib", @@ -708,6 +710,7 @@ cc_library( "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", + "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:math", "//tensorflow/core/kernels:multinomial_op", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0847d1279b..02ab30a04f 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1327,6 +1327,14 @@ cc_library( ], ) +cc_library( + name = "lookup", + deps = [ + ":lookup_table_init_op", + ":lookup_table_op", + ], +) + DATA_FLOW_DEPS = [ ":bounds_check", ":concat_lib", @@ -1450,10 +1458,10 @@ LOOKUP_DEPS = [ ":initializable_lookup_table", ":lookup_util", "//tensorflow/core:core_cpu", - "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:lookup_ops_op_lib", ] tf_kernel_library( diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index f35a1bb648..032ede6459 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1876,604 +1876,6 @@ size: The number of incomplete elements (i.e. those with some of their value // -------------------------------------------------------------------------- -REGISTER_OP("LookupTableFind") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("default_value: Tout") - .Output("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // Default value must be scalar or vector. - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); - c->set_output(0, c->UnknownShape()); - return Status::OK(); - }) - .Doc(R"doc( -Looks up keys in a table, outputs the corresponding values. - -The tensor `keys` must of the same type as the keys of the table. -The output `values` is of the type of the table values. - -The scalar `default_value` is the value output for keys not present in the -table. It must also be of the same type as the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Same shape as `keys`. Values found in the table, or `default_values` - for missing keys. -)doc"); - -REGISTER_OP("LookupTableFindV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("default_value: Tout") - .Output("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // Default value must be scalar or vector. - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); - c->set_output(0, c->UnknownShape()); - return Status::OK(); - }) - .Doc(R"doc( -Looks up keys in a table, outputs the corresponding values. - -The tensor `keys` must of the same type as the keys of the table. -The output `values` is of the type of the table values. - -The scalar `default_value` is the value output for keys not present in the -table. It must also be of the same type as the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Same shape as `keys`. Values found in the table, or `default_values` - for missing keys. -)doc"); - -REGISTER_OP("LookupTableInsert") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // TODO(ebrevdo): Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Updates the table to associates keys with values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableInsertV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // TODO: Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Updates the table to associates keys with values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableSize") - .Input("table_handle: Ref(string)") - .Output("size: int64") - .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) - .Doc(R"doc( -Computes the number of elements in the given table. - -table_handle: Handle to the table. -size: Scalar that contains number of elements in the table. -)doc"); - -REGISTER_OP("LookupTableSizeV2") - .Input("table_handle: resource") - .Output("size: int64") - .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs) - .Doc(R"doc( -Computes the number of elements in the given table. - -table_handle: Handle to the table. -size: Scalar that contains number of elements in the table. -)doc"); - -REGISTER_OP("LookupTableExport") - .Input("table_handle: Ref(string)") - .Output("keys: Tkeys") - .Output("values: Tvalues") - .Attr("Tkeys: type") - .Attr("Tvalues: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->Vector(c->Dim(values, 0)); - c->set_output(0, keys); - c->set_output(1, values); - return Status::OK(); - }) - .Doc(R"doc( -Outputs all keys and values in the table. - -table_handle: Handle to the table. -keys: Vector of all keys present in the table. -values: Tensor of all values in the table. Indexed in parallel with `keys`. -)doc"); - -REGISTER_OP("LookupTableExportV2") - .Input("table_handle: resource") - .Output("keys: Tkeys") - .Output("values: Tvalues") - .Attr("Tkeys: type") - .Attr("Tvalues: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->Vector(c->Dim(values, 0)); - c->set_output(0, keys); - c->set_output(1, values); - return Status::OK(); - }) - .Doc(R"doc( -Outputs all keys and values in the table. - -table_handle: Handle to the table. -keys: Vector of all keys present in the table. -values: Tensor of all values in the table. Indexed in parallel with `keys`. -)doc"); - -REGISTER_OP("LookupTableImport") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // TODO(ebrevdo): Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Replaces the contents of the table with the specified keys and values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableImportV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // TODO: Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Replaces the contents of the table with the specified keys and values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("HashTable") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates a non-initialized hash table. - -This op creates a hash table, specifying the type of its keys and values. -Before using the table you will have to initialize it. After initialization the -table will be immutable. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("HashTableV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates a non-initialized hash table. - -This op creates a hash table, specifying the type of its keys and values. -Before using the table you will have to initialize it. After initialization the -table will be immutable. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTable") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableOfTensors") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a vector. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableOfTensorsV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a vector. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableDenseHashTable") - .Input("empty_key: key_dtype") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .Attr("initial_num_buckets: int = 131072") // 2^17 - .Attr("max_load_factor: float = 0.8") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table that uses tensors as the backing store. It uses -"open addressing" with quadratic reprobing to resolve collisions. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -empty_key: The key used to represent empty key buckets internally. Must not - be used in insert or lookup operations. -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -value_shape: The shape of each value. -initial_num_buckets: The initial number of hash table buckets. Must be a power - to 2. -max_load_factor: The maximum ratio between number of entries and number of - buckets before growing the table. Must be between 0 and 1. -)doc"); - -REGISTER_OP("MutableDenseHashTableV2") - .Input("empty_key: key_dtype") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .Attr("initial_num_buckets: int = 131072") // 2^17 - .Attr("max_load_factor: float = 0.8") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table that uses tensors as the backing store. It uses -"open addressing" with quadratic reprobing to resolve collisions. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -empty_key: The key used to represent empty key buckets internally. Must not - be used in insert or lookup operations. -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -value_shape: The shape of each value. -initial_num_buckets: The initial number of hash table buckets. Must be a power - to 2. -max_load_factor: The maximum ratio between number of entries and number of - buckets before growing the table. Must be between 0 and 1. -)doc"); - -REGISTER_OP("InitializeTable") - .Input("table_handle: Ref(string)") - .Input("keys: Tkey") - .Input("values: Tval") - .Attr("Tkey: type") - .Attr("Tval: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - ShapeHandle keys; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); - TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); - return Status::OK(); - }) - .Doc(R"doc( -Table initializer that takes two tensors for keys and values respectively. - -table_handle: Handle to a table which will be initialized. -keys: Keys of type Tkey. -values: Values of type Tval. -)doc"); - -REGISTER_OP("InitializeTableV2") - .Input("table_handle: resource") - .Input("keys: Tkey") - .Input("values: Tval") - .Attr("Tkey: type") - .Attr("Tval: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - ShapeHandle keys; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); - TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); - return Status::OK(); - }) - .Doc(R"doc( -Table initializer that takes two tensors for keys and values respectively. - -table_handle: Handle to a table which will be initialized. -keys: Keys of type Tkey. -values: Values of type Tval. -)doc"); - -REGISTER_OP("InitializeTableFromTextFile") - .Input("table_handle: Ref(string)") - .Input("filename: string") - .Attr("key_index: int >= -2") - .Attr("value_index: int >= -2") - .Attr("vocab_size: int >= -1 = -1") - .Attr("delimiter: string = '\t'") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); - return Status::OK(); - }) - .Doc(R"doc( -Initializes a table from a text file. - -It inserts one key-value pair into the table for each line of the file. -The key and value is extracted from the whole line content, elements from the -split line based on `delimiter` or the line number (starting from zero). -Where to extract the key and value from a line is specified by `key_index` and -`value_index`. - -- A value of -1 means use the line number(starting from zero), expects `int64`. -- A value of -2 means use the whole line content, expects `string`. -- A value >= 0 means use the index (starting at zero) of the split line based - on `delimiter`. - -table_handle: Handle to a table which will be initialized. -filename: Filename of a vocabulary text file. -key_index: Column index in a line to get the table `key` values from. -value_index: Column index that represents information of a line to get the table - `value` values from. -vocab_size: Number of elements of the file, use -1 if unknown. -delimiter: Delimiter to separate fields in a line. -)doc"); - -REGISTER_OP("InitializeTableFromTextFileV2") - .Input("table_handle: resource") - .Input("filename: string") - .Attr("key_index: int >= -2") - .Attr("value_index: int >= -2") - .Attr("vocab_size: int >= -1 = -1") - .Attr("delimiter: string = '\t'") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); - return Status::OK(); - }) - .Doc(R"doc( -Initializes a table from a text file. - -It inserts one key-value pair into the table for each line of the file. -The key and value is extracted from the whole line content, elements from the -split line based on `delimiter` or the line number (starting from zero). -Where to extract the key and value from a line is specified by `key_index` and -`value_index`. - -- A value of -1 means use the line number(starting from zero), expects `int64`. -- A value of -2 means use the whole line content, expects `string`. -- A value >= 0 means use the index (starting at zero) of the split line based - on `delimiter`. - -table_handle: Handle to a table which will be initialized. -filename: Filename of a vocabulary text file. -key_index: Column index in a line to get the table `key` values from. -value_index: Column index that represents information of a line to get the table - `value` values from. -vocab_size: Number of elements of the file, use -1 if unknown. -delimiter: Delimiter to separate fields in a line. -)doc"); - REGISTER_OP("GetSessionHandle") .Input("value: T") .Output("handle: string") diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc new file mode 100644 index 0000000000..498a65690d --- /dev/null +++ b/tensorflow/core/ops/lookup_ops.cc @@ -0,0 +1,666 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- + +namespace { +Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_handle; + for (int i = 0; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + +Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + for (int i = 1; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + +Status TwoElementOutput(InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); +} + +Status ScalarOutput(InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); +} +} // namespace + +REGISTER_OP("LookupTableFind") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // Default value must be scalar or vector. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +Looks up keys in a table, outputs the corresponding values. + +The tensor `keys` must of the same type as the keys of the table. +The output `values` is of the type of the table values. + +The scalar `default_value` is the value output for keys not present in the +table. It must also be of the same type as the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Same shape as `keys`. Values found in the table, or `default_values` + for missing keys. +)doc"); + +REGISTER_OP("LookupTableFindV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // Default value must be scalar or vector. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +Looks up keys in a table, outputs the corresponding values. + +The tensor `keys` must of the same type as the keys of the table. +The output `values` is of the type of the table values. + +The scalar `default_value` is the value output for keys not present in the +table. It must also be of the same type as the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Same shape as `keys`. Values found in the table, or `default_values` + for missing keys. +)doc"); + +REGISTER_OP("LookupTableInsert") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Updates the table to associates keys with values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableInsertV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // TODO: Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Updates the table to associates keys with values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableSize") + .Input("table_handle: Ref(string)") + .Output("size: int64") + .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: Handle to the table. +size: Scalar that contains number of elements in the table. +)doc"); + +REGISTER_OP("LookupTableSizeV2") + .Input("table_handle: resource") + .Output("size: int64") + .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs) + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: Handle to the table. +size: Scalar that contains number of elements in the table. +)doc"); + +REGISTER_OP("LookupTableExport") + .Input("table_handle: Ref(string)") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); + ShapeHandle keys = c->Vector(c->Dim(values, 0)); + c->set_output(0, keys); + c->set_output(1, values); + return Status::OK(); + }) + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + +REGISTER_OP("LookupTableExportV2") + .Input("table_handle: resource") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); + ShapeHandle keys = c->Vector(c->Dim(values, 0)); + c->set_output(0, keys); + c->set_output(1, values); + return Status::OK(); + }) + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + +REGISTER_OP("LookupTableImport") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Replaces the contents of the table with the specified keys and values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableImportV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // TODO: Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Replaces the contents of the table with the specified keys and values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("HashTable") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates a non-initialized hash table. + +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("HashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates a non-initialized hash table. + +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTable") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableOfTensors") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableOfTensorsV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableDenseHashTable") + .Input("empty_key: key_dtype") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("initial_num_buckets: int = 131072") // 2^17 + .Attr("max_load_factor: float = 0.8") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table that uses tensors as the backing store. It uses +"open addressing" with quadratic reprobing to resolve collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +empty_key: The key used to represent empty key buckets internally. Must not + be used in insert or lookup operations. +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +value_shape: The shape of each value. +initial_num_buckets: The initial number of hash table buckets. Must be a power + to 2. +max_load_factor: The maximum ratio between number of entries and number of + buckets before growing the table. Must be between 0 and 1. +)doc"); + +REGISTER_OP("MutableDenseHashTableV2") + .Input("empty_key: key_dtype") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("initial_num_buckets: int = 131072") // 2^17 + .Attr("max_load_factor: float = 0.8") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table that uses tensors as the backing store. It uses +"open addressing" with quadratic reprobing to resolve collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +empty_key: The key used to represent empty key buckets internally. Must not + be used in insert or lookup operations. +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +value_shape: The shape of each value. +initial_num_buckets: The initial number of hash table buckets. Must be a power + to 2. +max_load_factor: The maximum ratio between number of entries and number of + buckets before growing the table. Must be between 0 and 1. +)doc"); + +REGISTER_OP("InitializeTable") + .Input("table_handle: Ref(string)") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }) + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: Handle to a table which will be initialized. +keys: Keys of type Tkey. +values: Values of type Tval. +)doc"); + +REGISTER_OP("InitializeTableV2") + .Input("table_handle: resource") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }) + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: Handle to a table which will be initialized. +keys: Keys of type Tkey. +values: Values of type Tval. +)doc"); + +REGISTER_OP("InitializeTableFromTextFile") + .Input("table_handle: Ref(string)") + .Input("filename: string") + .Attr("key_index: int >= -2") + .Attr("value_index: int >= -2") + .Attr("vocab_size: int >= -1 = -1") + .Attr("delimiter: string = '\t'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); + }) + .Doc(R"doc( +Initializes a table from a text file. + +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + +table_handle: Handle to a table which will be initialized. +filename: Filename of a vocabulary text file. +key_index: Column index in a line to get the table `key` values from. +value_index: Column index that represents information of a line to get the table + `value` values from. +vocab_size: Number of elements of the file, use -1 if unknown. +delimiter: Delimiter to separate fields in a line. +)doc"); + +REGISTER_OP("InitializeTableFromTextFileV2") + .Input("table_handle: resource") + .Input("filename: string") + .Attr("key_index: int >= -2") + .Attr("value_index: int >= -2") + .Attr("vocab_size: int >= -1 = -1") + .Attr("delimiter: string = '\t'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); + }) + .Doc(R"doc( +Initializes a table from a text file. + +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + +table_handle: Handle to a table which will be initialized. +filename: Filename of a vocabulary text file. +key_index: Column index in a line to get the table `key` values from. +value_index: Column index that represents information of a line to get the table + `value` values from. +vocab_size: Number of elements of the file, use -1 if unknown. +delimiter: Delimiter to separate fields in a line. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 817d157da2..9fd5ada71e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1022,7 +1022,6 @@ tf_gen_op_wrapper_private_py( require_shape_functions = True, visibility = [ "//learning/brain/python/ops:__pkg__", - "//tensorflow/python/feature_column:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", ], ) @@ -1058,6 +1057,16 @@ tf_gen_op_wrapper_private_py( ) tf_gen_op_wrapper_private_py( + name = "lookup_ops_gen", + require_shape_functions = True, + visibility = [ + "//learning/brain/python/ops:__pkg__", + "//tensorflow/python/feature_column:__pkg__", + "//tensorflow/python/kernel_tests:__pkg__", + ], +) + +tf_gen_op_wrapper_private_py( name = "math_ops_gen", require_shape_functions = True, visibility = [ @@ -1475,6 +1484,20 @@ py_library( ) py_library( + name = "lookup_ops", + srcs = ["ops/lookup_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":framework", + ":framework_for_generated_wrappers", + ":lookup_ops_gen", + ":math_ops", + "@six_archive//:six", + ], +) + +py_library( name = "math_grad", srcs = ["ops/math_grad.py"], srcs_version = "PY2AND3", @@ -1862,6 +1885,7 @@ py_library( ":io_ops", ":linalg_ops", ":logging_ops", + ":lookup_ops", ":math_grad", ":math_ops", ":numerics", @@ -2269,6 +2293,7 @@ py_library( ":io_ops", ":io_ops_gen", ":lib", + ":lookup_ops", ":math_ops", ":platform", ":protos_all_py", @@ -2991,6 +3016,7 @@ cuda_py_tests( ":framework", ":framework_for_generated_wrappers", ":framework_test_lib", + ":lookup_ops", ":gradients", ":math_ops", ":nn_grad", @@ -3021,7 +3047,7 @@ py_library( srcs = ["training/saver_test_utils.py"], srcs_version = "PY2AND3", deps = [ - ":data_flow_ops_gen", + ":lookup_ops_gen", ":training", ], ) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index f70c285f04..b8064f0a77 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -42,8 +42,8 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import state_ops @@ -1391,9 +1391,10 @@ class EstimatorExportTest(test.TestCase): my_int = variables.Variable(1, name='my_int', collections=[ops.GraphKeys.LOCAL_VARIABLES]) scores = constant_op.constant([3.]) - with ops.control_dependencies( - [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()]): + with ops.control_dependencies([ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ]): assign_op = state_ops.assign(my_int, 12345) # local_initSop must be an Operation, not a Tensor. diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index d734273845..ac7aef96ac 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -80,9 +80,9 @@ py_library( deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops_gen", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lookup_ops_gen", "//tensorflow/python:math_ops", "//tensorflow/python:string_ops", "//tensorflow/python:training", diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index d85142abcf..ad67a082dc 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -41,7 +41,7 @@ from tensorflow.python.platform import test def _initialized_session(): sess = session.Session() sess.run(variables_lib.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) return sess @@ -1277,7 +1277,7 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'): with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() def test_invalid_vocabulary_size(self): with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): @@ -1307,7 +1307,7 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'): with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() def test_invalid_num_oov_buckets(self): with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'): diff --git a/tensorflow/python/feature_column/lookup_ops.py b/tensorflow/python/feature_column/lookup_ops.py index 13a67fa518..8225b47b20 100644 --- a/tensorflow/python/feature_column/lookup_ops.py +++ b/tensorflow/python/feature_column/lookup_ops.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.training.saver import BaseSaverBuilder @@ -151,7 +151,7 @@ class InitializableLookupTableBase(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as scope: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=scope) # pylint: enable=protected-access def lookup(self, keys, name=None): @@ -182,7 +182,7 @@ class InitializableLookupTableBase(LookupInterface): name, "%s_Lookup" % self._name, (self._table_ref, key_tensor, self._default_value)) as scope: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find( + values = gen_lookup_ops._lookup_table_find( self._table_ref, key_tensor, self._default_value, name=scope) # pylint: enable=protected-access @@ -229,7 +229,7 @@ class HashTable(InitializableLookupTableBase): with ops.name_scope( name, "hash_table", (initializer, default_value)) as scope: # pylint: disable=protected-access - table_ref = gen_data_flow_ops._hash_table( + table_ref = gen_lookup_ops._hash_table( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, @@ -308,10 +308,8 @@ class KeyValueTensorInitializer(TableInitializerBase): self._name, values=(table.table_ref, self._keys, self._values)) as scope: # pylint: disable=protected-access - init_op = gen_data_flow_ops._initialize_table(table.table_ref, - self._keys, - self._values, - name=scope) + init_op = gen_lookup_ops._initialize_table( + table.table_ref, self._keys, self._values, name=scope) # pylint: enable=protected-access ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op @@ -477,7 +475,7 @@ class TextFileInitializer(TableInitializerBase): dtypes.string, name="asset_filepath") # pylint: disable=protected-access - init_op = gen_data_flow_ops._initialize_table_from_text_file( + init_op = gen_lookup_ops._initialize_table_from_text_file( table.table_ref, filename, self._key_index, @@ -1333,14 +1331,14 @@ class MutableHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None # pylint: disable=protected-access if self._default_value.get_shape().ndims == 0: - self._table_ref = gen_data_flow_ops._mutable_hash_table( + self._table_ref = gen_lookup_ops._mutable_hash_table( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: - self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors( + self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, @@ -1368,7 +1366,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -1394,10 +1392,8 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, (self._table_ref, keys, self._default_value)) as name: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find(self._table_ref, - keys, - self._default_value, - name=name) + values = gen_lookup_ops._lookup_table_find( + self._table_ref, keys, self._default_value, name=name) values.set_shape(keys.get_shape().concatenate(self._value_shape)) return values @@ -1423,7 +1419,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: # pylint: disable=protected-access - op = gen_data_flow_ops._lookup_table_insert( + op = gen_lookup_ops._lookup_table_insert( self._table_ref, keys, values, name=name) return op @@ -1440,11 +1436,8 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - exported_keys, exported_values = gen_data_flow_ops._lookup_table_export( - self._table_ref, - self._key_dtype, - self._value_dtype, - name=name) + exported_keys, exported_values = gen_lookup_ops._lookup_table_export( + self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( self._value_shape)) @@ -1464,7 +1457,7 @@ class MutableHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_import( + return gen_lookup_ops._lookup_table_import( self.op._table_ref, restored_tensors[0], restored_tensors[1]) @@ -1539,7 +1532,7 @@ class MutableDenseHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) # pylint: disable=protected-access - self._table_ref = gen_data_flow_ops._mutable_dense_hash_table( + self._table_ref = gen_lookup_ops._mutable_dense_hash_table( empty_key=empty_key, shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, @@ -1567,7 +1560,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -1593,7 +1586,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, [self._table_ref, keys]) as name: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find( + values = gen_lookup_ops._lookup_table_find( self._table_ref, keys, self._default_value, name=name) if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: @@ -1623,7 +1616,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: # pylint: disable=protected-access - op = gen_data_flow_ops._lookup_table_insert( + op = gen_lookup_ops._lookup_table_insert( self._table_ref, keys, values, name=name) return op @@ -1640,7 +1633,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - exported_keys, exported_values = gen_data_flow_ops._lookup_table_export( + exported_keys, exported_values = gen_lookup_ops._lookup_table_export( self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( @@ -1661,6 +1654,5 @@ class MutableDenseHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_import(self.op._table_ref, - restored_tensors[0], - restored_tensors[1]) + return gen_lookup_ops._lookup_table_import( + self.op._table_ref, restored_tensors[0], restored_tensors[1]) diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 95e803e2aa..9a208613ad 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -38,7 +38,6 @@ from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * # pylint: enable=wildcard-import -from tensorflow.python.util.deprecation import deprecated def _as_type_list(dtypes): @@ -1037,47 +1036,6 @@ class Barrier(object): self._barrier_ref, name=name) -@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.") -def initialize_all_tables(name="init_all_tables"): - """Returns an Op that initializes all tables of the default graph. - - Args: - name: Optional name for the initialization op. - - Returns: - An Op that initializes all tables. Note that if there are - not tables the returned Op is a NoOp. - """ - return tables_initializer(name) - - -def tables_initializer(name="init_all_tables"): - """Returns an Op that initializes all tables of the default graph. - - Args: - name: Optional name for the initialization op. - - Returns: - An Op that initializes all tables. Note that if there are - not tables the returned Op is a NoOp. - """ - initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) - if initializers: - return control_flow_ops.group(*initializers, name=name) - return control_flow_ops.no_op(name=name) - - -ops.NotDifferentiable("LookupTableFind") -ops.NotDifferentiable("LookupTableInsert") -ops.NotDifferentiable("LookupTableSize") -ops.NotDifferentiable("HashTable") -ops.NotDifferentiable("InitializeTable") -ops.NotDifferentiable("InitializeTableFromTextFile") -ops.NotDifferentiable("MutableDenseHashTable") -ops.NotDifferentiable("MutableHashTable") -ops.NotDifferentiable("MutableHashTableOfTensors") - - class ConditionalAccumulatorBase(object): """A conditional accumulator for aggregating gradients. diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py new file mode 100644 index 0000000000..54dba9e38e --- /dev/null +++ b/tensorflow/python/ops/lookup_ops.py @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================== +"""Data Flow Operations.""" +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_lookup_ops import * +# pylint: enable=wildcard-import +from tensorflow.python.util.deprecation import deprecated + + +@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.") +def initialize_all_tables(name="init_all_tables"): + """Returns an Op that initializes all tables of the default graph. + + Args: + name: Optional name for the initialization op. + + Returns: + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + """ + return tables_initializer(name) + + +def tables_initializer(name="init_all_tables"): + """Returns an Op that initializes all tables of the default graph. + + Args: + name: Optional name for the initialization op. + + Returns: + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + """ + initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) + if initializers: + return control_flow_ops.group(*initializers, name=name) + return control_flow_ops.no_op(name=name) + + +ops.NotDifferentiable("LookupTableFind") +ops.NotDifferentiable("LookupTableFindV2") +ops.NotDifferentiable("LookupTableInsert") +ops.NotDifferentiable("LookupTableInsertV2") +ops.NotDifferentiable("LookupTableSize") +ops.NotDifferentiable("LookupTableSizeV2") +ops.NotDifferentiable("HashTable") +ops.NotDifferentiable("HashTableV2") +ops.NotDifferentiable("InitializeTable") +ops.NotDifferentiable("InitializeTableV2") +ops.NotDifferentiable("InitializeTableFromTextFile") +ops.NotDifferentiable("InitializeTableFromTextFileV2") +ops.NotDifferentiable("MutableDenseHashTable") +ops.NotDifferentiable("MutableDenseHashTableV2") +ops.NotDifferentiable("MutableHashTable") +ops.NotDifferentiable("MutableHashTableV2") +ops.NotDifferentiable("MutableHashTableOfTensors") +ops.NotDifferentiable("MutableHashTableOfTensorsV2") diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 09e04d4247..a39d28490c 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -57,6 +57,7 @@ from tensorflow.python.ops.io_ops import * from tensorflow.python.ops.linalg_ops import * from tensorflow.python.ops.logging_ops import Print from tensorflow.python.ops.logging_ops import get_summary_op +from tensorflow.python.ops.lookup_ops import * from tensorflow.python.ops.math_ops import * from tensorflow.python.ops.numerics import * from tensorflow.python.ops.parsing_ops import * diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py index 66cf9d4d8a..355fd57bf1 100644 --- a/tensorflow/python/saved_model/main_op_impl.py +++ b/tensorflow/python/saved_model/main_op_impl.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables @@ -35,7 +35,7 @@ def main_op(): """ init = variables.global_variables_initializer() init_local = variables.local_variables_initializer() - init_tables = tf_data_flow_ops.tables_initializer() + init_tables = lookup_ops.tables_initializer() return control_flow_ops.group(init, init_local, init_tables) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 4c81af56ad..fcec3ed97c 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -238,7 +238,7 @@ class Scaffold(object): @staticmethod def _default_local_init_op(): return control_flow_ops.group(variables.local_variables_initializer(), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) def MonitoredTrainingSession(master='', # pylint: disable=invalid-name diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py index 5f31e2aa53..6a73565f82 100644 --- a/tensorflow/python/training/saver_test_utils.py +++ b/tensorflow/python/training/saver_test_utils.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib -from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.training import saver as saver_module @@ -34,7 +34,7 @@ class CheckpointedOp(object): # pylint: disable=protected-access def __init__(self, name, table_ref=None): if table_ref is None: - self.table_ref = gen_data_flow_ops._mutable_hash_table( + self.table_ref = gen_lookup_ops._mutable_hash_table( key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) else: self.table_ref = table_ref @@ -52,10 +52,10 @@ class CheckpointedOp(object): return self._saveable def insert(self, keys, values): - return gen_data_flow_ops._lookup_table_insert(self.table_ref, keys, values) + return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values) def lookup(self, keys, default): - return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default) + return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default) def keys(self): return self._export()[0] @@ -64,8 +64,8 @@ class CheckpointedOp(object): return self._export()[1] def _export(self): - return gen_data_flow_ops._lookup_table_export(self.table_ref, dtypes.string, - dtypes.float32) + return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string, + dtypes.float32) class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): """A custom saveable for CheckpointedOp.""" @@ -81,6 +81,6 @@ class CheckpointedOp(object): super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) def restore(self, restore_tensors, shapes): - return gen_data_flow_ops._lookup_table_import( + return gen_lookup_ops._lookup_table_import( self.op.table_ref, restore_tensors[0], restore_tensors[1]) # pylint: enable=protected-access diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 277c11386d..230ed1db68 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as _summary @@ -426,8 +426,10 @@ class Supervisor(object): local_init_op = self._get_first_op_from_collection( ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: - op_list = [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()] + op_list = [ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) |