aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-27 13:16:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-27 14:32:32 -0700
commit370a6d4e91ffcaa155dfc72a74ca082c987580f3 (patch)
treef33f950f043f538432ae6a6a1e7a7c9f8059a038 /tensorflow/contrib
parent6c14cd8e1bfbf1484ef22c82ac3badc53ed73f7d (diff)
Change SDCA to use MutableHashTable op instead of internal hash table.
Change: 125997620
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/linear_optimizer/BUILD5
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/BUILD27
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/resources.cc86
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/resources.h108
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/resources_test.cc184
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc164
-rw-r--r--tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc34
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py45
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py55
9 files changed, 106 insertions, 602 deletions
diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD
index 500587a6b7..677c8dc0b8 100644
--- a/tensorflow/contrib/linear_optimizer/BUILD
+++ b/tensorflow/contrib/linear_optimizer/BUILD
@@ -50,7 +50,10 @@ py_library(
],
data = [":python/ops/_sdca_ops.so"],
srcs_version = "PY2AND3",
- deps = [":sdca_ops"],
+ deps = [
+ ":sdca_ops",
+ "//tensorflow/contrib/lookup:lookup_py",
+ ],
)
py_test(
diff --git a/tensorflow/contrib/linear_optimizer/kernels/BUILD b/tensorflow/contrib/linear_optimizer/kernels/BUILD
index e5770085a2..1f68065c60 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/BUILD
+++ b/tensorflow/contrib/linear_optimizer/kernels/BUILD
@@ -39,39 +39,14 @@ cc_test(
)
cc_library(
- name = "resources",
- srcs = ["resources.cc"],
- hdrs = ["resources.h"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@farmhash_archive//:farmhash",
- "@protobuf//:protobuf",
- ],
-)
-
-cc_test(
- name = "resources_test",
- size = "small",
- srcs = ["resources_test.cc"],
- deps = [
- ":resources",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_library(
name = "sdca_ops",
srcs = ["sdca_ops.cc"],
deps = [
":loss_updaters",
- ":resources",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core/kernels:bounds_check_lib",
"//third_party/eigen3",
+ "@farmhash_archive//:farmhash",
"@protobuf//:protobuf",
],
alwayslink = 1,
diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources.cc b/tensorflow/contrib/linear_optimizer/kernels/resources.cc
deleted file mode 100644
index de3efea280..0000000000
--- a/tensorflow/contrib/linear_optimizer/kernels/resources.cc
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
-
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-DataByExample::DataByExample(const string& container, const string& solver_uuid)
- : container_(container), solver_uuid_(solver_uuid) {}
-
-DataByExample::~DataByExample() {}
-
-// static
-DataByExample::EphemeralKey DataByExample::MakeKey(const string& example_id) {
- return Fingerprint128(example_id);
-}
-
-DataByExample::Data DataByExample::Get(const EphemeralKey& key) {
- mutex_lock l(mu_);
- return data_by_key_[key];
-}
-
-void DataByExample::Set(const EphemeralKey& key, const Data& data) {
- mutex_lock l(mu_);
- data_by_key_[key] = data;
-}
-
-Status DataByExample::Visit(
- std::function<void(const Data& data)> visitor) const {
- struct State {
- // Snapshoted size of data_by_key_.
- size_t size;
-
- // Number of elements visited so far.
- size_t num_visited = 0;
-
- // Current element.
- DataByKey::const_iterator it;
- };
-
- auto state = [this] {
- mutex_lock l(mu_);
- State result;
- result.size = data_by_key_.size();
- result.it = data_by_key_.cbegin();
- return result;
- }();
-
- while (state.num_visited < state.size) {
- mutex_lock l(mu_);
- // Since DataByExample is modify-or-append only, a visit will (continue to)
- // be successful if and only if the size of the backing store hasn't
- // changed (since the body of this while-loop is under lock).
- if (data_by_key_.size() != state.size) {
- return errors::Unavailable("The number of elements for ", solver_uuid_,
- " has changed which nullifies a visit.");
- }
- for (size_t i = 0; i < kVisitChunkSize && state.num_visited < state.size;
- ++i, ++state.num_visited, ++state.it) {
- visitor(state.it->second);
- }
- }
- return Status::OK();
-}
-
-string DataByExample::DebugString() {
- return strings::StrCat("DataByExample(", container_, ", ", solver_uuid_, ")");
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources.h b/tensorflow/contrib/linear_optimizer/kernels/resources.h
deleted file mode 100644
index 40a683a999..0000000000
--- a/tensorflow/contrib/linear_optimizer/kernels/resources.h
+++ /dev/null
@@ -1,108 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_
-
-#include <cstddef>
-#include <functional>
-#include <string>
-#include <unordered_map>
-#include <utility>
-
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/fingerprint.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-// Resource for storing per-example data across many sessions. The data is
-// operated on in a modify or append fashion (data can be modified or added, but
-// never deleted).
-//
-// This class is thread-safe.
-class DataByExample : public ResourceBase {
- public:
- // The container and solver_uuid are only used for debugging purposes.
- DataByExample(const string& container, const string& solver_uuid);
-
- virtual ~DataByExample();
-
- // Platform independent, compact and unique (with very high probability)
- // representation of an example id. 'Ephemeral' because it shouldn't be put
- // in persistent storage, as its implementation may change in the future.
- //
- // The current probability of at least one collision for 1B example_ids is
- // approximately 10^-21 (ie 2^60 / 2^129).
- using EphemeralKey = Fprint128;
-
- // Makes a key for the supplied example_id, for compact storage.
- static EphemeralKey MakeKey(const string& example_id);
-
- struct Data {
- float dual = 0;
- float primal_loss = 0;
- float dual_loss = 0;
- float example_weight = 0;
- };
-
- // Accessor and mutator for the entry at Key. Accessor creates an entry with
- // default value (default constructed object) if the key is not present and
- // returns it.
- Data Get(const EphemeralKey& key) LOCKS_EXCLUDED(mu_);
- void Set(const EphemeralKey& key, const Data& data) LOCKS_EXCLUDED(mu_);
-
- // Visits all elements in this resource. The view of each element (Data) is
- // atomic, but the entirety of the visit is not (ie the visitor might see
- // different versions of the Data across elements).
- //
- // Returns OK on success or UNAVAILABLE if the number of elements in this
- // container has changed since the beginning of the visit (in which case the
- // visit cannot be completed and is aborted early, and computation can be
- // restarted).
- Status Visit(std::function<void(const Data& data)> visitor) const
- LOCKS_EXCLUDED(mu_);
-
- string DebugString() override;
-
- private:
- // Backing container.
- //
- // sizeof(EntryPayload) =
- // sizeof(Key) + sizeof(Data) =
- // 16 + 16 = 32.
- //
- // So on average we use ~51.5 (32 + 19.5) bytes per entry in this table.
- using EphemeralKeyHasher = Fprint128Hasher;
- using DataByKey = std::unordered_map<EphemeralKey, Data, EphemeralKeyHasher>;
-
- // TODO(sibyl-Mooth6ku): Benchmark and/or optimize this.
- static const size_t kVisitChunkSize = 100;
-
- const string container_;
- const string solver_uuid_;
-
- // TODO(sibyl-Mooth6ku): Come up with a more efficient locking scheme.
- mutable mutex mu_;
- DataByKey data_by_key_ GUARDED_BY(mu_);
-
- friend class DataByExampleTest;
-};
-
-} // namespace tensorflow
-
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_RESOURCES_H_
diff --git a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc b/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc
deleted file mode 100644
index 571e843281..0000000000
--- a/tensorflow/contrib/linear_optimizer/kernels/resources_test.cc
+++ /dev/null
@@ -1,184 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
-
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-
-// Operators for testing convenience (for EQ and NE GUnit macros).
-bool operator==(const DataByExample::Data& lhs,
- const DataByExample::Data& rhs) {
- return lhs.dual == rhs.dual && //
- lhs.primal_loss == rhs.primal_loss && //
- lhs.dual_loss == rhs.dual_loss && //
- lhs.example_weight == rhs.example_weight;
-}
-
-bool operator!=(const DataByExample::Data& lhs,
- const DataByExample::Data& rhs) {
- return !(lhs == rhs);
-}
-
-class DataByExampleTest : public ::testing::Test {
- protected:
- void SetUp() override {
- const string solver_uuid = "TheSolver";
- ASSERT_TRUE(resource_manager_
- .LookupOrCreate<DataByExample>(
- container_, solver_uuid, &data_by_example_,
- [&, this](DataByExample** ret) {
- *ret = new DataByExample(container_, solver_uuid);
- return Status::OK();
- })
- .ok());
- }
-
- void TearDown() override {
- data_by_example_->Unref();
- ASSERT_TRUE(resource_manager_.Cleanup(container_).ok());
- }
-
- // Accessors and mutators to private members of DataByExample for better
- // testing.
- static size_t VisitChunkSize() { return DataByExample::kVisitChunkSize; }
- void InsertReservedEntryUnlocked() NO_THREAD_SAFETY_ANALYSIS {
- data_by_example_->data_by_key_[{0, 0}];
- }
-
- const string container_ = "TheContainer";
- ResourceMgr resource_manager_;
- DataByExample* data_by_example_ = nullptr;
-};
-
-TEST_F(DataByExampleTest, MakeKeyIsCollisionResistent) {
- const DataByExample::EphemeralKey key =
- DataByExample::MakeKey("TheExampleId");
- EXPECT_NE(key.low64, key.high64);
-}
-
-TEST_F(DataByExampleTest, MakeKeyIsPlatformAgnostic) {
- // This is one way of enforcing the platform-agnostic nature of
- // DataByExample::MakeKey. Basically we are checking against exact values and
- // this test could be running across different platforms.
- // Note that it is fine for expected values to change in the future, if the
- // implementation of MakeKey changes (ie this is *not* a frozen test).
- const DataByExample::EphemeralKey key =
- DataByExample::MakeKey("TheExampleId");
- EXPECT_EQ(10492632643343118393ULL, key.low64);
- EXPECT_EQ(1007244271654873956ULL, key.high64);
-}
-
-TEST_F(DataByExampleTest, ElementAccessAndMutation) {
- const DataByExample::EphemeralKey key1 =
- DataByExample::MakeKey("TheExampleId1");
- EXPECT_EQ(DataByExample::Data(), data_by_example_->Get(key1));
-
- DataByExample::Data data1;
- data1.dual = 1.0f;
- data_by_example_->Set(key1, data1);
- EXPECT_EQ(data1, data_by_example_->Get(key1));
-
- const DataByExample::EphemeralKey key2 =
- DataByExample::MakeKey("TheExampleId2");
- EXPECT_NE(data_by_example_->Get(key1), data_by_example_->Get(key2));
-}
-
-TEST_F(DataByExampleTest, VisitEmpty) {
- size_t num_elements = 0;
- ASSERT_TRUE(
- data_by_example_
- ->Visit([&](const DataByExample::Data& data) { ++num_elements; })
- .ok());
- EXPECT_EQ(0, num_elements);
-}
-
-TEST_F(DataByExampleTest, VisitMany) {
- const size_t kNumElements = 2 * VisitChunkSize() + 1;
- for (size_t i = 0; i < kNumElements; ++i) {
- DataByExample::Data data;
- data.dual = static_cast<float>(i);
- data_by_example_->Set(DataByExample::MakeKey(strings::StrCat(i)), data);
- }
- size_t num_elements = 0;
- double total_dual = 0;
- ASSERT_TRUE(data_by_example_
- ->Visit([&](const DataByExample::Data& data) {
- ++num_elements;
- total_dual += data.dual;
- })
- .ok());
- EXPECT_EQ(kNumElements, num_elements);
- EXPECT_DOUBLE_EQ(
- // 0 + 1 + ... + (N-1) = (N-1)*N/2
- (kNumElements - 1) * kNumElements / 2.0, total_dual);
-}
-
-TEST_F(DataByExampleTest, VisitUnavailable) {
- // Populate enough entries so that Visiting will be chunked.
- for (size_t i = 0; i < 2 * VisitChunkSize(); ++i) {
- data_by_example_->Get(DataByExample::MakeKey(strings::StrCat(i)));
- }
-
- struct Condition {
- mutex mu;
- bool c GUARDED_BY(mu) = false;
- condition_variable cv;
- };
- auto signal = [](Condition* const condition) {
- mutex_lock l(condition->mu);
- condition->c = true;
- condition->cv.notify_all();
- };
- auto wait = [](Condition* const condition) {
- mutex_lock l(condition->mu);
- while (!condition->c) {
- condition->cv.wait(l);
- }
- };
-
- Condition paused_visit; // Signaled after a Visit has paused.
- Condition updated_data; // Signaled after data has been updated.
- Condition completed_visit; // Signaled after a Visit has completed.
-
- thread::ThreadPool thread_pool(Env::Default(), "test", 2 /* num_threads */);
- Status status;
- size_t num_visited = 0;
- thread_pool.Schedule([&] {
- status = data_by_example_->Visit([&](const DataByExample::Data& unused) {
- ++num_visited;
- if (num_visited == VisitChunkSize()) {
- // Safe point to mutate the data structure without a lock below.
- signal(&paused_visit);
- wait(&updated_data);
- }
- });
- signal(&completed_visit);
- });
- thread_pool.Schedule([&, this] {
- wait(&paused_visit);
- InsertReservedEntryUnlocked();
- signal(&updated_data);
- });
- wait(&completed_visit);
- EXPECT_TRUE(errors::IsUnavailable(status));
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
index 94fe137c2c..20e4eba451 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
@@ -28,14 +28,12 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h"
#include "tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h"
-#include "tensorflow/contrib/linear_optimizer/kernels/resources.h"
#include "tensorflow/contrib/linear_optimizer/kernels/squared-loss.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -47,6 +45,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/group_iterator.h"
@@ -596,22 +595,18 @@ class FeaturesAndWeights {
};
Status RunTrainStepsForMiniBatch(
- const int num_examples, const TTypes<const string>::Vec example_ids,
- const TTypes<const float>::Vec example_labels,
+ const int num_examples, const TTypes<const float>::Vec example_labels,
const TTypes<const float>::Vec example_weights,
const DeviceBase::CpuWorkerThreads& worker_threads,
const Regularizations& regularizations, const DualLossUpdater& loss_updater,
FeaturesAndWeights* const features_and_weights,
- DataByExample* const data_by_example) {
+ TTypes<float>::Matrix example_state_data) {
// Process examples in parallel, in a partitioned fashion.
mutex mu;
Status train_step_status GUARDED_BY(mu);
auto train_step = [&](const int64 begin, const int64 end) {
for (int64 example_index = begin; example_index < end; ++example_index) {
- // Get example id, label, and weight.
- const DataByExample::EphemeralKey example_key =
- DataByExample::MakeKey(example_ids(example_index));
- DataByExample::Data data = data_by_example->Get(example_key);
+ const float dual = example_state_data(example_index, 0);
const float example_weight = example_weights(example_index);
float example_label = example_labels(example_index);
const Status conversion_status =
@@ -633,24 +628,23 @@ Status RunTrainStepsForMiniBatch(
const double primal_loss = loss_updater.ComputePrimalLoss(
per_example_data.wx, example_label, example_weight);
- const double dual_loss = loss_updater.ComputeDualLoss(
- data.dual, example_label, example_weight);
+ const double dual_loss =
+ loss_updater.ComputeDualLoss(dual, example_label, example_weight);
const double new_dual = loss_updater.ComputeUpdatedDual(
- example_label, example_weight, data.dual, per_example_data.wx,
+ example_label, example_weight, dual, per_example_data.wx,
per_example_data.normalized_squared_norm, primal_loss, dual_loss);
// Compute new weights.
- const double bounded_dual_delta = (new_dual - data.dual) * example_weight;
+ const double bounded_dual_delta = (new_dual - dual) * example_weight;
features_and_weights->UpdateDeltaWeights(
example_index, bounded_dual_delta, regularizations.symmetric_l2());
// Update example data.
- data.dual = new_dual;
- data.primal_loss = primal_loss;
- data.dual_loss = dual_loss;
- data.example_weight = example_weight;
- data_by_example->Set(example_key, data);
+ example_state_data(example_index, 0) = new_dual;
+ example_state_data(example_index, 1) = primal_loss;
+ example_state_data(example_index, 2) = dual_loss;
+ example_state_data(example_index, 3) = example_weight;
}
};
// TODO(sibyl-Aix6ihai): Current multiplier 100000 works well empirically
@@ -689,30 +683,9 @@ class SdcaSolver : public OpKernel {
OP_REQUIRES_OK(context, regularizations_.Initialize(context));
OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
&num_inner_iterations_));
- OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
- OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
}
void Compute(OpKernelContext* context) override {
- // Get a handle on a shared container across invocations of this Kernel.
- // The shared container is intended to maintain state at the example level
- // across invocations of the kernel on different input data.
- //
- // TODO(sibyl-Mooth6ku): Replace this in-Kernel data structure with a first class
- // citizen mutable Dictionary in tensorflow proper, that we will initialize
- // and update externally.
- DataByExample* data_by_example = nullptr;
- OP_REQUIRES_OK(context,
- context->resource_manager()->LookupOrCreate<DataByExample>(
- container_, solver_uuid_, &data_by_example,
- [this](DataByExample** ret) {
- *ret = new DataByExample(container_, solver_uuid_);
- return Status::OK();
- }));
- OP_REQUIRES(
- context, !data_by_example->RefCountIsOne(),
- errors::Internal("Expected shared-ownership of data_by_example."));
-
const Tensor* example_weights_t;
OP_REQUIRES_OK(context,
context->input("example_weights", &example_weights_t));
@@ -738,16 +711,19 @@ class SdcaSolver : public OpKernel {
"number of example weights (%d).",
example_labels.size(), num_examples)));
- const Tensor* example_ids_t;
- OP_REQUIRES_OK(context, context->input("example_ids", &example_ids_t));
- OP_REQUIRES(context, TensorShapeUtils::IsVector(example_ids_t->shape()),
- errors::InvalidArgument("example_ids should be a vector."));
- const auto example_ids = example_ids_t->vec<string>();
- OP_REQUIRES(context, example_labels.size() == num_examples,
- errors::InvalidArgument(strings::Printf(
- "The number of example ids (%ld) should match the number "
- "of example weights (%d).",
- example_ids.size(), num_examples)));
+ const Tensor* example_state_data_t;
+ OP_REQUIRES_OK(context,
+ context->input("example_state_data", &example_state_data_t));
+ TensorShape expected_example_state_shape({num_examples, 4});
+ OP_REQUIRES(
+ context, example_state_data_t->shape() == expected_example_state_shape,
+ errors::InvalidArgument("Expected shape ",
+ expected_example_state_shape.DebugString(),
+ " for example_state_data, got ",
+ example_state_data_t->shape().DebugString()));
+
+ Tensor mutable_example_state_data_t(*example_state_data_t);
+ auto example_state_data = mutable_example_state_data_t.matrix<float>();
FeaturesAndWeights features_and_weights;
OP_REQUIRES_OK(context,
@@ -757,17 +733,15 @@ class SdcaSolver : public OpKernel {
for (int i = 0; i < num_inner_iterations_; ++i) {
OP_REQUIRES_OK(
- context,
- RunTrainStepsForMiniBatch(
- num_examples, example_ids, example_labels, example_weights,
- *context->device()->tensorflow_cpu_worker_threads(),
- regularizations_, *loss_updater_, &features_and_weights,
- data_by_example));
+ context, RunTrainStepsForMiniBatch(
+ num_examples, example_labels, example_weights,
+ *context->device()->tensorflow_cpu_worker_threads(),
+ regularizations_, *loss_updater_, &features_and_weights,
+ example_state_data));
}
features_and_weights.AddDeltaWeights();
- // TODO(sibyl-Mooth6ku): Use core::ScopedUnref once it's moved out of internal.
- data_by_example->Unref();
+ context->set_output(0, mutable_example_state_data_t);
}
private:
@@ -779,8 +753,6 @@ class SdcaSolver : public OpKernel {
int64 num_dense_features_;
Regularizations regularizations_;
int num_inner_iterations_;
- string container_;
- string solver_uuid_;
};
REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver);
@@ -803,72 +775,26 @@ class SdcaShrinkL1 : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
-class SdcaTrainingStats : public OpKernel {
+class SdcaFprint : public OpKernel {
public:
- explicit SdcaTrainingStats(OpKernelConstruction* context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
- OP_REQUIRES_OK(context, context->GetAttr("solver_uuid", &solver_uuid_));
- }
+ explicit SdcaFprint(OpKernelConstruction* context) : OpKernel(context) {}
- void Compute(OpKernelContext* context) override {
- DataByExample* data_by_example = nullptr;
- OP_REQUIRES_OK(context, context->resource_manager()->Lookup<DataByExample>(
- container_, solver_uuid_, &data_by_example));
- OP_REQUIRES(
- context, !data_by_example->RefCountIsOne(),
- errors::Internal("Expected shared-ownership of data_by_example."));
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& input = ctx->input(0);
+ Tensor* out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &out));
- double total_primal_loss = 0;
- double total_dual_loss = 0;
- double total_example_weight = 0;
- OP_REQUIRES_OK(context,
- data_by_example->Visit([&](const DataByExample::Data& data) {
- total_primal_loss += data.primal_loss;
- total_dual_loss += data.dual_loss;
- total_example_weight += data.example_weight;
- }));
-
- // TODO(sibyl-Mooth6ku): Think about the most arithmetically stable way of
- // computing (dual + primal) loss (if it matters).
-
- {
- Tensor* tensor = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output("primal_loss", {}, &tensor));
- tensor->scalar<double>()() = total_primal_loss;
- }
+ const auto in_values = input.flat<string>();
+ auto out_values = out->flat<string>();
- {
- Tensor* tensor = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output("dual_loss", {}, &tensor));
- tensor->scalar<double>()() = total_dual_loss;
+ for (int64 i = 0; i < in_values.size(); ++i) {
+ const string& s = in_values(i);
+ Fprint128 fprint = Fingerprint128(s);
+ out_values(i) = strings::StrCat(strings::FpToString(fprint.high64), "-",
+ strings::FpToString(fprint.low64));
}
-
- {
- OP_REQUIRES(
- context, total_example_weight > 0,
- errors::FailedPrecondition(
- "No examples found or all examples have zero weight. Either the "
- "optimizer was trained with no instances or perhaps there is a "
- "bug in the training data."));
-
- Tensor* tensor = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output("example_weights", {}, &tensor));
- tensor->scalar<double>()() = total_example_weight;
- }
-
- // TODO(sibyl-Mooth6ku): Use core::ScopedUnref once it's moved out of internal.
- data_by_example->Unref();
}
-
- private:
- string container_;
- string solver_uuid_;
};
-REGISTER_KERNEL_BUILDER(Name("SdcaTrainingStats").Device(DEVICE_CPU),
- SdcaTrainingStats);
+REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc
index 05e515423d..cfb6741459 100644
--- a/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc
+++ b/tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc
@@ -25,16 +25,15 @@ REGISTER_OP("SdcaSolver")
.Attr("l1: float")
.Attr("l2: float")
.Attr("num_inner_iterations: int >= 1")
- .Attr("container: string")
- .Attr("solver_uuid: string")
.Input("sparse_features_indices: num_sparse_features * int64")
.Input("sparse_features_values: num_sparse_features * float")
.Input("dense_features: num_dense_features * float")
.Input("example_weights: float")
.Input("example_labels: float")
- .Input("example_ids: string")
.Input("sparse_weights: Ref(num_sparse_features * float)")
.Input("dense_weights: Ref(num_dense_features * float)")
+ .Input("example_state_data: float")
+ .Output("example_data_data_out: float")
.Doc(R"doc(
Stochastic Dual Coordinate Ascent (SDCA) optimizer for linear models with
L1 + L2 regularization. As global optimization objective is strongly-convex, the
@@ -54,9 +53,6 @@ num_dense_features: Number of dense feature groups to train on.
l1: Symmetric l1 regularization strength.
l2: Symmetric l2 regularization strength.
num_inner_iterations: Number of iterations per mini-batch.
-container: Name of the Container that stores data across invocations of this
- Kernel. Together with SolverUUID form an isolation unit for this solver.
-solver_uuid: Universally Unique Identifier for this solver.
sparse_features_indices: a list of matrices with two columns that contain
example_indices, and feature_indices.
sparse_features_values: a list of vectors which contains feature value
@@ -66,12 +62,13 @@ example_weights: a vector which contains the weight associated with each
example.
example_labels: a vector which contains the label/target associated with each
example.
-example_ids: a vector which contains the unique identifier associated with each
- example.
sparse_weights: a list of vectors where each value is the weight associated with
a feature group.
dense_weights: a list of vectors where the value is the weight associated with
a dense feature group.
+example_state_data: a list of vectors containing the example state data.
+example_data_data_out: a list of vectors containing the updated example state
+ data.
)doc");
REGISTER_OP("SdcaShrinkL1")
@@ -94,23 +91,14 @@ dense_weights: a list of vectors where the value is the weight associated with
a dense feature group.
)doc");
-REGISTER_OP("SdcaTrainingStats")
- .Attr("container: string")
- .Attr("solver_uuid: string")
- .Output("primal_loss: float64")
- .Output("dual_loss: float64")
- .Output("example_weights: float64")
+REGISTER_OP("SdcaFprint")
+ .Input("input: string")
+ .Output("output: string")
.Doc(R"doc(
-Computes statistics over all examples seen by the optimizer.
+Computes fingerprints of the input strings.
-container: Name of the Container that stores data across invocations of this
- Kernel. Together with SolverUUID form an isolation unit for this solver.
-solver_uuid: Universally Unique Identifier for this solver.
-primal_loss: total primal loss of all examples seen by the optimizer.
-dual_loss: total dual loss of all examples seen by the optimizer.
-example_weights: total example weights of all examples seen by the optimizer
- (guaranteed to be positive; otherwise returns FAILED_PRECONDITION as it
- probably indicates a bug in the training data).
+input: strings to compute fingerprints on.
+output: the computed fingerprints.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index ef4733b1f1..bd5a24f9cb 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -24,6 +24,7 @@ import uuid
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _sdca_ops
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
from tensorflow.python.framework.test_util import TensorFlowTestCase
from tensorflow.python.platform import googletest
@@ -138,37 +139,6 @@ class SdcaOptimizerTest(TensorFlowTestCase):
intra_op_parallelism_threads=1)
return self.test_session(use_gpu=False, config=config)
- # The following tests, check that operations raise errors when certain
- # preconditions on the input data are not satisfied. These errors are raised
- # regardless of the loss type.
- def testNoWeightedExamples(self):
- # Setup test data with 1 positive, and 1 negative example.
- example_protos = [
- make_example_proto(
- {'age': [0],
- 'gender': [0]}, 0),
- make_example_proto(
- {'age': [1],
- 'gender': [1]}, 1),
- ]
- # Zeroed out example weights.
- example_weights = [0.0, 0.0]
- with self._single_threaded_test_session():
- examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
- options = dict(symmetric_l2_regularization=1,
- symmetric_l1_regularization=0,
- loss_type='logistic_loss')
-
- lr = SdcaModel(CONTAINER, examples, variables, options)
- tf.initialize_all_variables().run()
- self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
- lr.minimize().run()
- self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
- with self.assertRaisesOpError(
- 'No examples found or all examples have zero weight.'):
- lr.approximate_duality_gap().eval()
-
class SdcaWithLogisticLossTest(SdcaOptimizerTest):
"""SDCA optimizer test class for logistic loss."""
@@ -815,5 +785,18 @@ class SdcaWithHingeLossTest(SdcaOptimizerTest):
self.assertAllClose(0.2, unregularized_loss.eval(), atol=0.02)
self.assertAllClose(0.4, regularized_loss.eval(), atol=0.02)
+
+class SdcaFprintTest(TensorFlowTestCase):
+ """Tests for the SdcaFprint op."""
+
+ def testFprint(self):
+ with self.test_session():
+ in_data = tf.constant(['abc', 'very looooooong string', 'def'])
+ out_data = _sdca_ops.sdca_fprint(in_data)
+ self.assertAllEqual([b'a085f09013029e45-3980b2afd2126c04',
+ b'bc5a254df959f26c-512e479a50910f9f',
+ b'79999cd817a03f12-085f182230e03022'],
+ out_data.eval())
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 3501faa529..16355de400 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -17,11 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import uuid
-
from six.moves import range # pylint: disable=redefined-builtin
from tensorflow.contrib.linear_optimizer.ops import gen_sdca_ops
+from tensorflow.contrib.lookup import lookup_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework.load_library import load_op_library
@@ -106,10 +105,11 @@ class SdcaModel(object):
```
"""
- def __init__(self, container, examples, variables, options):
+ def __init__(self, container, examples, variables, options): # pylint: disable=unused-argument
"""Create a new sdca optimizer."""
+ # TODO(andreasst): get rid of obsolete container parameter
- if not container or not examples or not variables or not options:
+ if not examples or not variables or not options:
raise ValueError('All arguments must be specified.')
supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss')
@@ -136,12 +136,12 @@ class SdcaModel(object):
raise ValueError('%s should be non-negative. Found (%f)' %
(name, value))
- self._container = container
self._examples = examples
self._variables = variables
self._options = options
- self._solver_uuid = uuid.uuid4().hex
self._create_slots()
+ self._hashtable = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32,
+ [0.0, 0.0, 0.0, 0.0])
def _symmetric_l2_regularization(self):
# Algorithmic requirement (for now) is to have minimal l2 of 1.0
@@ -264,19 +264,23 @@ class SdcaModel(object):
sparse_features_indices.append(convert_to_tensor(sf.indices))
sparse_features_values.append(convert_to_tensor(sf.values))
- step_op = _sdca_ops.sdca_solver(
+ example_ids_hashed = _sdca_ops.sdca_fprint(convert_to_tensor(
+ self._examples['example_ids']))
+ example_state_data = self._hashtable.lookup(example_ids_hashed)
+
+ example_state_data_updated = _sdca_ops.sdca_solver(
sparse_features_indices,
sparse_features_values,
self._convert_n_to_tensor(self._examples['dense_features']),
convert_to_tensor(self._examples['example_weights']),
convert_to_tensor(self._examples['example_labels']),
- convert_to_tensor(self._examples['example_ids']),
self._convert_n_to_tensor(
self._slots['unshrinked_sparse_features_weights'],
as_ref=True),
self._convert_n_to_tensor(
self._slots['unshrinked_dense_features_weights'],
as_ref=True),
+ example_state_data,
l1=self._options['symmetric_l1_regularization'],
l2=self._symmetric_l2_regularization(),
# TODO(sibyl-Aix6ihai): Provide empirical evidence for this. It is better
@@ -286,17 +290,17 @@ class SdcaModel(object):
# reuse old samples than train on new samples.
# See: http://arxiv.org/abs/1602.02136.
num_inner_iterations=2,
- loss_type=self._options['loss_type'],
- container=self._container,
- solver_uuid=self._solver_uuid)
- with ops.control_dependencies([step_op]):
- assign_ops = []
+ loss_type=self._options['loss_type'])
+ with ops.control_dependencies([example_state_data_updated]):
+ insert_op = self._hashtable.insert(example_ids_hashed,
+ example_state_data_updated)
+ update_ops = [insert_op]
for name in ['sparse_features_weights', 'dense_features_weights']:
for var, slot_var in zip(self._variables[name],
self._slots['unshrinked_' + name]):
- assign_ops.append(var.assign(slot_var))
- assign_group = control_flow_ops.group(*assign_ops)
- with ops.control_dependencies([assign_group]):
+ update_ops.append(var.assign(slot_var))
+ update_group = control_flow_ops.group(*update_ops)
+ with ops.control_dependencies([update_group]):
shrink_l1 = _sdca_ops.sdca_shrink_l1(
self._convert_n_to_tensor(
self._variables['sparse_features_weights'],
@@ -318,14 +322,17 @@ class SdcaModel(object):
An Operation that computes the approximate duality gap over all
examples.
"""
- (primal_loss, dual_loss, example_weights) = _sdca_ops.sdca_training_stats(
- container=self._container,
- solver_uuid=self._solver_uuid)
- # Note that example_weights is guaranteed to be positive by
- # sdca_training_stats so dividing by it is safe.
- return (primal_loss + dual_loss + math_ops.to_double(self._l1_loss()) +
- (2.0 * math_ops.to_double(self._l2_loss(
- self._symmetric_l2_regularization())))) / example_weights
+ _, exported_values = self._hashtable.export()
+ summed_values = math_ops.reduce_sum(exported_values, 0)
+ primal_loss = summed_values[1]
+ dual_loss = summed_values[2]
+ example_weights = summed_values[3]
+ # TODO(andreasst): what about handle examples_weights == 0?
+ return (
+ primal_loss + dual_loss + math_ops.to_float(self._l1_loss()) +
+ (2.0 *
+ math_ops.to_float(self._l2_loss(self._symmetric_l2_regularization())))
+ ) / example_weights
def unregularized_loss(self, examples):
"""Add operations to compute the loss (without the regularization loss).