diff options
author | 2016-11-16 10:12:45 -0800 | |
---|---|---|
committer | 2016-11-16 10:26:41 -0800 | |
commit | 634790044c81521c438799b558d33b8440fa9e23 (patch) | |
tree | a0319de4c600738e4e917a3f9f2a194b08214635 /tensorflow/core/framework/lookup_interface.h | |
parent | eabf41b7cc8ab515b556bc91b4f282d1d671c1a7 (diff) |
Fixing a bug in the MutableDenseHashTable implementation where the difference in shapes between the Insert and Import functions was causing issues with a vector key and scalar value input.
Fixed by splitting the LookupInterface CheckKeysAndValueTensors method into one for Insert and the other for Import.
Change: 139346138
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.h')
-rw-r--r-- | tensorflow/core/framework/lookup_interface.h | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h index 2247bcb82f..b0b661be0e 100644 --- a/tensorflow/core/framework/lookup_interface.h +++ b/tensorflow/core/framework/lookup_interface.h @@ -67,8 +67,16 @@ class LookupInterface : public ResourceBase { // Returns the number of elements in the table. virtual size_t size() const = 0; + // Exports the values of the table to two tensors named keys and values. + // Note that the shape of the tensors is completely up to the implementation + // of the table and can be different than the tensors used for the Insert + // function above. virtual Status ExportValues(OpKernelContext* ctx) = 0; + // Imports previously exported keys and values. + // As mentioned above, the shape of the keys and values tensors are determined + // by the ExportValues function above and can be different than for the + // Insert function. virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys, const Tensor& values) = 0; @@ -84,14 +92,20 @@ class LookupInterface : public ResourceBase { // Returns the shape of a value in the table. virtual TensorShape value_shape() const = 0; - // Check format of the key and value tensors. + // Check format of the key and value tensors for the Insert function. // Returns OK if all the following requirements are satisfied, otherwise it // returns InvalidArgument: // - DataType of the tensor keys equals to the table key_dtype // - DataType of the tensor values equals to the table value_dtype // - the values tensor has the required shape given keys and the tables's // value shape. - Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values); + virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys, + const Tensor& values); + + // Similar to the function above but instead checks eligibility for the Import + // function. + virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys, + const Tensor& values); // Check the arguments of a find operation. Returns OK if all the following // requirements are satisfied, otherwise it returns InvalidArgument: @@ -111,8 +125,16 @@ class LookupInterface : public ResourceBase { protected: virtual ~LookupInterface() = default; - private: + // Makes sure that the key and value tensor DataType's match the table + // key_dtype and value_dtype. Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values); + + // Makes sure that the provided shape is consistent with the table keys shape. + Status CheckKeyShape(const TensorShape& shape); + + private: + Status CheckKeyAndValueTensorsHelper(const Tensor& keys, + const Tensor& values); }; } // namespace lookup |