aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/rnn/linear.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/rnn/linear.py')
-rw-r--r--tensorflow/models/rnn/linear.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/models/rnn/linear.py b/tensorflow/models/rnn/linear.py
new file mode 100644
index 0000000000..96278e73e4
--- /dev/null
+++ b/tensorflow/models/rnn/linear.py
@@ -0,0 +1,49 @@
+"""Basic linear combinations that implicitly generate variables."""
+
+import tensorflow as tf
+
+
+def linear(args, output_size, bias, bias_start=0.0, scope=None):
+ """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
+
+ Args:
+ args: a 2D Tensor or a list of 2D, batch x n, Tensors.
+ output_size: int, second dimension of W[i].
+ bias: boolean, whether to add a bias term or not.
+ bias_start: starting value to initialize the bias; 0 by default.
+ scope: VariableScope for the created subgraph; defaults to "Linear".
+
+ Returns:
+ A 2D Tensor with shape [batch x output_size] equal to
+ sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
+
+ Raises:
+ ValueError: if some of the arguments has unspecified or wrong shape.
+ """
+ assert args
+ if not isinstance(args, (list, tuple)):
+ args = [args]
+
+ # Calculate the total size of arguments on dimension 1.
+ total_arg_size = 0
+ shapes = [a.get_shape().as_list() for a in args]
+ for shape in shapes:
+ if len(shape) != 2:
+ raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
+ if not shape[1]:
+ raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
+ else:
+ total_arg_size += shape[1]
+
+ # Now the computation.
+ with tf.variable_scope(scope or "Linear"):
+ matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
+ if len(args) == 1:
+ res = tf.matmul(args[0], matrix)
+ else:
+ res = tf.matmul(tf.concat(1, args), matrix)
+ if not bias:
+ return res
+ bias_term = tf.get_variable("Bias", [output_size],
+ initializer=tf.constant_initializer(bias_start))
+ return res + bias_term