aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/proto/tpu_embedding_config.proto')
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_config.proto16
1 files changed, 3 insertions, 13 deletions
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
index b0ec968d3a..3476cc8953 100644
--- a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
@@ -2,6 +2,8 @@ syntax = "proto3";
package tensorflow.tpu;
+import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
+
// The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups
// and gradient updates separate from the TF Graph.
message TPUEmbeddingConfiguration {
@@ -30,15 +32,6 @@ message TPUEmbeddingConfiguration {
// The number of training examples per TensorNode.
int32 batch_size = 4;
- message GradientDescentOptimizer {
- float learning_rate = 1;
- }
-
- message AdagradOptimizer {
- float learning_rate = 1;
- float initial_accumulator = 2;
- }
-
// Each Embedding
message TPUEmbeddingTable {
// Name of the embedding table. This will be used to name Variables in the
@@ -66,10 +59,7 @@ message TPUEmbeddingConfiguration {
// separately to the convolutional or recurrent network.
int32 num_features = 5;
- oneof optimizer {
- GradientDescentOptimizer gradient_descent = 6;
- AdagradOptimizer adagrad = 7;
- }
+ OptimizationParameters optimization_parameters = 6;
}
repeated TPUEmbeddingTable table_config = 5;