diff options
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.h')
-rw-r--r-- | tensorflow/core/framework/lookup_interface.h | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h new file mode 100644 index 0000000000..d4036d2019 --- /dev/null +++ b/tensorflow/core/framework/lookup_interface.h @@ -0,0 +1,65 @@ +#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace lookup { + +// Lookup interface for batch lookups used by table lookup ops. +class LookupInterface : public ResourceBase { + public: + // Performs batch lookups, for every element in the key tensor, Find returns + // the corresponding value into the values tensor. + // If an element is not present in the table, the given default value is used. + + // For tables that require initialization, Find is available once the table + // is marked as initialized. + + // Returns the following statuses: + // - OK: when the find finishes successfully. + // - FailedPrecondition: if the table is not initialized. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + virtual Status Find(const Tensor& keys, Tensor* values, + const Tensor& default_value) = 0; + + // Returns the number of elements in the table. + virtual size_t size() const = 0; + + // Returns the data type of the key. + virtual DataType key_dtype() const = 0; + + // Returns the data type of the value. + virtual DataType value_dtype() const = 0; + + string DebugString() override { return "A lookup table"; } + + protected: + virtual ~LookupInterface() = default; + + // Check format of the key and value tensors. + // Returns OK if all the following requirements are satisfied, otherwise it + // returns InvalidArgument: + // - DataType of the tensor key equals to the table key_dtype + // - DataType of the test value equals to the table value_dtype + // - key and value have the same size and shape + Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values); + + // Check the arguments of a find operation. Returns OK if all the following + // requirements are satisfied, otherwise it returns InvalidArgument: + // - All requirements of CheckKeyAndValueTensors + // - default_value type equals to the table value_dtype + // - default_value is scalar + Status CheckFindArguments(const Tensor& keys, const Tensor& values, + const Tensor& default_value); +}; + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ |