aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-02-18 09:29:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-18 11:27:58 -0800
commit61f12e51ca81c539e69c4d2f1249f16f823c8ba5 (patch)
tree730bae109b92cb80183fef0f966a0b8b55a24dc2
parentc89e89f994f32707bb9fb6dca342b612f19c9db0 (diff)
Add digamma op - derivative of lgamma op.
Change: 114971245
-rw-r--r--tensorflow/core/kernels/cwise_op_digamma.cc23
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_digamma.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_ops.h3
-rw-r--r--tensorflow/core/ops/math_ops.cc9
-rw-r--r--tensorflow/core/ops/ops.pbtxt28
-rw-r--r--tensorflow/g3doc/api_docs/python/array_ops.md50
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md9
-rw-r--r--tensorflow/g3doc/api_docs/python/image.md32
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md7
-rw-r--r--tensorflow/g3doc/api_docs/python/math_ops.md28
-rw-r--r--tensorflow/g3doc/api_docs/python/sparse_ops.md170
-rw-r--r--tensorflow/g3doc/api_docs/python/state_ops.md322
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py16
-rw-r--r--tensorflow/python/ops/math_grad.py13
-rw-r--r--tensorflow/python/ops/math_ops.py19
16 files changed, 601 insertions, 155 deletions
diff --git a/tensorflow/core/kernels/cwise_op_digamma.cc b/tensorflow/core/kernels/cwise_op_digamma.cc
new file mode 100644
index 0000000000..393f391a48
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_digamma.cc
@@ -0,0 +1,23 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "Digamma", functor::digamma, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Digamma", functor::digamma, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_digamma.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_digamma.cu.cc
new file mode 100644
index 0000000000..3c5b68b90b
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_digamma.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(digamma, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index bfb0977f7f..59450ffcc2 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -333,6 +333,9 @@ template <typename T>
struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T> > {};
template <typename T>
+struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {};
+
+template <typename T>
struct erf : base<T, Eigen::internal::scalar_erf_op<T> > {};
template <typename T>
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 5e382e5a46..63513d3533 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -191,7 +191,14 @@ Computes hyperbolic tangent of `x` element-wise.
REGISTER_OP("Lgamma")
.UNARY()
.Doc(R"doc(
-Computes the log of the absolute value of Gamma of `x` element-wise.
+Computes the log of the absolute value of `Gamma(x)` element-wise.
+)doc");
+
+REGISTER_OP("Digamma")
+ .UNARY()
+ .Doc(R"doc(
+Computes Psi, the derivative of Lgamma (the log of the absolute value of
+`Gamma(x)`), element-wise.
)doc");
REGISTER_OP("Erf")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 0132abe3e4..babc8d6ccb 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -2389,6 +2389,32 @@ op {
description: "Given a `diagonal`, this operation returns a tensor with the `diagonal` and\neverything else padded with zeros. The diagonal is computed as follows:\n\nAssume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of\nrank 2k with dimensions [D1,..., Dk, D1,..., Dk] where:\n\n`output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else.\n\nFor example:\n\n```prettyprint\n# \'diagonal\' is [1, 2, 3, 4]\ntf.diag(diagonal) ==> [[1, 0, 0, 0]\n [0, 2, 0, 0]\n [0, 0, 3, 0]\n [0, 0, 0, 4]]\n```"
}
op {
+ name: "Digamma"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_INT64
+ }
+ }
+ }
+ summary: "Computes Psi, the derivative of Lgamma (the log of the absolute value of"
+ description: "`Gamma(x)`), element-wise."
+}
+op {
name: "Div"
input_arg {
name: "x"
@@ -4056,7 +4082,7 @@ op {
}
}
}
- summary: "Computes the log of the absolute value of Gamma of `x` element-wise."
+ summary: "Computes the log of the absolute value of `Gamma(x)` element-wise."
}
op {
name: "LinSpace"
diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md
index bd01b7afca..d532731f5c 100644
--- a/tensorflow/g3doc/api_docs/python/array_ops.md
+++ b/tensorflow/g3doc/api_docs/python/array_ops.md
@@ -178,6 +178,28 @@ tf.cast(a, tf.int32) ==> [1, 2] # dtype=tf.int32
* <b>`TypeError`</b>: If `x` cannot be cast to the `dtype`.
+- - -
+
+### `tf.saturate_cast(value, dtype, name=None)` {#saturate_cast}
+
+Performs a safe saturating cast of `value` to `dtype`.
+
+This function casts the input to `dtype` without applying any scaling. If
+there is a danger that values would over or underflow in the cast, this op
+applies the appropriate clamping before the cast.
+
+##### Args:
+
+
+* <b>`value`</b>: A `Tensor`.
+* <b>`dtype`</b>: The desired output `DType`.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ `value` safely cast to `dtype`.
+
+
## Shapes and Shaping
@@ -1267,6 +1289,34 @@ boolean_mask(tensor, mask) ==> [[1, 2], [5, 6]]
## Other Functions and Classes
- - -
+### `tf.bitcast(input, type, name=None)` {#bitcast}
+
+Bitcasts a tensor from one type to another without copying data.
+
+Given a tensor `input`, this operation returns a tensor that has the same buffer
+data as `input` with datatype `type`.
+
+If the input datatype `T` is larger than the output datatype `type` then the
+shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)].
+
+If `T` is smaller than `type`, the operator requires that the rightmost
+dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from
+[..., sizeof(`type`)/sizeof(`T`)] to [...].
+
+##### Args:
+
+
+* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, `qint8`, `quint8`, `qint32`.
+* <b>`type`</b>: A `tf.DType` from: `tf.float32, tf.float64, tf.int64, tf.int32, tf.uint8, tf.uint16, tf.int16, tf.int8, tf.complex64, tf.qint8, tf.quint8, tf.qint32`.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A `Tensor` of type `type`.
+
+
+- - -
+
### `tf.shape_n(input, name=None)` {#shape_n}
Returns shape of tensors.
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index d5485ae0b7..590355bdab 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1336,6 +1336,13 @@ Returns the minimum representable value in this data type.
* <b>`TypeError`</b>: if this is a non-numeric, unordered, or quantized type.
+- - -
+
+#### `tf.DType.size` {#DType.size}
+
+
+
+
- - -
@@ -1366,7 +1373,7 @@ Converts the given `type_value` to a `DType`.
- - -
-### `tf.device(dev)` {#device}
+### `tf.device(device_name_or_function)` {#device}
Wrapper for `Graph.device()` using the default graph.
diff --git a/tensorflow/g3doc/api_docs/python/image.md b/tensorflow/g3doc/api_docs/python/image.md
index 094488a0b9..5998b591a0 100644
--- a/tensorflow/g3doc/api_docs/python/image.md
+++ b/tensorflow/g3doc/api_docs/python/image.md
@@ -163,6 +163,7 @@ PNG-encode an image.
where `channels` is:
* 1: for grayscale.
+* 2: for grayscale + alpha.
* 3: for RGB.
* 4: for RGBA.
@@ -678,9 +679,9 @@ Example:
```python
# Decode an image and convert it to HSV.
-rgb_image = tf.decode_png(..., channels=3)
-rgb_image_float = tf.convert_image_dtype(rgb_image, tf.float32)
-hsv_image = tf.rgb_to_hsv(rgb_image)
+rgb_image = tf.image.decode_png(..., channels=3)
+rgb_image_float = tf.image.convert_image_dtype(rgb_image, tf.float32)
+hsv_image = tf.image.rgb_to_hsv(rgb_image)
```
- - -
@@ -908,7 +909,7 @@ channel and then adjusts each component `x` of each pixel to
##### Returns:
- The constrast-adjusted image or images.
+ The contrast-adjusted image or images.
- - -
@@ -1219,26 +1220,3 @@ false and no bounding boxes are supplied, an error is raised.
Provide as input to `tf.image.draw_bounding_boxes`.
-
-## Other Functions and Classes
-- - -
-
-### `tf.image.saturate_cast(image, dtype)` {#saturate_cast}
-
-Performs a safe cast of image data to `dtype`.
-
-This function casts the data in image to `dtype`, without applying any
-scaling. If there is a danger that image data would over or underflow in the
-cast, this op applies the appropriate clamping before the cast.
-
-##### Args:
-
-
-* <b>`image`</b>: An image to cast to a different data type.
-* <b>`dtype`</b>: A `DType` to cast `image` to.
-
-##### Returns:
-
- `image`, safely cast to `dtype`.
-
-
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 34c7c883b5..82f7fa02fc 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -64,6 +64,7 @@
* [`latest_checkpoint`](../../api_docs/python/state_ops.md#latest_checkpoint)
* [`make_template`](../../api_docs/python/state_ops.md#make_template)
* [`moving_average_variables`](../../api_docs/python/state_ops.md#moving_average_variables)
+ * [`no_regularizer`](../../api_docs/python/state_ops.md#no_regularizer)
* [`random_normal_initializer`](../../api_docs/python/state_ops.md#random_normal_initializer)
* [`random_uniform_initializer`](../../api_docs/python/state_ops.md#random_uniform_initializer)
* [`Saver`](../../api_docs/python/state_ops.md#Saver)
@@ -78,9 +79,11 @@
* [`Variable`](../../api_docs/python/state_ops.md#Variable)
* [`variable_op_scope`](../../api_docs/python/state_ops.md#variable_op_scope)
* [`variable_scope`](../../api_docs/python/state_ops.md#variable_scope)
+ * [`VariableScope`](../../api_docs/python/state_ops.md#VariableScope)
* [`zeros_initializer`](../../api_docs/python/state_ops.md#zeros_initializer)
* **[Tensor Transformations](../../api_docs/python/array_ops.md)**:
+ * [`bitcast`](../../api_docs/python/array_ops.md#bitcast)
* [`boolean_mask`](../../api_docs/python/array_ops.md#boolean_mask)
* [`cast`](../../api_docs/python/array_ops.md#cast)
* [`concat`](../../api_docs/python/array_ops.md#concat)
@@ -95,6 +98,7 @@
* [`reshape`](../../api_docs/python/array_ops.md#reshape)
* [`reverse`](../../api_docs/python/array_ops.md#reverse)
* [`reverse_sequence`](../../api_docs/python/array_ops.md#reverse_sequence)
+ * [`saturate_cast`](../../api_docs/python/array_ops.md#saturate_cast)
* [`shape`](../../api_docs/python/array_ops.md#shape)
* [`shape_n`](../../api_docs/python/array_ops.md#shape_n)
* [`size`](../../api_docs/python/array_ops.md#size)
@@ -136,6 +140,7 @@
* [`cos`](../../api_docs/python/math_ops.md#cos)
* [`cross`](../../api_docs/python/math_ops.md#cross)
* [`diag`](../../api_docs/python/math_ops.md#diag)
+ * [`digamma`](../../api_docs/python/math_ops.md#digamma)
* [`div`](../../api_docs/python/math_ops.md#div)
* [`edit_distance`](../../api_docs/python/math_ops.md#edit_distance)
* [`erf`](../../api_docs/python/math_ops.md#erf)
@@ -257,7 +262,6 @@
* [`rgb_to_grayscale`](../../api_docs/python/image.md#rgb_to_grayscale)
* [`rgb_to_hsv`](../../api_docs/python/image.md#rgb_to_hsv)
* [`sample_distorted_bounding_box`](../../api_docs/python/image.md#sample_distorted_bounding_box)
- * [`saturate_cast`](../../api_docs/python/image.md#saturate_cast)
* [`transpose_image`](../../api_docs/python/image.md#transpose_image)
* **[Sparse Tensors](../../api_docs/python/sparse_ops.md)**:
@@ -267,6 +271,7 @@
* [`sparse_reorder`](../../api_docs/python/sparse_ops.md#sparse_reorder)
* [`sparse_retain`](../../api_docs/python/sparse_ops.md#sparse_retain)
* [`sparse_split`](../../api_docs/python/sparse_ops.md#sparse_split)
+ * [`sparse_tensor_dense_matmul`](../../api_docs/python/sparse_ops.md#sparse_tensor_dense_matmul)
* [`sparse_tensor_to_dense`](../../api_docs/python/sparse_ops.md#sparse_tensor_to_dense)
* [`sparse_to_dense`](../../api_docs/python/sparse_ops.md#sparse_to_dense)
* [`sparse_to_indicator`](../../api_docs/python/sparse_ops.md#sparse_to_indicator)
diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md
index 6e4b653cf9..cdbb54aa6a 100644
--- a/tensorflow/g3doc/api_docs/python/math_ops.md
+++ b/tensorflow/g3doc/api_docs/python/math_ops.md
@@ -573,6 +573,25 @@ Computes `ln(|gamma(x)|)` element-wise.
- - -
+### `tf.digamma(x, name=None)` {#digamma}
+
+Computes Psi, the derivative of lgamma, `ln(|gamma(x)|)`, element-wise.
+
+##### Args:
+
+
+* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `int64`,
+ or `qint32`.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+
+
+- - -
+
### `tf.erf(x, name=None)` {#erf}
Computes Gauss error function of `x` element-wise.
@@ -1197,11 +1216,10 @@ the minimum-norm solution to the under-determined linear system, i.e.
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\)
is sufficiently large.
-If `fast` is `False` then the solution is computed using the rank revealing
-QR decomposition with column pivoting. This will always compute a
-least-squares solution that minimizes the residual norm \\(||A X - B||_F^2\\),
-even when \\(A\\) is rank deficient or ill-conditioned. Notice: The current
-version does not compute a minimum norm solution. If `fast` is `False` then
+If `fast` is `False` an algorithm based on the numerically robust complete
+orthogonal decomposition is used. This computes the minimum-norm
+least-squares solution, even when \\(A\\) is rank deficient. This path is
+typically 6-7 times slower than the fast path. If `fast` is `False` then
`l2_regularizer` is ignored.
##### Args:
diff --git a/tensorflow/g3doc/api_docs/python/sparse_ops.md b/tensorflow/g3doc/api_docs/python/sparse_ops.md
index 74d18bc689..419a3a0757 100644
--- a/tensorflow/g3doc/api_docs/python/sparse_ops.md
+++ b/tensorflow/g3doc/api_docs/python/sparse_ops.md
@@ -551,3 +551,173 @@ This op also returns an indicator vector such that
* <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
+
+## Math Operations
+- - -
+
+### `tf.sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, adjoint_b=False, name=None)` {#sparse_tensor_dense_matmul}
+
+Multiply SparseTensor (of rank 2) "A" by dense matrix "B".
+
+No validity checking is performed on the indices of A. However, the following
+input format is recommended for optimal behavior:
+
+if adjoint_a == false:
+ A should be sorted in lexicographically increasing order. Use
+ sparse_reorder if you're not sure.
+if adjoint_a == true:
+ A should be sorted in order of increasing dimension 1 (i.e., "column major"
+ order instead of "row major" order).
+
+Deciding when to use sparse_tensor_dense_matmul vs. matmul(sp_a=True):
+
+There are a number of questions to ask in the decision process, including:
+
+* Will the SparseTensor A fit in memory if densified?
+* Is the column count of the product large (>> 1)?
+* Is the density of A larger than approximately 15%?
+* Is backprop into A necessary?
+
+If the answer to several of these questions is yes, consider
+converting the SparseTensor to a dense one and using tf.matmul with sp_a=True.
+
+This operation tends to perform well when A is more sparse, if the column
+size of the product is small (e.g. matrix-vector multiplication),
+if sp_a.shape takes on large values. While gradients with respect to B
+are supported, gradients with respect to A are not.
+
+Below is a rough speed comparison between sparse_tensor_dense_matmul,
+labelled 'sparse', and matmul(sp_a=True), labelled 'dense'. For purposes of
+the comparison, the time spent converting from a SparseTensor to a dense
+Tensor is not included, so it is overly conservative with respect to
+the time ratio.
+
+Benchmark system:
+CPU: Intel Ivybridge with HyperThreading (6 cores) dL1:32KB dL2:256KB dL3:12MB
+GPU: NVidia Tesla k40c
+
+Compiled with:
+-c opt --config=cuda --copt=-mavx
+
+```tensorflow/python/sparse_tensor_dense_matmul_op_test --benchmarks
+A sparse [m, k] with % nonzero values between 1% and 80%
+B dense [k, n]
+
+% nnz n gpu m k dt(dense) dt(sparse) dt(sparse)/dt(dense)
+0.01 1 True 100 100 0.000221166 0.00010154 0.459112
+0.01 1 True 100 1000 0.00033858 0.000109275 0.322745
+0.01 1 True 1000 100 0.000310557 9.85661e-05 0.317385
+0.01 1 True 1000 1000 0.0008721 0.000100875 0.115669
+0.01 1 False 100 100 0.000208085 0.000107603 0.51711
+0.01 1 False 100 1000 0.000327112 9.51118e-05 0.290762
+0.01 1 False 1000 100 0.000308222 0.00010345 0.335635
+0.01 1 False 1000 1000 0.000865721 0.000101397 0.117124
+0.01 10 True 100 100 0.000218522 0.000105537 0.482958
+0.01 10 True 100 1000 0.000340882 0.000111641 0.327506
+0.01 10 True 1000 100 0.000315472 0.000117376 0.372064
+0.01 10 True 1000 1000 0.000905493 0.000123263 0.136128
+0.01 10 False 100 100 0.000221529 9.82571e-05 0.44354
+0.01 10 False 100 1000 0.000330552 0.000112615 0.340687
+0.01 10 False 1000 100 0.000341277 0.000114097 0.334324
+0.01 10 False 1000 1000 0.000819944 0.000120982 0.147549
+0.01 25 True 100 100 0.000207806 0.000105977 0.509981
+0.01 25 True 100 1000 0.000322879 0.00012921 0.400181
+0.01 25 True 1000 100 0.00038262 0.000141583 0.370035
+0.01 25 True 1000 1000 0.000865438 0.000202083 0.233504
+0.01 25 False 100 100 0.000209401 0.000104696 0.499979
+0.01 25 False 100 1000 0.000321161 0.000130737 0.407076
+0.01 25 False 1000 100 0.000377012 0.000136801 0.362856
+0.01 25 False 1000 1000 0.000861125 0.00020272 0.235413
+0.2 1 True 100 100 0.000206952 9.69219e-05 0.46833
+0.2 1 True 100 1000 0.000348674 0.000147475 0.422959
+0.2 1 True 1000 100 0.000336908 0.00010122 0.300439
+0.2 1 True 1000 1000 0.001022 0.000203274 0.198898
+0.2 1 False 100 100 0.000207532 9.5412e-05 0.459746
+0.2 1 False 100 1000 0.000356127 0.000146824 0.41228
+0.2 1 False 1000 100 0.000322664 0.000100918 0.312764
+0.2 1 False 1000 1000 0.000998987 0.000203442 0.203648
+0.2 10 True 100 100 0.000211692 0.000109903 0.519165
+0.2 10 True 100 1000 0.000372819 0.000164321 0.440753
+0.2 10 True 1000 100 0.000338651 0.000144806 0.427596
+0.2 10 True 1000 1000 0.00108312 0.000758876 0.70064
+0.2 10 False 100 100 0.000215727 0.000110502 0.512231
+0.2 10 False 100 1000 0.000375419 0.0001613 0.429653
+0.2 10 False 1000 100 0.000336999 0.000145628 0.432132
+0.2 10 False 1000 1000 0.00110502 0.000762043 0.689618
+0.2 25 True 100 100 0.000218705 0.000129913 0.594009
+0.2 25 True 100 1000 0.000394794 0.00029428 0.745402
+0.2 25 True 1000 100 0.000404483 0.0002693 0.665788
+0.2 25 True 1000 1000 0.0012002 0.00194494 1.62052
+0.2 25 False 100 100 0.000221494 0.0001306 0.589632
+0.2 25 False 100 1000 0.000396436 0.000297204 0.74969
+0.2 25 False 1000 100 0.000409346 0.000270068 0.659754
+0.2 25 False 1000 1000 0.00121051 0.00193737 1.60046
+0.5 1 True 100 100 0.000214981 9.82111e-05 0.456836
+0.5 1 True 100 1000 0.000415328 0.000223073 0.537101
+0.5 1 True 1000 100 0.000358324 0.00011269 0.314492
+0.5 1 True 1000 1000 0.00137612 0.000437401 0.317851
+0.5 1 False 100 100 0.000224196 0.000101423 0.452386
+0.5 1 False 100 1000 0.000400987 0.000223286 0.556841
+0.5 1 False 1000 100 0.000368825 0.00011224 0.304318
+0.5 1 False 1000 1000 0.00136036 0.000429369 0.31563
+0.5 10 True 100 100 0.000222125 0.000112308 0.505608
+0.5 10 True 100 1000 0.000461088 0.00032357 0.701753
+0.5 10 True 1000 100 0.000394624 0.000225497 0.571422
+0.5 10 True 1000 1000 0.00158027 0.00190898 1.20801
+0.5 10 False 100 100 0.000232083 0.000114978 0.495418
+0.5 10 False 100 1000 0.000454574 0.000324632 0.714146
+0.5 10 False 1000 100 0.000379097 0.000227768 0.600817
+0.5 10 False 1000 1000 0.00160292 0.00190168 1.18638
+0.5 25 True 100 100 0.00023429 0.000151703 0.647501
+0.5 25 True 100 1000 0.000497462 0.000598873 1.20386
+0.5 25 True 1000 100 0.000460778 0.000557038 1.20891
+0.5 25 True 1000 1000 0.00170036 0.00467336 2.74845
+0.5 25 False 100 100 0.000228981 0.000155334 0.678371
+0.5 25 False 100 1000 0.000496139 0.000620789 1.25124
+0.5 25 False 1000 100 0.00045473 0.000551528 1.21287
+0.5 25 False 1000 1000 0.00171793 0.00467152 2.71927
+0.8 1 True 100 100 0.000222037 0.000105301 0.47425
+0.8 1 True 100 1000 0.000410804 0.000329327 0.801664
+0.8 1 True 1000 100 0.000349735 0.000131225 0.375212
+0.8 1 True 1000 1000 0.00139219 0.000677065 0.48633
+0.8 1 False 100 100 0.000214079 0.000107486 0.502085
+0.8 1 False 100 1000 0.000413746 0.000323244 0.781261
+0.8 1 False 1000 100 0.000348983 0.000131983 0.378193
+0.8 1 False 1000 1000 0.00136296 0.000685325 0.50282
+0.8 10 True 100 100 0.000229159 0.00011825 0.516017
+0.8 10 True 100 1000 0.000498845 0.000532618 1.0677
+0.8 10 True 1000 100 0.000383126 0.00029935 0.781336
+0.8 10 True 1000 1000 0.00162866 0.00307312 1.88689
+0.8 10 False 100 100 0.000230783 0.000124958 0.541452
+0.8 10 False 100 1000 0.000493393 0.000550654 1.11606
+0.8 10 False 1000 100 0.000377167 0.000298581 0.791642
+0.8 10 False 1000 1000 0.00165795 0.00305103 1.84024
+0.8 25 True 100 100 0.000233496 0.000175241 0.75051
+0.8 25 True 100 1000 0.00055654 0.00102658 1.84458
+0.8 25 True 1000 100 0.000463814 0.000783267 1.68875
+0.8 25 True 1000 1000 0.00186905 0.00755344 4.04132
+0.8 25 False 100 100 0.000240243 0.000175047 0.728625
+0.8 25 False 100 1000 0.000578102 0.00104499 1.80763
+0.8 25 False 1000 100 0.000485113 0.000776849 1.60138
+0.8 25 False 1000 1000 0.00211448 0.00752736 3.55992
+```
+
+##### Args:
+
+
+* <b>`sp_a`</b>: SparseTensor A, of rank 2.
+* <b>`b`</b>: A dense Matrix with the same dtype as sp_a.
+* <b>`adjoint_a`</b>: Use the adjoint of A in the matrix multiply. If A is complex,
+ this is transpose(conj(A)). Otherwise it's transpose(A).
+* <b>`adjoint_b`</b>: Use the adjoint of B in the matrix multiply. If B is complex,
+ this is transpose(conj(B)). Otherwise it's transpose(B).
+* <b>`name`</b>: A name prefix for the returned tensors (optional)
+
+##### Returns:
+
+ A dense matrix (pseudo-code in dense np.matrix notation):
+ A = A.H if adjoint_a else A
+ B = B.H if adjoint_b else B
+ return A*B
+
+
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index 29355cb1ff..01988995bb 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -955,7 +955,7 @@ create variables contingent on certain conditions.
- - -
-### `tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=True, collections=None)` {#get_variable}
+### `tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, collections=None)` {#get_variable}
Gets an existing variable with these parameters or create a new one.
@@ -973,10 +973,14 @@ with tf.variable_scope("foo", reuse=True)
```
If initializer is `None` (the default), the default initializer passed in
-the constructor is used. If that one is `None` too, a
+the variable scope will be used. If that one is `None` too, a
`UniformUnitScalingInitializer` will be used. The initializer can also be
a Tensor, in which case the variable is initialized to this value and shape.
+Similarly, if the regularizer is `None` (the default), the default regularizer
+passed in the variable scope will be used (if that is `None` too,
+then by default no regularization is performed).
+
##### Args:
@@ -984,6 +988,9 @@ a Tensor, in which case the variable is initialized to this value and shape.
* <b>`shape`</b>: shape of the new or existing variable.
* <b>`dtype`</b>: type of the new or existing variable (defaults to `DT_FLOAT`).
* <b>`initializer`</b>: initializer for the variable if one is created.
+* <b>`regularizer`</b>: a (Tensor -> Tensor or None) function; the result of
+ applying it on a newly created variable will be added to the collection
+ GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
* <b>`trainable`</b>: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
* <b>`collections`</b>: List of graph collections keys to add the Variable to.
@@ -1003,164 +1010,86 @@ a Tensor, in which case the variable is initialized to this value and shape.
- - -
-### `tf.get_variable_scope()` {#get_variable_scope}
+### `class tf.VariableScope` {#VariableScope}
-Returns the current variable scope.
+Variable scope object to carry defaults to provide to get_variable.
+Many of the arguments we need for get_variable in a variable store are most
+easily handled with a context. This object is used for the defaults.
+Attributes:
+ name: name of the current scope, used as prefix in get_variable.
+ initializer: default initializer passed to get_variable.
+ regularizer: default regularizer passed to get_variable.
+ reuse: Boolean or None, setting the reuse in get_variable.
+ name_scope: The name passed to tf.name_scope.
- - -
-### `tf.make_template(name_, func_, **kwargs)` {#make_template}
+#### `tf.VariableScope.__init__(reuse, name='', initializer=None, regularizer=None, name_scope='')` {#VariableScope.__init__}
-Given an arbitrary function, wrap it so that it does variable sharing.
+Creates a new VariableScope with the given properties.
-This wraps `func_` in a Template and partially evaluates it. Templates are
-functions that create variables the first time they are called and reuse them
-thereafter. In order for `func_` to be compatible with a `Template` it must
-have the following properties:
-* The function should create all trainable variables and any variables that
- should be reused by calling `tf.get_variable`. If a trainable variable is
- created using `tf.Variable`, then a ValueError will be thrown. Variables
- that are intended to be locals can be created by specifying
- `tf.Variable(..., trainable=false)`.
-* The function may use variable scopes and other templates internally to
- create and reuse variables, but it shouldn't use `tf.get_variables` to
- capture variables that are defined outside of the scope of the function.
-* Internal scopes and variable names should not depend on any arguments that
- are not supplied to `make_template`. In general you will get a ValueError
- telling you that you are trying to reuse a variable that doesn't exist
- if you make a mistake.
-
-In the following example, both `z` and `w` will be scaled by the same `y`. It
-is important to note that if we didn't assign `scalar_name` and used a
-different name for z and w that a `ValueError` would be thrown because it
-couldn't reuse the variable.
-
-```python
-def my_op(x, scalar_name):
- var1 = tf.get_variable(scalar_name,
- shape=[],
- initializer=tf.constant_initializer(1))
- return x * var1
-
-scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
+- - -
-z = scale_by_y(input1)
-w = scale_by_y(input2)
-```
+#### `tf.VariableScope.get_variable(var_store, name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, collections=None)` {#VariableScope.get_variable}
-As a safe-guard, the returned function will raise a `ValueError` after the
-first call if trainable variables are created by calling `tf.Variable`.
+Gets an existing variable with this name or create a new one.
-If all of these are true, then 2 properties are enforced by the template:
-1. Calling the same template multiple times will share all non-local
- variables.
-2. Two different templates are guaranteed to be unique, unless you reenter the
- same variable scope as the initial definition of a template and redefine
- it. An examples of this exception:
+- - -
-```python
-def my_op(x, scalar_name):
- var1 = tf.get_variable(scalar_name,
- shape=[],
- initializer=tf.constant_initializer(1))
- return x * var1
+#### `tf.VariableScope.initializer` {#VariableScope.initializer}
-with tf.variable_scope('scope') as vs:
- scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
- z = scale_by_y(input1)
- w = scale_by_y(input2)
-# Creates a template that reuses the variables above.
-with tf.variable_scope(vs, reuse=True):
- scale_by_y2 = tf.make_template('scale_by_y', my_op, scalar_name='y')
- z2 = scale_by_y2(input1)
- w2 = scale_by_y2(input2)
-```
-Note: The full variable scope is captured at the time of the first call.
-Note: `name_` and `func_` have a following underscore to reduce the likelihood
-of collisions with kwargs.
+- - -
-##### Args:
+#### `tf.VariableScope.name` {#VariableScope.name}
-* <b>`name_`</b>: A name for the scope created by this template. If necessary, the name
- will be made unique by appending `_N` to the name.
-* <b>`func_`</b>: The function to wrap.
-* <b>`**kwargs`</b>: Keyword arguments to apply to `func_`.
-##### Returns:
- A function that will enter a `variable_scope` before calling `func_`. The
- first time it is called, it will create a non-reusing scope so that the
- variables will be unique. On each subsequent call, it will reuse those
- variables.
+- - -
-##### Raises:
+#### `tf.VariableScope.regularizer` {#VariableScope.regularizer}
-* <b>`ValueError`</b>: if the name is None.
- - -
-### `tf.variable_op_scope(values, name, default_name, initializer=None)` {#variable_op_scope}
+#### `tf.VariableScope.reuse` {#VariableScope.reuse}
-Returns a context manager for defining an op that creates variables.
-This context manager validates that the given `values` are from the
-same graph, ensures that that graph is the default graph, and pushes a
-name scope and a variable scope.
-If `name` is not None, it is used as is in the variable scope. If `name`
-is None, then `default_name` is used. In that case, if the same name has been
-previously used in the same scope, it will made unique be appending `_N` to
-it.
-This is intended to be used when defining generic ops and so reuse is always
-inherited.
+- - -
-For example, to define a new Python op called `my_op_with_vars`:
+#### `tf.VariableScope.reuse_variables()` {#VariableScope.reuse_variables}
-```python
-def my_op_with_vars(a, b, name=None):
- with tf.variable_op_scope([a, b], name, "MyOp") as scope:
- a = tf.convert_to_tensor(a, name="a")
- b = tf.convert_to_tensor(b, name="b")
- c = tf.get_variable('c')
- # Define some computation that uses `a`, `b`, and `c`.
- return foo_op(..., name=scope)
-```
+Reuse variables in this scope.
-##### Args:
+- - -
-* <b>`values`</b>: The list of `Tensor` arguments that are passed to the op function.
-* <b>`name`</b>: The name argument that is passed to the op function, this name is not
- uniquified in the variable scope.
-* <b>`default_name`</b>: The default name to use if the `name` argument is `None`, this
- name will be uniquified.
-* <b>`initializer`</b>: A default initializer to pass to variable scope.
+#### `tf.VariableScope.set_initializer(initializer)` {#VariableScope.set_initializer}
-##### Returns:
+Set initializer for this scope.
- A context manager for use in defining a Python op.
-##### Raises:
+- - -
+#### `tf.VariableScope.set_regularizer(regularizer)` {#VariableScope.set_regularizer}
+
+Set regularizer for this scope.
-* <b>`ValueError`</b>: when trying to reuse within a create scope, or create within
- a reuse scope, or if reuse is not `None` or `True`.
-* <b>`TypeError`</b>: when the types of some arguments are not appropriate.
- - -
-### `tf.variable_scope(name_or_scope, reuse=None, initializer=None)` {#variable_scope}
+### `tf.variable_scope(name_or_scope, reuse=None, initializer=None, regularizer=None)` {#variable_scope}
Returns a context for variable scope.
@@ -1227,6 +1156,7 @@ then all its sub-scopes become reusing as well.
* <b>`reuse`</b>: `True` or `None`; if `True`, we go into reuse mode for this scope as
well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
* <b>`initializer`</b>: default initializer for variables within this scope.
+* <b>`regularizer`</b>: default regularizer for variables within this scope.
##### Returns:
@@ -1240,6 +1170,172 @@ then all its sub-scopes become reusing as well.
* <b>`TypeError`</b>: when the types of some arguments are not appropriate.
+- - -
+
+### `tf.variable_op_scope(values, name, default_name, initializer=None, regularizer=None)` {#variable_op_scope}
+
+Returns a context manager for defining an op that creates variables.
+
+This context manager validates that the given `values` are from the
+same graph, ensures that that graph is the default graph, and pushes a
+name scope and a variable scope.
+
+If `name` is not None, it is used as is in the variable scope. If `name`
+is None, then `default_name` is used. In that case, if the same name has been
+previously used in the same scope, it will made unique be appending `_N` to
+it.
+
+This is intended to be used when defining generic ops and so reuse is always
+inherited.
+
+For example, to define a new Python op called `my_op_with_vars`:
+
+```python
+def my_op_with_vars(a, b, name=None):
+ with tf.variable_op_scope([a, b], name, "MyOp") as scope:
+ a = tf.convert_to_tensor(a, name="a")
+ b = tf.convert_to_tensor(b, name="b")
+ c = tf.get_variable('c')
+ # Define some computation that uses `a`, `b`, and `c`.
+ return foo_op(..., name=scope)
+```
+
+##### Args:
+
+
+* <b>`values`</b>: The list of `Tensor` arguments that are passed to the op function.
+* <b>`name`</b>: The name argument that is passed to the op function, this name is not
+ uniquified in the variable scope.
+* <b>`default_name`</b>: The default name to use if the `name` argument is `None`, this
+ name will be uniquified.
+* <b>`initializer`</b>: A default initializer to pass to variable scope.
+* <b>`regularizer`</b>: default regularizer for variables within this scope.
+
+##### Returns:
+
+ A context manager for use in defining a Python op.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: when trying to reuse within a create scope, or create within
+ a reuse scope, or if reuse is not `None` or `True`.
+* <b>`TypeError`</b>: when the types of some arguments are not appropriate.
+
+
+- - -
+
+### `tf.get_variable_scope()` {#get_variable_scope}
+
+Returns the current variable scope.
+
+
+- - -
+
+### `tf.make_template(name_, func_, **kwargs)` {#make_template}
+
+Given an arbitrary function, wrap it so that it does variable sharing.
+
+This wraps `func_` in a Template and partially evaluates it. Templates are
+functions that create variables the first time they are called and reuse them
+thereafter. In order for `func_` to be compatible with a `Template` it must
+have the following properties:
+
+* The function should create all trainable variables and any variables that
+ should be reused by calling `tf.get_variable`. If a trainable variable is
+ created using `tf.Variable`, then a ValueError will be thrown. Variables
+ that are intended to be locals can be created by specifying
+ `tf.Variable(..., trainable=false)`.
+* The function may use variable scopes and other templates internally to
+ create and reuse variables, but it shouldn't use `tf.get_variables` to
+ capture variables that are defined outside of the scope of the function.
+* Internal scopes and variable names should not depend on any arguments that
+ are not supplied to `make_template`. In general you will get a ValueError
+ telling you that you are trying to reuse a variable that doesn't exist
+ if you make a mistake.
+
+In the following example, both `z` and `w` will be scaled by the same `y`. It
+is important to note that if we didn't assign `scalar_name` and used a
+different name for z and w that a `ValueError` would be thrown because it
+couldn't reuse the variable.
+
+```python
+def my_op(x, scalar_name):
+ var1 = tf.get_variable(scalar_name,
+ shape=[],
+ initializer=tf.constant_initializer(1))
+ return x * var1
+
+scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
+
+z = scale_by_y(input1)
+w = scale_by_y(input2)
+```
+
+As a safe-guard, the returned function will raise a `ValueError` after the
+first call if trainable variables are created by calling `tf.Variable`.
+
+If all of these are true, then 2 properties are enforced by the template:
+
+1. Calling the same template multiple times will share all non-local
+ variables.
+2. Two different templates are guaranteed to be unique, unless you reenter the
+ same variable scope as the initial definition of a template and redefine
+ it. An examples of this exception:
+
+```python
+def my_op(x, scalar_name):
+ var1 = tf.get_variable(scalar_name,
+ shape=[],
+ initializer=tf.constant_initializer(1))
+ return x * var1
+
+with tf.variable_scope('scope') as vs:
+ scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
+ z = scale_by_y(input1)
+ w = scale_by_y(input2)
+
+# Creates a template that reuses the variables above.
+with tf.variable_scope(vs, reuse=True):
+ scale_by_y2 = tf.make_template('scale_by_y', my_op, scalar_name='y')
+ z2 = scale_by_y2(input1)
+ w2 = scale_by_y2(input2)
+```
+
+Note: The full variable scope is captured at the time of the first call.
+
+Note: `name_` and `func_` have a following underscore to reduce the likelihood
+of collisions with kwargs.
+
+##### Args:
+
+
+* <b>`name_`</b>: A name for the scope created by this template. If necessary, the name
+ will be made unique by appending `_N` to the name.
+* <b>`func_`</b>: The function to wrap.
+* <b>`**kwargs`</b>: Keyword arguments to apply to `func_`.
+
+##### Returns:
+
+ A function that will enter a `variable_scope` before calling `func_`. The
+ first time it is called, it will create a non-reusing scope so that the
+ variables will be unique. On each subsequent call, it will reuse those
+ variables.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: if the name is None.
+
+
+
+- - -
+
+### `tf.no_regularizer(_)` {#no_regularizer}
+
+Use this function to prevent regularization of variables.
+
+
- - -
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 8ebb47fe1c..312868cd75 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -581,6 +581,7 @@ tf_gen_op_wrapper_py(
"Sigmoid",
"Tanh",
"Lgamma",
+ "Digamma",
"Erf",
"Erfc",
],
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 3009b3494b..2a12f9d53f 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -57,8 +57,8 @@ class UnaryOpTest(tf.test.TestCase):
self.assertShapeEqual(np_ans, y)
self.assertAllClose(np_ans, tf_cpu)
- # TODO(ebrevdo): add gradient for lgamma (digamma) and remove lgamma here.
- if tf_func in (tf.lgamma,):
+ # TODO(ebrevdo): consider adding polygamma function
+ if tf_func in (tf.digamma,):
return # Return early
if x.dtype == np.float32:
@@ -131,7 +131,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(x, np.sin, tf.sin)
self._compareBoth(x, np.cos, tf.cos)
self._compareBoth(
- x,
+ y,
np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
tf.lgamma)
self._compareBoth(x, np.vectorize(math.erf), tf.erf)
@@ -160,6 +160,10 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(x, np.sign, tf.sign)
self._compareBoth(x, np.sin, tf.sin)
self._compareBoth(x, np.cos, tf.cos)
+ # Can't use vectorize below, so just use some arbitrary function
+ self._compareBoth(x, np.sign, tf.lgamma)
+ self._compareBoth(x, np.sign, tf.erf)
+ self._compareBoth(x, np.sign, tf.erfc)
def testDoubleBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
@@ -180,6 +184,12 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(y, np.sign, tf.sign)
self._compareBoth(x, np.sin, tf.sin)
self._compareBoth(x, np.cos, tf.cos)
+ self._compareBoth(
+ y,
+ np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ tf.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), tf.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), tf.erfc)
def testInt32Basic(self):
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 1ccf152f70..61fc643e05 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -263,9 +263,16 @@ def _ErfcGrad(op, grad):
@ops.RegisterGradient("Lgamma")
-def _LgammaGrad(op, grad): # pylint: disable=unused-argument
- # TODO(ebrevdo): implement digamma
- raise NotImplementedError("grad(Lgamma) == Digamma is not implemented")
+def _LgammaGrad(op, grad):
+ """Returns grad * digamma(x)."""
+ x = op.inputs[0]
+ with ops.control_dependencies([grad.op]):
+ return grad * math_ops.digamma(x)
+
+
+@ops.RegisterGradient("Digamma")
+def _DigammaGrad(op, grad): # pylint: disable=unused-argument
+ raise NotImplementedError("grad(Digamma) == Polygamma(1) is not implemented")
@ops.RegisterGradient("Sigmoid")
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index bcf85710e5..add76734a5 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -51,6 +51,7 @@ mathematical functions to your graph.
@@cos
@@sin
@@lgamma
+@@digamma
@@erf
@@erfc
@@squared_difference
@@ -1214,6 +1215,23 @@ def lgamma(x, name=None):
return gen_math_ops._lgamma(x, name=name)
+def digamma(x, name=None):
+ """Computes Psi, the derivative of lgamma, `ln(|gamma(x)|)`, element-wise.
+
+ Args:
+ x: A Tensor with type `float`, `double`, `int32`, `int64`,
+ or `qint32`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+ """
+ with ops.op_scope([x], name, "Digamma") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ return gen_math_ops._digamma(x, name=name)
+
+
def erf(x, name=None):
"""Computes Gauss error function of `x` element-wise.
@@ -1272,6 +1290,7 @@ ops.RegisterShape("Square")(common_shapes.unchanged_shape)
ops.RegisterShape("Sigmoid")(common_shapes.unchanged_shape)
ops.RegisterShape("Tanh")(common_shapes.unchanged_shape)
ops.RegisterShape("Lgamma")(common_shapes.unchanged_shape)
+ops.RegisterShape("Digamma")(common_shapes.unchanged_shape)
ops.RegisterShape("Erf")(common_shapes.unchanged_shape)
ops.RegisterShape("Erfc")(common_shapes.unchanged_shape)
ops.RegisterShape("Cast")(common_shapes.unchanged_shape)