diff options
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 54 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 131 | ||||
-rw-r--r-- | tensorflow/core/framework/lookup_interface.cc | 26 | ||||
-rw-r--r-- | tensorflow/core/framework/lookup_interface.h | 40 | ||||
-rw-r--r-- | tensorflow/core/kernels/initializable_lookup_table.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/initializable_lookup_table.h | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.cc | 212 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.h | 4 | ||||
-rw-r--r-- | tensorflow/core/ops/data_flow_ops.cc | 42 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/data_flow_ops.py | 16 |
11 files changed, 462 insertions, 74 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index b3248c4741..cdfe3f478b 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -179,6 +179,7 @@ class InitializableLookupTableBase(LookupInterface): name=name) # pylint: enable=protected-access + values.set_shape(key_tensor.get_shape()) if isinstance(keys, ops.SparseTensor): return ops.SparseTensor(keys.indices, values, keys.shape) else: @@ -726,22 +727,30 @@ class MutableHashTable(LookupInterface): Returns: A `MutableHashTable` object. """ + self._default_value = ops.convert_to_tensor(default_value, + dtype=value_dtype) + self._value_shape = self._default_value.get_shape() + # pylint: disable=protected-access - self._table_ref = gen_data_flow_ops._mutable_hash_table( - shared_name=shared_name, - key_dtype=key_dtype, - value_dtype=value_dtype, - name=name) + if self._default_value.get_shape().ndims == 0: + self._table_ref = gen_data_flow_ops._mutable_hash_table( + shared_name=shared_name, + key_dtype=key_dtype, + value_dtype=value_dtype, + name=name) + else: + self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors( + shared_name=shared_name, + key_dtype=key_dtype, + value_dtype=value_dtype, + value_shape=self._default_value.get_shape(), + name=name) # pylint: enable=protected-access + super(MutableHashTable, self).__init__(key_dtype, value_dtype, self._table_ref.op.name.split( "/")[-1]) - with ops.op_scope([self._table_ref, default_value], name, - "MutableHashTable"): - self._default_value = ops.convert_to_tensor(default_value, - dtype=self._value_dtype) - def size(self, name=None): """Compute the number of elements in this table. @@ -786,6 +795,7 @@ class MutableHashTable(LookupInterface): name=name) # pylint: enable=protected-access + values.set_shape(keys.get_shape().concatenate(self._value_shape)) return values def insert(self, keys, values, name=None): @@ -814,3 +824,27 @@ class MutableHashTable(LookupInterface): # pylint: enable=protected-access return op + + def export(self, name=None): + """Returns tensors of all keys and values in the table. + + Args: + name: A name for the operation (optional). + + Returns: + A pair of tensors with the first tensor containing all keys and the + second tensors containing all values in the table. + """ + with ops.op_scope([self._table_ref], name, + "%s_lookup_table_export_values" % self._name) as name: + # pylint: disable=protected-access + exported_keys, exported_values = gen_data_flow_ops._lookup_table_export( + self._table_ref, + self._key_dtype, + self._value_dtype, + name=name) + # pylint: enable=protected-access + + exported_values.set_shape(exported_keys.get_shape().concatenate( + self._value_shape)) + return exported_keys, exported_values diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index c8a718741d..48e5e43a6e 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -38,6 +38,7 @@ class HashTableOpTest(tf.test.TestCase): input_string = tf.constant(["brain", "salad", "tank"]) output = table.lookup(input_string) + self.assertAllEqual([3], output.get_shape()) result = output.eval() self.assertAllEqual([0, 1, -1], result) @@ -251,10 +252,120 @@ class MutableHashTableOpTest(tf.test.TestCase): input_string = tf.constant(["brain", "salad", "tank"]) output = table.lookup(input_string) + self.assertAllEqual([3], output.get_shape()) result = output.eval() self.assertAllEqual([0, 1, -1], result) + exported_keys, exported_values = table.export() + self.assertAllEqual([None], exported_keys.get_shape().as_list()) + self.assertAllEqual([None], exported_values.get_shape().as_list()) + + # exported data is in the order of the internal map, i.e. undefined + sorted_keys = np.sort(exported_keys.eval()) + sorted_values = np.sort(exported_values.eval()) + self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) + self.assertAllEqual([0, 1, 2], sorted_values) + + def testMutableHashTableOfTensors(self): + with self.test_session(): + default_val = tf.constant([-1, -1], tf.int64) + keys = tf.constant(["brain", "salad", "surgery"]) + values = tf.constant([[0, 1], [2, 3], [4, 5]], tf.int64) + table = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, + default_val) + self.assertAllEqual(0, table.size().eval()) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + input_string = tf.constant(["brain", "salad", "tank"]) + output = table.lookup(input_string) + self.assertAllEqual([3, 2], output.get_shape()) + + result = output.eval() + self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) + + exported_keys, exported_values = table.export() + self.assertAllEqual([None], exported_keys.get_shape().as_list()) + self.assertAllEqual([None, 2], exported_values.get_shape().as_list()) + # exported data is in the order of the internal map, i.e. undefined + sorted_keys = np.sort(exported_keys.eval()) + sorted_values = np.sort(exported_values.eval()) + self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) + self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values) + + def testMutableHashTableExportInsert(self): + with self.test_session(): + default_val = tf.constant([-1, -1], tf.int64) + keys = tf.constant(["brain", "salad", "surgery"]) + values = tf.constant([[0, 1], [2, 3], [4, 5]], tf.int64) + table1 = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, + default_val) + self.assertAllEqual(0, table1.size().eval()) + table1.insert(keys, values).run() + self.assertAllEqual(3, table1.size().eval()) + + input_string = tf.constant(["brain", "salad", "tank"]) + expected_output = [[0, 1], [2, 3], [-1, -1]] + output1 = table1.lookup(input_string) + self.assertAllEqual(expected_output, output1.eval()) + + exported_keys, exported_values = table1.export() + self.assertAllEqual(3, exported_keys.eval().size) + self.assertAllEqual(6, exported_values.eval().size) + + # Populate a second table from the exported data + table2 = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, + default_val) + self.assertAllEqual(0, table2.size().eval()) + table2.insert(exported_keys, exported_values).run() + self.assertAllEqual(3, table2.size().eval()) + + # Verify lookup result is still the same + output2 = table2.lookup(input_string) + self.assertAllEqual(expected_output, output2.eval()) + + def testMutableHashTableOfTensorsInvalidShape(self): + with self.test_session(): + default_val = tf.constant([-1, -1], tf.int64) + keys = tf.constant(["brain", "salad", "surgery"]) + table = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, + default_val) + + # Shape [6] instead of [3, 2] + values = tf.constant([0, 1, 2, 3, 4, 5], tf.int64) + with self.assertRaisesOpError("Expected shape"): + table.insert(keys, values).run() + + # Shape [2,3] instead of [3, 2] + values = tf.constant([[0, 1, 2], [3, 4, 5]], tf.int64) + with self.assertRaisesOpError("Expected shape"): + table.insert(keys, values).run() + + # Shape [2, 2] instead of [3, 2] + values = tf.constant([[0, 1], [2, 3]], tf.int64) + with self.assertRaisesOpError("Expected shape"): + table.insert(keys, values).run() + + # Shape [3, 1] instead of [3, 2] + values = tf.constant([[0], [2], [4]], tf.int64) + with self.assertRaisesOpError("Expected shape"): + table.insert(keys, values).run() + + # Valid Insert + values = tf.constant([[0, 1], [2, 3], [4, 5]], tf.int64) + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + def testMutableHashTableInvalidDefaultValue(self): + with self.test_session(): + default_val = tf.constant([[-1, -1]], tf.int64) + table = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, + default_val) + with self.assertRaisesOpError("Default value must be a vector"): + self.assertAllEqual(0, table.size().eval()) + def testMutableHashTableDuplicateInsert(self): with self.test_session(): default_val = -1 @@ -288,6 +399,7 @@ class MutableHashTableOpTest(tf.test.TestCase): input_string = tf.constant([["brain", "salad"], ["tank", "tarkus"]]) output = table.lookup(input_string) + self.assertAllEqual([2, 2], output.get_shape()) result = output.eval() self.assertAllEqual([[0, 1], [-1, -1]], result) @@ -310,6 +422,25 @@ class MutableHashTableOpTest(tf.test.TestCase): result = output.eval() self.assertAllEqual([0, 1, 3, -1], result) + def testMutableHashTableOfTensorsFindHighRank(self): + with self.test_session(): + default_val = tf.constant([-1, -1, -1], tf.int64) + keys = tf.constant(["brain", "salad", "surgery"]) + values = tf.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], tf.int64) + table = tf.contrib.lookup.MutableHashTable(tf.string, tf.int64, + default_val) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + input_string = tf.constant([["brain", "salad"], ["tank", "tarkus"]]) + output = table.lookup(input_string) + self.assertAllEqual([2, 2, 3], output.get_shape()) + + result = output.eval() + self.assertAllEqual( + [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) + def testMultipleMutableHashTables(self): with self.test_session() as sess: default_val = -1 diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc index aafa9e084a..0d20766673 100644 --- a/tensorflow/core/framework/lookup_interface.cc +++ b/tensorflow/core/framework/lookup_interface.cc @@ -30,28 +30,30 @@ Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key, return errors::InvalidArgument("Value must be type ", value_dtype(), " but got ", value.dtype()); } - if (key.NumElements() != value.NumElements()) { - return errors::InvalidArgument("Number of elements of key(", - key.NumElements(), ") and value(", - value.NumElements(), ") are different."); - } - if (!key.shape().IsSameSize(value.shape())) { - return errors::InvalidArgument("key and value have different shapes."); + TensorShape expected_value_shape = key.shape(); + expected_value_shape.AppendShape(value_shape()); + if (value.shape() != expected_value_shape) { + return errors::InvalidArgument( + "Expected shape ", expected_value_shape.DebugString(), + " for value, got ", value.shape().DebugString()); } return Status::OK(); } Status LookupInterface::CheckFindArguments(const Tensor& key, - const Tensor& value, const Tensor& default_value) { - TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value)); - + if (key.dtype() != key_dtype()) { + return errors::InvalidArgument("Key must be type ", key_dtype(), + " but got ", key.dtype()); + } if (default_value.dtype() != value_dtype()) { return errors::InvalidArgument("Default value must be type ", value_dtype(), " but got ", default_value.dtype()); } - if (!TensorShapeUtils::IsScalar(default_value.shape())) { - return errors::InvalidArgument("Default values must be scalar."); + if (default_value.shape() != value_shape()) { + return errors::InvalidArgument( + "Expected shape ", value_shape().DebugString(), + " for default value, got ", default_value.shape().DebugString()); } return Status::OK(); } diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h index 5968f6e84d..b2279de66f 100644 --- a/tensorflow/core/framework/lookup_interface.h +++ b/tensorflow/core/framework/lookup_interface.h @@ -21,6 +21,9 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { + +class OpKernelContext; + namespace lookup { // Forward declaration so we can define GetInitializableLookupTable() in @@ -63,12 +66,33 @@ class LookupInterface : public ResourceBase { // Returns the number of elements in the table. virtual size_t size() const = 0; + virtual Status ExportValues(OpKernelContext* context) = 0; + // Returns the data type of the key. virtual DataType key_dtype() const = 0; // Returns the data type of the value. virtual DataType value_dtype() const = 0; + // Returns the shape of a value in the table. + virtual TensorShape value_shape() const = 0; + + // Check format of the key and value tensors. + // Returns OK if all the following requirements are satisfied, otherwise it + // returns InvalidArgument: + // - DataType of the tensor keys equals to the table key_dtype + // - DataType of the tensor values equals to the table value_dtype + // - the values tensor has the required shape given keys and the tables's + // value shape. + Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values); + + // Check the arguments of a find operation. Returns OK if all the following + // requirements are satisfied, otherwise it returns InvalidArgument: + // - DataType of the tensor keys equals to the table key_dtype + // - DataType of the tensor default_value equals to the table value_dtype + // - the default_value tensor shape matches the table's value shape. + Status CheckFindArguments(const Tensor& keys, const Tensor& default_value); + string DebugString() override { return "A lookup table"; } // Returns an InitializableLookupTable, a subclass of LookupInterface, if the @@ -79,22 +103,6 @@ class LookupInterface : public ResourceBase { protected: virtual ~LookupInterface() = default; - - // Check format of the key and value tensors. - // Returns OK if all the following requirements are satisfied, otherwise it - // returns InvalidArgument: - // - DataType of the tensor key equals to the table key_dtype - // - DataType of the test value equals to the table value_dtype - // - key and value have the same size and shape - Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values); - - // Check the arguments of a find operation. Returns OK if all the following - // requirements are satisfied, otherwise it returns InvalidArgument: - // - All requirements of CheckKeyAndValueTensors - // - default_value type equals to the table value_dtype - // - default_value is scalar - Status CheckFindArguments(const Tensor& keys, const Tensor& values, - const Tensor& default_value); }; } // namespace lookup diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc index feb3e8dcc7..743464d4ba 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.cc +++ b/tensorflow/core/kernels/initializable_lookup_table.cc @@ -28,7 +28,6 @@ Status InitializableLookupTable::Find(const Tensor& keys, Tensor* values, // Do not let the use migrate before the check; table is used without // a lock by the readers. std::atomic_thread_fence(std::memory_order_acquire); - TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value)); return DoFind(keys, values, default_value); } diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h index 1e394bc902..be6b4c7aa1 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.h +++ b/tensorflow/core/kernels/initializable_lookup_table.h @@ -50,6 +50,14 @@ class InitializableLookupTable : public LookupInterface { "Insert not supported by InitializableLookupTable implementations"); } + Status ExportValues(OpKernelContext* context) final { + return errors::Unimplemented( + "ExportValues not supported by InitializableLookupTable " + "implementations"); + } + + TensorShape value_shape() const final { return TensorShape(); } + // Returns whether the table was initialized and is ready to serve lookups. bool is_initialized() const { return is_initialized_; } diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 213844ec00..c48856c206 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/initializable_lookup_table.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" @@ -73,6 +74,8 @@ const float SubtleMustCopyUnlessStringOrFloat(const float value) { template <class K, class V> class HashTable : public InitializableLookupTable { public: + HashTable(OpKernelContext* ctx, OpKernel* kernel) {} + size_t size() const override { // return the size of the table only if it's initialized, otherwise 0. if (!is_initialized_) { @@ -105,7 +108,7 @@ class HashTable : public InitializableLookupTable { const auto key_values = keys.flat<K>(); const auto value_values = values.flat<V>(); - for (int i = 0; i < key_values.size(); ++i) { + for (int64 i = 0; i < key_values.size(); ++i) { const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i)); const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i)); const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value); @@ -124,7 +127,7 @@ class HashTable : public InitializableLookupTable { const auto key_values = key.flat<K>(); auto value_values = value->flat<V>(); - for (int i = 0; i < key_values.size(); ++i) { + for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( *table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), default_val); @@ -137,22 +140,23 @@ class HashTable : public InitializableLookupTable { }; // Lookup table that wraps an unordered_map, where the key and value data type -// is specified. +// is specified. Each individual value must be a scalar. If vector values are +// required, use MutableHashTableOfTensors. // // This table is mutable and thread safe - Insert can be called at any time. // // Sample use case: // -// MutableHashTable<int64, int64> table; // int64 -> int64. +// MutableHashTableOfScalars<int64, int64> table; // int64 -> int64. // // Populate the table, elements could be added in one or multiple calls. // table.Insert(key_tensor, value_tensor); // Populate the table. // // table.Find(in_t, &out_t, default_t) // template <class K, class V> -class MutableHashTable : public LookupInterface { +class MutableHashTableOfScalars : public LookupInterface { public: - MutableHashTable() {} + MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {} size_t size() const override { mutex_lock l(mu_); @@ -161,14 +165,12 @@ class MutableHashTable : public LookupInterface { Status Find(const Tensor& key, Tensor* value, const Tensor& default_value) override { - TF_RETURN_IF_ERROR(CheckFindArguments(key, *value, default_value)); - const V default_val = default_value.flat<V>()(0); const auto key_values = key.flat<K>(); auto value_values = value->flat<V>(); mutex_lock l(mu_); - for (int i = 0; i < key_values.size(); ++i) { + for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), default_val); @@ -178,13 +180,11 @@ class MutableHashTable : public LookupInterface { } Status Insert(const Tensor& keys, const Tensor& values) override { - TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(keys, values)); - const auto key_values = keys.flat<K>(); const auto value_values = values.flat<V>(); mutex_lock l(mu_); - for (int i = 0; i < key_values.size(); ++i) { + for (int64 i = 0; i < key_values.size(); ++i) { const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i)); const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i)); gtl::InsertOrUpdate(&table_, key, value); @@ -192,14 +192,139 @@ class MutableHashTable : public LookupInterface { return Status::OK(); } + Status ExportValues(OpKernelContext* ctx) override { + mutex_lock l(mu_); + int64 size = table_.size(); + + Tensor* keys; + Tensor* values; + TF_RETURN_IF_ERROR( + ctx->allocate_output("keys", TensorShape({size}), &keys)); + TF_RETURN_IF_ERROR( + ctx->allocate_output("values", TensorShape({size}), &values)); + + auto keys_data = keys->flat<K>(); + auto values_data = values->flat<V>(); + int64 i = 0; + for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { + keys_data(i) = it->first; + values_data(i) = it->second; + } + return Status::OK(); + } + + DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } + + DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } + + TensorShape value_shape() const override { return TensorShape(); } + + private: + // TODO(andreasst): consider using a read/write lock or a concurrent map + mutable mutex mu_; + std::unordered_map<K, V> table_ GUARDED_BY(mu_); +}; + +// Lookup table that wraps an unordered_map. Behaves identical to +// MutableHashTableOfScalars except that each value must be a vector. +template <class K, class V> +class MutableHashTableOfTensors : public LookupInterface { + public: + MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) { + OP_REQUIRES_OK(ctx, + GetNodeAttr(kernel->def(), "value_shape", &value_shape_)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(value_shape_), + errors::InvalidArgument("Default value must be a vector, got shape ", + value_shape_.DebugString())); + } + + size_t size() const override { + mutex_lock l(mu_); + return table_.size(); + } + + Status Find(const Tensor& key, Tensor* value, + const Tensor& default_value) override { + const auto default_flat = default_value.flat<V>(); + const auto key_values = key.flat<K>(); + auto value_values = value->flat_inner_dims<V, 2>(); + int64 value_dim = value_shape_.dim_size(0); + + mutex_lock l(mu_); + for (int64 i = 0; i < key_values.size(); ++i) { + ValueArray* value_vec = gtl::FindOrNull( + table_, SubtleMustCopyUnlessStringOrFloat(key_values(i))); + if (value_vec != nullptr) { + for (int64 j = 0; j < value_dim; j++) { + value_values(i, j) = value_vec->at(j); + } + } else { + for (int64 j = 0; j < value_dim; j++) { + value_values(i, j) = default_flat(j); + } + } + } + + return Status::OK(); + } + + Status Insert(const Tensor& keys, const Tensor& values) override { + const auto key_values = keys.flat<K>(); + const auto value_values = values.flat_inner_dims<V, 2>(); + int64 value_dim = value_shape_.dim_size(0); + + mutex_lock l(mu_); + for (int64 i = 0; i < key_values.size(); ++i) { + const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i)); + ValueArray value_vec; + for (int64 j = 0; j < value_dim; j++) { + V value = value_values(i, j); + value_vec.push_back(value); + } + gtl::InsertOrUpdate(&table_, key, value_vec); + } + return Status::OK(); + } + + Status ExportValues(OpKernelContext* ctx) override { + mutex_lock l(mu_); + int64 size = table_.size(); + int64 value_dim = value_shape_.dim_size(0); + + Tensor* keys; + Tensor* values; + TF_RETURN_IF_ERROR( + ctx->allocate_output("keys", TensorShape({size}), &keys)); + TF_RETURN_IF_ERROR(ctx->allocate_output( + "values", TensorShape({size, value_dim}), &values)); + + auto keys_data = keys->flat<K>(); + auto values_data = values->matrix<V>(); + int64 i = 0; + for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { + K key = it->first; + ValueArray value = it->second; + keys_data(i) = key; + for (int64 j = 0; j < value_dim; j++) { + values_data(i, j) = value[j]; + } + } + return Status::OK(); + } + DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } + TensorShape value_shape() const override { return value_shape_; } + private: + TensorShape value_shape_; // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; - std::unordered_map<K, V> table_; + typedef gtl::InlinedVector<V, 4> ValueArray; + std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_); }; } // namespace lookup @@ -219,17 +344,16 @@ class LookupTableFindOp : public OpKernel { DataTypeVector expected_outputs = {table->value_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); - const Tensor& input = ctx->input(1); - + const Tensor& key = ctx->input(1); const Tensor& default_value = ctx->input(2); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(default_value.shape()), - errors::InvalidArgument("Default value must be a scalar, not ", - default_value.shape().DebugString())); + OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value)); + TensorShape output_shape = key.shape(); + output_shape.AppendShape(table->value_shape()); Tensor* out; - OP_REQUIRES_OK(ctx, ctx->allocate_output("values", input.shape(), &out)); + OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out)); - OP_REQUIRES_OK(ctx, table->Find(input, out, default_value)); + OP_REQUIRES_OK(ctx, table->Find(key, out, default_value)); } }; @@ -252,6 +376,7 @@ class LookupTableInsertOp : public OpKernel { const Tensor& keys = ctx->input(1); const Tensor& values = ctx->input(2); + OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensors(keys, values)); OP_REQUIRES_OK(ctx, table->Insert(keys, values)); } }; @@ -278,6 +403,23 @@ class LookupTableSizeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU), LookupTableSizeOp); +// Op that outputs tensors of all keys and all values. +class LookupTableExportOp : public OpKernel { + public: + explicit LookupTableExportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table; + OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + OP_REQUIRES_OK(ctx, table->ExportValues(ctx)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU), + LookupTableExportOp); + // Register the HashTable op with the currently supported key and value types. #define REGISTER_KERNEL(key_dtype, value_dtype) \ REGISTER_KERNEL_BUILDER( \ @@ -294,13 +436,29 @@ REGISTER_KERNEL(int64, string); #undef REGISTER_KERNEL // Register the MutableHashTable op. -#define REGISTER_KERNEL(key_dtype, value_dtype) \ - REGISTER_KERNEL_BUILDER( \ - Name("MutableHashTable") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<key_dtype>("key_dtype") \ - .TypeConstraint<value_dtype>("value_dtype"), \ - LookupTableOp<lookup::MutableHashTable<key_dtype, value_dtype>, \ +#define REGISTER_KERNEL(key_dtype, value_dtype) \ + REGISTER_KERNEL_BUILDER( \ + Name("MutableHashTable") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ + key_dtype, value_dtype>) + +REGISTER_KERNEL(string, float); +REGISTER_KERNEL(string, int64); +REGISTER_KERNEL(int64, string); + +#undef REGISTER_KERNEL + +// Register the MutableHashTableOfTensors op. +#define REGISTER_KERNEL(key_dtype, value_dtype) \ + REGISTER_KERNEL_BUILDER( \ + Name("MutableHashTableOfTensors") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ key_dtype, value_dtype>) REGISTER_KERNEL(string, float); diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index 8f316a0f3c..cd44fb64c2 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -49,8 +49,8 @@ class LookupTableOp : public OpKernel { mutex_lock l(mu_); if (!table_handle_set_) { OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); - auto creator = [this](lookup::LookupInterface** ret) { - *ret = new Container(); + auto creator = [ctx, this](lookup::LookupInterface** ret) { + *ret = new Container(ctx, this); return Status::OK(); }; diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 85c3cd28fc..742c7d0b38 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -704,6 +704,20 @@ table_handle: Handle to the table. size: Scalar that contains number of elements in the table. )doc"); +REGISTER_OP("LookupTableExport") + .Input("table_handle: Ref(string)") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + REGISTER_OP("HashTable") .Output("table_handle: Ref(string)") .Attr("container: string = ''") @@ -738,8 +752,32 @@ REGISTER_OP("MutableHashTable") Creates an empty hash table. This op creates a mutable hash table, specifying the type of its keys and -values. Data can be inserted into the table using the insert operations. It -does not support the initialization operation. +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableOfTensors") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. table_handle: Handle to a table. container: If non-empty, this table is placed in the given container. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e65d44c378..ef7754a920 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -544,10 +544,12 @@ tf_gen_op_wrapper_py( "HashTable", "InitializeTable", "InitializeTableFromTextFile", + "LookupTableExport", "LookupTableFind", "LookupTableInsert", "LookupTableSize", "MutableHashTable", + "MutableHashTableOfTensors", "Mutex", "MutexAcquire", "MutexRelease", diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 9fbcbd645d..a4a3c4e669 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -732,6 +732,7 @@ ops.NoGradient("HashTable") ops.NoGradient("InitializeTable") ops.NoGradient("InitializeTableFromTextFile") ops.NoGradient("MutableHashTable") +ops.NoGradient("MutableHashTableOfTensors") ops.RegisterShape("QueueSize")(common_shapes.scalar_shape) @@ -807,16 +808,13 @@ def _DynamicStitchShape(op): def _LookupTableFindShape(op): """Shape function for data_flow_ops._lookup_table_find.""" op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) - shape_in = op.inputs[1].get_shape() - return [shape_in] + return [tensor_shape.unknown_shape()] @ops.RegisterShape("LookupTableInsert") def _LookupTableInsertShape(op): """Shape function for data_flow_ops._lookup_table_insert.""" op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) - keys_shape = op.inputs[1].get_shape() - op.inputs[2].get_shape().merge_with(keys_shape) return [] @@ -827,8 +825,18 @@ def _LookupTableSizeShape(op): return [tensor_shape.scalar()] +@ops.RegisterShape("LookupTableExport") +def _LookupTableExportShape(op): + """Shape function for data_flow_ops._lookup_table_export_values.""" + op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) + keys_shape = tensor_shape.vector(None) + values_shape = tensor_shape.unknown_shape() + return [keys_shape, values_shape] + + @ops.RegisterShape("HashTable") @ops.RegisterShape("MutableHashTable") +@ops.RegisterShape("MutableHashTableOfTensors") def _HashTableShape(_): """Shape function for data_flow_ops._hash_table.""" return [tensor_shape.scalar()] |