aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-26 08:14:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-26 09:18:06 -0700
commitacd8c859f7c2d077464422cd033efeb5cce4b986 (patch)
treed6ccfcd3ed7d932b7ab9876ff0f327aed7b46a19
parent0868ce67f4174b2b857641f473a00da81a1f511a (diff)
Add a variant of mutable hash table that supports tensors as values.
Add support for exporting the contents of a table. Change: 125901929
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py54
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py131
-rw-r--r--tensorflow/core/framework/lookup_interface.cc26
-rw-r--r--tensorflow/core/framework/lookup_interface.h40
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.cc1
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h8
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc212
-rw-r--r--tensorflow/core/kernels/lookup_table_op.h4
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc42
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/ops/data_flow_ops.py16
11 files changed, 462 insertions, 74 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index b3248c4741..cdfe3f478b 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -179,6 +179,7 @@ class InitializableLookupTableBase(LookupInterface):
name=name)
# pylint: enable=protected-access
+ values.set_shape(key_tensor.get_shape())
if isinstance(keys, ops.SparseTensor):
return ops.SparseTensor(keys.indices, values, keys.shape)
else:
@@ -726,22 +727,30 @@ class MutableHashTable(LookupInterface):
Returns:
A `MutableHashTable` object.
"""
+ self._default_value = ops.convert_to_tensor(default_value,
+ dtype=value_dtype)
+ self._value_shape = self._default_value.get_shape()
+
# pylint: disable=protected-access
- self._table_ref = gen_data_flow_ops._mutable_hash_table(
- shared_name=shared_name,
- key_dtype=key_dtype,
- value_dtype=value_dtype,
- name=name)
+ if self._default_value.get_shape().ndims == 0:
+ self._table_ref = gen_data_flow_ops._mutable_hash_table(
+ shared_name=shared_name,
+ key_dtype=key_dtype,
+ value_dtype=value_dtype,
+ name=name)
+ else:
+ self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors(
+ shared_name=shared_name,
+ key_dtype=key_dtype,
+ value_dtype=value_dtype,
+ value_shape=self._default_value.get_shape(),
+ name=name)
# pylint: enable=protected-access
+
super(MutableHashTable, self).__init__(key_dtype, value_dtype,
self._table_ref.op.name.split(
"/")[-1])
- with ops.op_scope([self._table_ref, default_value], name,
- "MutableHashTable"):
- self._default_value = ops.convert_to_tensor(default_value,
- dtype=self._value_dtype)
-
def size(self, name=None):
"""Compute the number of elements in this table.
@@ -786,6 +795,7 @@ class MutableHashTable(LookupInterface):
name=name)
# pylint: enable=protected-access
+ values.set_shape(keys.get_shape().concatenate(self._value_shape))
return values
def insert(self, keys, values, name=None):
@@ -814,3 +824,27 @@ class MutableHashTable(LookupInterface):
# pylint: enable=protected-access
return op
+
+ def export(self, name=None):
+ """Returns tensors of all keys and values in the table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A pair of tensors with the first tensor containing all keys and the
+ second tensors containing all values in the table.
+ """
+ with ops.op_scope([self._table_ref], name,
+ "%s_lookup_table_export_values" % self._name) as name:
+ # pylint: disable=protected-access
+ exported_keys, exported_values = gen_data_flow_ops._lookup_table_export(
+ self._table_ref,
+ self._key_dtype,
+ self._value_dtype,
+ name=name)
+ # pylint: enable=protected-access
+
+ exported_values.set_shape(exported_keys.get_shape().concatenate(
+ self._value_shape))
+ return exported_keys, exported_values
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index c8a718741d..48e5e43a6e 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -38,6 +38,7 @@ class HashTableOpTest(tf.test.TestCase):
input_string = tf.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
+ self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
@@ -251,10 +252,120 @@ class MutableHashTableOpTest(tf.test.TestCase):
input_string = tf.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
+ self.assertAllEqual([3], output.get_shape())
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
+ exported_keys, exported_values = table.export()
+ self.assertAllEqual([None], exported_keys.get_shape().as_list())
+ self.assertAllEqual([None], exported_values.get_shape().as_list())
+
+ # exported data is in the order of the internal map, i.e. undefined
+ sorted_keys = np.sort(exported_keys.eval())
+ sorted_values = np.sort(exported_values.eval())
+ self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
+ self.assertAllEqual([0, 1, 2], sorted_values)
+
+ def testMutableHashTableOfTensors(self):
+ with self.test_session():
+ default_val = tf.constant([-1, -1], tf.int64)
+ keys = tf.constant(["brain", "salad", "surgery"])
+ values = tf.constant([[0, 1], [2, 3], [4, 5]], tf.int64)
+ table = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64,
+ default_val)
+ self.assertAllEqual(0, table.size().eval())
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = tf.constant(["brain", "salad", "tank"])
+ output = table.lookup(input_string)
+ self.assertAllEqual([3, 2], output.get_shape())
+
+ result = output.eval()
+ self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result)
+
+ exported_keys, exported_values = table.export()
+ self.assertAllEqual([None], exported_keys.get_shape().as_list())
+ self.assertAllEqual([None, 2], exported_values.get_shape().as_list())
+ # exported data is in the order of the internal map, i.e. undefined
+ sorted_keys = np.sort(exported_keys.eval())
+ sorted_values = np.sort(exported_values.eval())
+ self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
+ self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
+
+ def testMutableHashTableExportInsert(self):
+ with self.test_session():
+ default_val = tf.constant([-1, -1], tf.int64)
+ keys = tf.constant(["brain", "salad", "surgery"])
+ values = tf.constant([[0, 1], [2, 3], [4, 5]], tf.int64)
+ table1 = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64,
+ default_val)
+ self.assertAllEqual(0, table1.size().eval())
+ table1.insert(keys, values).run()
+ self.assertAllEqual(3, table1.size().eval())
+
+ input_string = tf.constant(["brain", "salad", "tank"])
+ expected_output = [[0, 1], [2, 3], [-1, -1]]
+ output1 = table1.lookup(input_string)
+ self.assertAllEqual(expected_output, output1.eval())
+
+ exported_keys, exported_values = table1.export()
+ self.assertAllEqual(3, exported_keys.eval().size)
+ self.assertAllEqual(6, exported_values.eval().size)
+
+ # Populate a second table from the exported data
+ table2 = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64,
+ default_val)
+ self.assertAllEqual(0, table2.size().eval())
+ table2.insert(exported_keys, exported_values).run()
+ self.assertAllEqual(3, table2.size().eval())
+
+ # Verify lookup result is still the same
+ output2 = table2.lookup(input_string)
+ self.assertAllEqual(expected_output, output2.eval())
+
+ def testMutableHashTableOfTensorsInvalidShape(self):
+ with self.test_session():
+ default_val = tf.constant([-1, -1], tf.int64)
+ keys = tf.constant(["brain", "salad", "surgery"])
+ table = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64,
+ default_val)
+
+ # Shape [6] instead of [3, 2]
+ values = tf.constant([0, 1, 2, 3, 4, 5], tf.int64)
+ with self.assertRaisesOpError("Expected shape"):
+ table.insert(keys, values).run()
+
+ # Shape [2,3] instead of [3, 2]
+ values = tf.constant([[0, 1, 2], [3, 4, 5]], tf.int64)
+ with self.assertRaisesOpError("Expected shape"):
+ table.insert(keys, values).run()
+
+ # Shape [2, 2] instead of [3, 2]
+ values = tf.constant([[0, 1], [2, 3]], tf.int64)
+ with self.assertRaisesOpError("Expected shape"):
+ table.insert(keys, values).run()
+
+ # Shape [3, 1] instead of [3, 2]
+ values = tf.constant([[0], [2], [4]], tf.int64)
+ with self.assertRaisesOpError("Expected shape"):
+ table.insert(keys, values).run()
+
+ # Valid Insert
+ values = tf.constant([[0, 1], [2, 3], [4, 5]], tf.int64)
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ def testMutableHashTableInvalidDefaultValue(self):
+ with self.test_session():
+ default_val = tf.constant([[-1, -1]], tf.int64)
+ table = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64,
+ default_val)
+ with self.assertRaisesOpError("Default value must be a vector"):
+ self.assertAllEqual(0, table.size().eval())
+
def testMutableHashTableDuplicateInsert(self):
with self.test_session():
default_val = -1
@@ -288,6 +399,7 @@ class MutableHashTableOpTest(tf.test.TestCase):
input_string = tf.constant([["brain", "salad"], ["tank", "tarkus"]])
output = table.lookup(input_string)
+ self.assertAllEqual([2, 2], output.get_shape())
result = output.eval()
self.assertAllEqual([[0, 1], [-1, -1]], result)
@@ -310,6 +422,25 @@ class MutableHashTableOpTest(tf.test.TestCase):
result = output.eval()
self.assertAllEqual([0, 1, 3, -1], result)
+ def testMutableHashTableOfTensorsFindHighRank(self):
+ with self.test_session():
+ default_val = tf.constant([-1, -1, -1], tf.int64)
+ keys = tf.constant(["brain", "salad", "surgery"])
+ values = tf.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], tf.int64)
+ table = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64,
+ default_val)
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = tf.constant([["brain", "salad"], ["tank", "tarkus"]])
+ output = table.lookup(input_string)
+ self.assertAllEqual([2, 2, 3], output.get_shape())
+
+ result = output.eval()
+ self.assertAllEqual(
+ [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
+
def testMultipleMutableHashTables(self):
with self.test_session() as sess:
default_val = -1
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc
index aafa9e084a..0d20766673 100644
--- a/tensorflow/core/framework/lookup_interface.cc
+++ b/tensorflow/core/framework/lookup_interface.cc
@@ -30,28 +30,30 @@ Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key,
return errors::InvalidArgument("Value must be type ", value_dtype(),
" but got ", value.dtype());
}
- if (key.NumElements() != value.NumElements()) {
- return errors::InvalidArgument("Number of elements of key(",
- key.NumElements(), ") and value(",
- value.NumElements(), ") are different.");
- }
- if (!key.shape().IsSameSize(value.shape())) {
- return errors::InvalidArgument("key and value have different shapes.");
+ TensorShape expected_value_shape = key.shape();
+ expected_value_shape.AppendShape(value_shape());
+ if (value.shape() != expected_value_shape) {
+ return errors::InvalidArgument(
+ "Expected shape ", expected_value_shape.DebugString(),
+ " for value, got ", value.shape().DebugString());
}
return Status::OK();
}
Status LookupInterface::CheckFindArguments(const Tensor& key,
- const Tensor& value,
const Tensor& default_value) {
- TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value));
-
+ if (key.dtype() != key_dtype()) {
+ return errors::InvalidArgument("Key must be type ", key_dtype(),
+ " but got ", key.dtype());
+ }
if (default_value.dtype() != value_dtype()) {
return errors::InvalidArgument("Default value must be type ", value_dtype(),
" but got ", default_value.dtype());
}
- if (!TensorShapeUtils::IsScalar(default_value.shape())) {
- return errors::InvalidArgument("Default values must be scalar.");
+ if (default_value.shape() != value_shape()) {
+ return errors::InvalidArgument(
+ "Expected shape ", value_shape().DebugString(),
+ " for default value, got ", default_value.shape().DebugString());
}
return Status::OK();
}
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h
index 5968f6e84d..b2279de66f 100644
--- a/tensorflow/core/framework/lookup_interface.h
+++ b/tensorflow/core/framework/lookup_interface.h
@@ -21,6 +21,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
+
+class OpKernelContext;
+
namespace lookup {
// Forward declaration so we can define GetInitializableLookupTable() in
@@ -63,12 +66,33 @@ class LookupInterface : public ResourceBase {
// Returns the number of elements in the table.
virtual size_t size() const = 0;
+ virtual Status ExportValues(OpKernelContext* context) = 0;
+
// Returns the data type of the key.
virtual DataType key_dtype() const = 0;
// Returns the data type of the value.
virtual DataType value_dtype() const = 0;
+ // Returns the shape of a value in the table.
+ virtual TensorShape value_shape() const = 0;
+
+ // Check format of the key and value tensors.
+ // Returns OK if all the following requirements are satisfied, otherwise it
+ // returns InvalidArgument:
+ // - DataType of the tensor keys equals to the table key_dtype
+ // - DataType of the tensor values equals to the table value_dtype
+ // - the values tensor has the required shape given keys and the tables's
+ // value shape.
+ Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values);
+
+ // Check the arguments of a find operation. Returns OK if all the following
+ // requirements are satisfied, otherwise it returns InvalidArgument:
+ // - DataType of the tensor keys equals to the table key_dtype
+ // - DataType of the tensor default_value equals to the table value_dtype
+ // - the default_value tensor shape matches the table's value shape.
+ Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
+
string DebugString() override { return "A lookup table"; }
// Returns an InitializableLookupTable, a subclass of LookupInterface, if the
@@ -79,22 +103,6 @@ class LookupInterface : public ResourceBase {
protected:
virtual ~LookupInterface() = default;
-
- // Check format of the key and value tensors.
- // Returns OK if all the following requirements are satisfied, otherwise it
- // returns InvalidArgument:
- // - DataType of the tensor key equals to the table key_dtype
- // - DataType of the test value equals to the table value_dtype
- // - key and value have the same size and shape
- Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values);
-
- // Check the arguments of a find operation. Returns OK if all the following
- // requirements are satisfied, otherwise it returns InvalidArgument:
- // - All requirements of CheckKeyAndValueTensors
- // - default_value type equals to the table value_dtype
- // - default_value is scalar
- Status CheckFindArguments(const Tensor& keys, const Tensor& values,
- const Tensor& default_value);
};
} // namespace lookup
diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc
index feb3e8dcc7..743464d4ba 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.cc
+++ b/tensorflow/core/kernels/initializable_lookup_table.cc
@@ -28,7 +28,6 @@ Status InitializableLookupTable::Find(const Tensor& keys, Tensor* values,
// Do not let the use migrate before the check; table is used without
// a lock by the readers.
std::atomic_thread_fence(std::memory_order_acquire);
- TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value));
return DoFind(keys, values, default_value);
}
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index 1e394bc902..be6b4c7aa1 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -50,6 +50,14 @@ class InitializableLookupTable : public LookupInterface {
"Insert not supported by InitializableLookupTable implementations");
}
+ Status ExportValues(OpKernelContext* context) final {
+ return errors::Unimplemented(
+ "ExportValues not supported by InitializableLookupTable "
+ "implementations");
+ }
+
+ TensorShape value_shape() const final { return TensorShape(); }
+
// Returns whether the table was initialized and is ready to serve lookups.
bool is_initialized() const { return is_initialized_; }
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 213844ec00..c48856c206 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/initializable_lookup_table.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -73,6 +74,8 @@ const float SubtleMustCopyUnlessStringOrFloat(const float value) {
template <class K, class V>
class HashTable : public InitializableLookupTable {
public:
+ HashTable(OpKernelContext* ctx, OpKernel* kernel) {}
+
size_t size() const override {
// return the size of the table only if it's initialized, otherwise 0.
if (!is_initialized_) {
@@ -105,7 +108,7 @@ class HashTable : public InitializableLookupTable {
const auto key_values = keys.flat<K>();
const auto value_values = values.flat<V>();
- for (int i = 0; i < key_values.size(); ++i) {
+ for (int64 i = 0; i < key_values.size(); ++i) {
const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i));
const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i));
const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value);
@@ -124,7 +127,7 @@ class HashTable : public InitializableLookupTable {
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
- for (int i = 0; i < key_values.size(); ++i) {
+ for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
*table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)),
default_val);
@@ -137,22 +140,23 @@ class HashTable : public InitializableLookupTable {
};
// Lookup table that wraps an unordered_map, where the key and value data type
-// is specified.
+// is specified. Each individual value must be a scalar. If vector values are
+// required, use MutableHashTableOfTensors.
//
// This table is mutable and thread safe - Insert can be called at any time.
//
// Sample use case:
//
-// MutableHashTable<int64, int64> table; // int64 -> int64.
+// MutableHashTableOfScalars<int64, int64> table; // int64 -> int64.
// // Populate the table, elements could be added in one or multiple calls.
// table.Insert(key_tensor, value_tensor); // Populate the table.
//
// table.Find(in_t, &out_t, default_t)
//
template <class K, class V>
-class MutableHashTable : public LookupInterface {
+class MutableHashTableOfScalars : public LookupInterface {
public:
- MutableHashTable() {}
+ MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
size_t size() const override {
mutex_lock l(mu_);
@@ -161,14 +165,12 @@ class MutableHashTable : public LookupInterface {
Status Find(const Tensor& key, Tensor* value,
const Tensor& default_value) override {
- TF_RETURN_IF_ERROR(CheckFindArguments(key, *value, default_value));
-
const V default_val = default_value.flat<V>()(0);
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
mutex_lock l(mu_);
- for (int i = 0; i < key_values.size(); ++i) {
+ for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)),
default_val);
@@ -178,13 +180,11 @@ class MutableHashTable : public LookupInterface {
}
Status Insert(const Tensor& keys, const Tensor& values) override {
- TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(keys, values));
-
const auto key_values = keys.flat<K>();
const auto value_values = values.flat<V>();
mutex_lock l(mu_);
- for (int i = 0; i < key_values.size(); ++i) {
+ for (int64 i = 0; i < key_values.size(); ++i) {
const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i));
const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i));
gtl::InsertOrUpdate(&table_, key, value);
@@ -192,14 +192,139 @@ class MutableHashTable : public LookupInterface {
return Status::OK();
}
+ Status ExportValues(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ int64 size = table_.size();
+
+ Tensor* keys;
+ Tensor* values;
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output("keys", TensorShape({size}), &keys));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output("values", TensorShape({size}), &values));
+
+ auto keys_data = keys->flat<K>();
+ auto values_data = values->flat<V>();
+ int64 i = 0;
+ for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
+ keys_data(i) = it->first;
+ values_data(i) = it->second;
+ }
+ return Status::OK();
+ }
+
+ DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
+
+ DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
+
+ TensorShape value_shape() const override { return TensorShape(); }
+
+ private:
+ // TODO(andreasst): consider using a read/write lock or a concurrent map
+ mutable mutex mu_;
+ std::unordered_map<K, V> table_ GUARDED_BY(mu_);
+};
+
+// Lookup table that wraps an unordered_map. Behaves identical to
+// MutableHashTableOfScalars except that each value must be a vector.
+template <class K, class V>
+class MutableHashTableOfTensors : public LookupInterface {
+ public:
+ MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) {
+ OP_REQUIRES_OK(ctx,
+ GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(value_shape_),
+ errors::InvalidArgument("Default value must be a vector, got shape ",
+ value_shape_.DebugString()));
+ }
+
+ size_t size() const override {
+ mutex_lock l(mu_);
+ return table_.size();
+ }
+
+ Status Find(const Tensor& key, Tensor* value,
+ const Tensor& default_value) override {
+ const auto default_flat = default_value.flat<V>();
+ const auto key_values = key.flat<K>();
+ auto value_values = value->flat_inner_dims<V, 2>();
+ int64 value_dim = value_shape_.dim_size(0);
+
+ mutex_lock l(mu_);
+ for (int64 i = 0; i < key_values.size(); ++i) {
+ ValueArray* value_vec = gtl::FindOrNull(
+ table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)));
+ if (value_vec != nullptr) {
+ for (int64 j = 0; j < value_dim; j++) {
+ value_values(i, j) = value_vec->at(j);
+ }
+ } else {
+ for (int64 j = 0; j < value_dim; j++) {
+ value_values(i, j) = default_flat(j);
+ }
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Status Insert(const Tensor& keys, const Tensor& values) override {
+ const auto key_values = keys.flat<K>();
+ const auto value_values = values.flat_inner_dims<V, 2>();
+ int64 value_dim = value_shape_.dim_size(0);
+
+ mutex_lock l(mu_);
+ for (int64 i = 0; i < key_values.size(); ++i) {
+ const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i));
+ ValueArray value_vec;
+ for (int64 j = 0; j < value_dim; j++) {
+ V value = value_values(i, j);
+ value_vec.push_back(value);
+ }
+ gtl::InsertOrUpdate(&table_, key, value_vec);
+ }
+ return Status::OK();
+ }
+
+ Status ExportValues(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ int64 size = table_.size();
+ int64 value_dim = value_shape_.dim_size(0);
+
+ Tensor* keys;
+ Tensor* values;
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output("keys", TensorShape({size}), &keys));
+ TF_RETURN_IF_ERROR(ctx->allocate_output(
+ "values", TensorShape({size, value_dim}), &values));
+
+ auto keys_data = keys->flat<K>();
+ auto values_data = values->matrix<V>();
+ int64 i = 0;
+ for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
+ K key = it->first;
+ ValueArray value = it->second;
+ keys_data(i) = key;
+ for (int64 j = 0; j < value_dim; j++) {
+ values_data(i, j) = value[j];
+ }
+ }
+ return Status::OK();
+ }
+
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
+ TensorShape value_shape() const override { return value_shape_; }
+
private:
+ TensorShape value_shape_;
// TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
- std::unordered_map<K, V> table_;
+ typedef gtl::InlinedVector<V, 4> ValueArray;
+ std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
};
} // namespace lookup
@@ -219,17 +344,16 @@ class LookupTableFindOp : public OpKernel {
DataTypeVector expected_outputs = {table->value_dtype()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
- const Tensor& input = ctx->input(1);
-
+ const Tensor& key = ctx->input(1);
const Tensor& default_value = ctx->input(2);
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(default_value.shape()),
- errors::InvalidArgument("Default value must be a scalar, not ",
- default_value.shape().DebugString()));
+ OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
+ TensorShape output_shape = key.shape();
+ output_shape.AppendShape(table->value_shape());
Tensor* out;
- OP_REQUIRES_OK(ctx, ctx->allocate_output("values", input.shape(), &out));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
- OP_REQUIRES_OK(ctx, table->Find(input, out, default_value));
+ OP_REQUIRES_OK(ctx, table->Find(key, out, default_value));
}
};
@@ -252,6 +376,7 @@ class LookupTableInsertOp : public OpKernel {
const Tensor& keys = ctx->input(1);
const Tensor& values = ctx->input(2);
+ OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensors(keys, values));
OP_REQUIRES_OK(ctx, table->Insert(keys, values));
}
};
@@ -278,6 +403,23 @@ class LookupTableSizeOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU),
LookupTableSizeOp);
+// Op that outputs tensors of all keys and all values.
+class LookupTableExportOp : public OpKernel {
+ public:
+ explicit LookupTableExportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ lookup::LookupInterface* table;
+ OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
+ core::ScopedUnref unref_me(table);
+
+ OP_REQUIRES_OK(ctx, table->ExportValues(ctx));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU),
+ LookupTableExportOp);
+
// Register the HashTable op with the currently supported key and value types.
#define REGISTER_KERNEL(key_dtype, value_dtype) \
REGISTER_KERNEL_BUILDER( \
@@ -294,13 +436,29 @@ REGISTER_KERNEL(int64, string);
#undef REGISTER_KERNEL
// Register the MutableHashTable op.
-#define REGISTER_KERNEL(key_dtype, value_dtype) \
- REGISTER_KERNEL_BUILDER( \
- Name("MutableHashTable") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<key_dtype>("key_dtype") \
- .TypeConstraint<value_dtype>("value_dtype"), \
- LookupTableOp<lookup::MutableHashTable<key_dtype, value_dtype>, \
+#define REGISTER_KERNEL(key_dtype, value_dtype) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MutableHashTable") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<key_dtype>("key_dtype") \
+ .TypeConstraint<value_dtype>("value_dtype"), \
+ LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
+ key_dtype, value_dtype>)
+
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int64);
+REGISTER_KERNEL(int64, string);
+
+#undef REGISTER_KERNEL
+
+// Register the MutableHashTableOfTensors op.
+#define REGISTER_KERNEL(key_dtype, value_dtype) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MutableHashTableOfTensors") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<key_dtype>("key_dtype") \
+ .TypeConstraint<value_dtype>("value_dtype"), \
+ LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
REGISTER_KERNEL(string, float);
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index 8f316a0f3c..cd44fb64c2 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -49,8 +49,8 @@ class LookupTableOp : public OpKernel {
mutex_lock l(mu_);
if (!table_handle_set_) {
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
- auto creator = [this](lookup::LookupInterface** ret) {
- *ret = new Container();
+ auto creator = [ctx, this](lookup::LookupInterface** ret) {
+ *ret = new Container(ctx, this);
return Status::OK();
};
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 85c3cd28fc..742c7d0b38 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -704,6 +704,20 @@ table_handle: Handle to the table.
size: Scalar that contains number of elements in the table.
)doc");
+REGISTER_OP("LookupTableExport")
+ .Input("table_handle: Ref(string)")
+ .Output("keys: Tkeys")
+ .Output("values: Tvalues")
+ .Attr("Tkeys: type")
+ .Attr("Tvalues: type")
+ .Doc(R"doc(
+Outputs all keys and values in the table.
+
+table_handle: Handle to the table.
+keys: Vector of all keys present in the table.
+values: Tensor of all values in the table. Indexed in parallel with `keys`.
+)doc");
+
REGISTER_OP("HashTable")
.Output("table_handle: Ref(string)")
.Attr("container: string = ''")
@@ -738,8 +752,32 @@ REGISTER_OP("MutableHashTable")
Creates an empty hash table.
This op creates a mutable hash table, specifying the type of its keys and
-values. Data can be inserted into the table using the insert operations. It
-does not support the initialization operation.
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
+REGISTER_OP("MutableHashTableOfTensors")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .SetIsStateful()
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a vector. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
table_handle: Handle to a table.
container: If non-empty, this table is placed in the given container.
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e65d44c378..ef7754a920 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -544,10 +544,12 @@ tf_gen_op_wrapper_py(
"HashTable",
"InitializeTable",
"InitializeTableFromTextFile",
+ "LookupTableExport",
"LookupTableFind",
"LookupTableInsert",
"LookupTableSize",
"MutableHashTable",
+ "MutableHashTableOfTensors",
"Mutex",
"MutexAcquire",
"MutexRelease",
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 9fbcbd645d..a4a3c4e669 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -732,6 +732,7 @@ ops.NoGradient("HashTable")
ops.NoGradient("InitializeTable")
ops.NoGradient("InitializeTableFromTextFile")
ops.NoGradient("MutableHashTable")
+ops.NoGradient("MutableHashTableOfTensors")
ops.RegisterShape("QueueSize")(common_shapes.scalar_shape)
@@ -807,16 +808,13 @@ def _DynamicStitchShape(op):
def _LookupTableFindShape(op):
"""Shape function for data_flow_ops._lookup_table_find."""
op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
- shape_in = op.inputs[1].get_shape()
- return [shape_in]
+ return [tensor_shape.unknown_shape()]
@ops.RegisterShape("LookupTableInsert")
def _LookupTableInsertShape(op):
"""Shape function for data_flow_ops._lookup_table_insert."""
op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
- keys_shape = op.inputs[1].get_shape()
- op.inputs[2].get_shape().merge_with(keys_shape)
return []
@@ -827,8 +825,18 @@ def _LookupTableSizeShape(op):
return [tensor_shape.scalar()]
+@ops.RegisterShape("LookupTableExport")
+def _LookupTableExportShape(op):
+ """Shape function for data_flow_ops._lookup_table_export_values."""
+ op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ keys_shape = tensor_shape.vector(None)
+ values_shape = tensor_shape.unknown_shape()
+ return [keys_shape, values_shape]
+
+
@ops.RegisterShape("HashTable")
@ops.RegisterShape("MutableHashTable")
+@ops.RegisterShape("MutableHashTableOfTensors")
def _HashTableShape(_):
"""Shape function for data_flow_ops._hash_table."""
return [tensor_shape.scalar()]