diff options
author | 2016-03-14 16:58:53 -0800 | |
---|---|---|
committer | 2016-03-15 11:45:07 -0700 | |
commit | 9795521f98b48bc9c894317d0bea70100dfbe36d (patch) | |
tree | 3a51acbb69888099b8c0516a16a7bd0d7c30e4ba | |
parent | ca9bb0ccb188865cb34b0dc4b2a017f84ecf3d5c (diff) |
Ensure that keys, in HashTable::Find, are validated in the face of asynchronous
updates to the keys tensor.
Change: 117194227
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.cc | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index ebbf8aa7aa..60c875b1ca 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -17,16 +17,32 @@ limitations under the License. #define EIGEN_USE_THREADS #include <string> +#include <type_traits> #include <utility> #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/initializable_lookup_table.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { namespace lookup { +namespace { + +// Ensure that the compiler cannot elide a copy into a local, for +// bounds checking on source tensors that might be updated asynchronously for +// integral types. However strings variables are not allowed and therefore the +// local copy is unnecessary. +template <typename T> +T SubtleMustCopyUnlessString(const T& value) { + return internal::SubtleMustCopy(value); +} + +const string& SubtleMustCopyUnlessString(const string& value) { return value; } + +} // namespace // Lookup table that wraps an unordered_map, where the key and value data type // is specified. @@ -99,8 +115,8 @@ class HashTable : public InitializableLookupTable { auto value_values = value->flat<V>(); for (int i = 0; i < key_values.size(); ++i) { - value_values(i) = - gtl::FindWithDefault(*table_, key_values(i), default_val); + value_values(i) = gtl::FindWithDefault( + *table_, SubtleMustCopyUnlessString(key_values(i)), default_val); } return Status::OK(); } |