aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/lookup_interface.h
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2016-11-16 10:12:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-16 10:26:41 -0800
commit634790044c81521c438799b558d33b8440fa9e23 (patch)
treea0319de4c600738e4e917a3f9f2a194b08214635 /tensorflow/core/framework/lookup_interface.h
parenteabf41b7cc8ab515b556bc91b4f282d1d671c1a7 (diff)
Fixing a bug in the MutableDenseHashTable implementation where the difference in shapes between the Insert and Import functions was causing issues with a vector key and scalar value input.
Fixed by splitting the LookupInterface CheckKeysAndValueTensors method into one for Insert and the other for Import. Change: 139346138
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.h')
-rw-r--r--tensorflow/core/framework/lookup_interface.h28
1 files changed, 25 insertions, 3 deletions
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h
index 2247bcb82f..b0b661be0e 100644
--- a/tensorflow/core/framework/lookup_interface.h
+++ b/tensorflow/core/framework/lookup_interface.h
@@ -67,8 +67,16 @@ class LookupInterface : public ResourceBase {
// Returns the number of elements in the table.
virtual size_t size() const = 0;
+ // Exports the values of the table to two tensors named keys and values.
+ // Note that the shape of the tensors is completely up to the implementation
+ // of the table and can be different than the tensors used for the Insert
+ // function above.
virtual Status ExportValues(OpKernelContext* ctx) = 0;
+ // Imports previously exported keys and values.
+ // As mentioned above, the shape of the keys and values tensors are determined
+ // by the ExportValues function above and can be different than for the
+ // Insert function.
virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
const Tensor& values) = 0;
@@ -84,14 +92,20 @@ class LookupInterface : public ResourceBase {
// Returns the shape of a value in the table.
virtual TensorShape value_shape() const = 0;
- // Check format of the key and value tensors.
+ // Check format of the key and value tensors for the Insert function.
// Returns OK if all the following requirements are satisfied, otherwise it
// returns InvalidArgument:
// - DataType of the tensor keys equals to the table key_dtype
// - DataType of the tensor values equals to the table value_dtype
// - the values tensor has the required shape given keys and the tables's
// value shape.
- Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values);
+ virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys,
+ const Tensor& values);
+
+ // Similar to the function above but instead checks eligibility for the Import
+ // function.
+ virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
+ const Tensor& values);
// Check the arguments of a find operation. Returns OK if all the following
// requirements are satisfied, otherwise it returns InvalidArgument:
@@ -111,8 +125,16 @@ class LookupInterface : public ResourceBase {
protected:
virtual ~LookupInterface() = default;
- private:
+ // Makes sure that the key and value tensor DataType's match the table
+ // key_dtype and value_dtype.
Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values);
+
+ // Makes sure that the provided shape is consistent with the table keys shape.
+ Status CheckKeyShape(const TensorShape& shape);
+
+ private:
+ Status CheckKeyAndValueTensorsHelper(const Tensor& keys,
+ const Tensor& values);
};
} // namespace lookup