aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yutaka Leon <yutaka.leon@gmail.com>2016-03-14 16:58:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-15 11:45:07 -0700
commit9795521f98b48bc9c894317d0bea70100dfbe36d (patch)
tree3a51acbb69888099b8c0516a16a7bd0d7c30e4ba
parentca9bb0ccb188865cb34b0dc4b2a017f84ecf3d5c (diff)
Ensure that keys, in HashTable::Find, are validated in the face of asynchronous
updates to the keys tensor. Change: 117194227
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc20
1 files changed, 18 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index ebbf8aa7aa..60c875b1ca 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -17,16 +17,32 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <string>
+#include <type_traits>
#include <utility>
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/initializable_lookup_table.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace lookup {
+namespace {
+
+// Ensure that the compiler cannot elide a copy into a local, for
+// bounds checking on source tensors that might be updated asynchronously for
+// integral types. However strings variables are not allowed and therefore the
+// local copy is unnecessary.
+template <typename T>
+T SubtleMustCopyUnlessString(const T& value) {
+ return internal::SubtleMustCopy(value);
+}
+
+const string& SubtleMustCopyUnlessString(const string& value) { return value; }
+
+} // namespace
// Lookup table that wraps an unordered_map, where the key and value data type
// is specified.
@@ -99,8 +115,8 @@ class HashTable : public InitializableLookupTable {
auto value_values = value->flat<V>();
for (int i = 0; i < key_values.size(); ++i) {
- value_values(i) =
- gtl::FindWithDefault(*table_, key_values(i), default_val);
+ value_values(i) = gtl::FindWithDefault(
+ *table_, SubtleMustCopyUnlessString(key_values(i)), default_val);
}
return Status::OK();
}