aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
blob: 3476cc89534efb7fe05640935d1387d02737f240 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
syntax = "proto3";

package tensorflow.tpu;

import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";

// The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups
// and gradient updates separate from the TF Graph.
message TPUEmbeddingConfiguration {
  // model_mode specifies whether the model is to be run in training or
  // inference. In inference mode, gradient updates to embedding tables are not
  // performed.
  enum ModelMode {
    INVALID = 0;
    TRAINING = 1;
    INFERENCE = 2;
  }

  ModelMode model_mode = 1;

  // num_hosts is the number of host CPU systems in the training/inference job.
  // Each embedding table must be sharded into num_hosts separate Variables,
  // placed separately on the num_hosts CPU devices in the cluster. Sharding
  // will be performed equivalently to the 'div' sharding_strategy option of
  // embedding_lookup() and embedding_lookup_sparse().
  int32 num_hosts = 2;

  // The total number of TensorNodes. This is equal to num_hosts times the
  // number of TensorNodes attached to each host.
  int32 num_tensornodes = 3;

  // The number of training examples per TensorNode.
  int32 batch_size = 4;

  // Each Embedding
  message TPUEmbeddingTable {
    // Name of the embedding table. This will be used to name Variables in the
    // Tensorflow Graph.
    string name = 1;

    // Number of rows of the embedding table. The Variable created to hold the
    // learned embedding table values will have shape (num_rows, width).
    int32 num_rows = 3;

    // Width of the embedding table. The Variable created to hold the
    // learned embedding table values will have shape (num_rows, width).
    int32 width = 4;

    // Number of distinct embedding activation vectors per training example
    // produced by lookups into this table during model evaluation. For each
    // table, the Graph will receive an activations Tensor of shape
    //   (batch_size * table.num_features, table.width).
    // For example, num_features = 1 produces equivalent behavior to a single
    // tf.nn.embedding_lookup() call. In the case of 'multivalent' embeddings,
    // (i.e. tf.nn.embedding_lookup_sparse()) which compute weighted averages of
    // embedding table rows, num_features is the number of vectors produced
    // after averaging. In sequence models num_features is typically equal
    // to the sequence length, since each sequence element must be represented
    // separately to the convolutional or recurrent network.
    int32 num_features = 5;

    OptimizationParameters optimization_parameters = 6;
  }

  repeated TPUEmbeddingTable table_config = 5;
}