diff options
author | 2017-01-17 19:51:14 -0800 | |
---|---|---|
committer | 2017-01-17 20:07:35 -0800 | |
commit | 1c7ef3db0840d6530e7dde693b9351b8205f7e84 (patch) | |
tree | 5ef7d9e688a8fcf7787bff2dd59dae5ef5cfccde | |
parent | 64ce45728c19c1058196dd64b19b2d10b7af42a3 (diff) |
linear_operator_test_util.py: test_log_abs_det added. This tests
LinearOperator.log_abs_determinant. This test was missing previously.
Change: 144787893
-rw-r--r-- | tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py index 466fedd578..85cd7fcd9a 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py @@ -174,6 +174,29 @@ class LinearOperatorDerivedClassTest(test.TestCase): feed_dict=feed_dict) self.assertAC(op_det_v, mat_det_v) + def test_log_abs_det(self): + self._maybe_skip("log_abs_det") + for use_placeholder in False, True: + for shape in self._shapes_to_test: + for dtype in self._dtypes_to_test: + if dtype.is_complex: + self.skipTest( + "tf.matrix_determinant does not work with complex, so this " + "test is being skipped.") + with self.test_session(graph=ops.Graph()) as sess: + sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED + operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + shape, dtype, use_placeholder=use_placeholder) + op_log_abs_det = operator.log_abs_determinant() + mat_log_abs_det = math_ops.log( + math_ops.abs(linalg_ops.matrix_determinant(mat))) + if not use_placeholder: + self.assertAllEqual(shape[:-2], op_log_abs_det.get_shape()) + op_log_abs_det_v, mat_log_abs_det_v = sess.run( + [op_log_abs_det, mat_log_abs_det], + feed_dict=feed_dict) + self.assertAC(op_log_abs_det_v, mat_log_abs_det_v) + def test_apply(self): self._maybe_skip("apply") for use_placeholder in False, True: @@ -291,7 +314,7 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest): @property def _tests_to_skip(self): """List of test names to skip.""" - return ["solve", "det"] + return ["solve", "det", "log_abs_det"] @property def _shapes_to_test(self): |