diff options
author | Justine Tunney <jart@google.com> | 2016-12-29 22:46:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-29 23:06:59 -0800 |
commit | e121667dc609de978a223c56ee906368d2c4ceef (patch) | |
tree | 7d4e1f1e1b4fd469487872c0cd34ddace5ac570c /tensorflow/contrib/solvers | |
parent | 7815fcba7767aa1eb3196c5861e174f8b3c43bab (diff) |
Remove so many more hourglass imports
Change: 143230429
Diffstat (limited to 'tensorflow/contrib/solvers')
9 files changed, 155 insertions, 102 deletions
diff --git a/tensorflow/contrib/solvers/BUILD b/tensorflow/contrib/solvers/BUILD index bd7e1c0a25..87b67486ad 100644 --- a/tensorflow/contrib/solvers/BUILD +++ b/tensorflow/contrib/solvers/BUILD @@ -14,6 +14,15 @@ py_library( name = "solvers_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:tensor_array_ops", + ], ) # Ops tests @@ -24,8 +33,12 @@ cuda_py_test( ], additional_deps = [ ":solvers_py", - "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], shard_count = 4, @@ -38,7 +51,10 @@ cuda_py_test( ], additional_deps = [ ":solvers_py", - "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], @@ -52,7 +68,10 @@ cuda_py_test( ], additional_deps = [ ":solvers_py", - "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], @@ -66,7 +85,10 @@ cuda_py_test( ], additional_deps = [ ":solvers_py", - "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py index 5fea07cd83..4707dc2229 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py @@ -18,10 +18,13 @@ from __future__ import division from __future__ import print_function import numpy as np -import tensorflow as tf from tensorflow.contrib.solvers.python.ops import lanczos from tensorflow.contrib.solvers.python.ops import util +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test as test_lib def _add_test(test, test_name, fn): @@ -31,7 +34,7 @@ def _add_test(test, test_name, fn): setattr(test, test_name, fn) -class LanczosBidiagTest(tf.test.TestCase): +class LanczosBidiagTest(test_lib.TestCase): pass # Filled in below. @@ -46,9 +49,9 @@ def _get_lanczos_tests(dtype_, use_static_shape_, shape_, orthogonalize_, with self.test_session() as sess: if use_static_shape_: - a = tf.constant(a_np) + a = constant_op.constant(a_np) else: - a = tf.placeholder(dtype_) + a = array_ops.placeholder(dtype_) operator = util.create_operator(a) lbd = lanczos.lanczos_bidiag( operator, steps_, orthogonalize=orthogonalize_) @@ -56,9 +59,9 @@ def _get_lanczos_tests(dtype_, use_static_shape_, shape_, orthogonalize_, # The computed factorization should satisfy the equations # A * V = U * B # A' * U[:, :-1] = V * B[:-1, :]' - av = tf.matmul(a, lbd.v) + av = math_ops.matmul(a, lbd.v) ub = lanczos.bidiag_matmul(lbd.u, lbd.alpha, lbd.beta, adjoint_b=False) - atu = tf.matmul(a, lbd.u[:, :-1], adjoint_a=True) + atu = math_ops.matmul(a, lbd.u[:, :-1], adjoint_a=True) vbt = lanczos.bidiag_matmul(lbd.v, lbd.alpha, lbd.beta, adjoint_b=True) if use_static_shape_: @@ -86,4 +89,4 @@ if __name__ == "__main__": name = "_".join(["Lanczos", test_fn.__name__, arg_string]) _add_test(LanczosBidiagTest, name, test_fn) - tf.test.main() + test_lib.main() diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py index be66311935..a73642716b 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py @@ -18,10 +18,12 @@ from __future__ import division from __future__ import print_function import numpy as np -import tensorflow as tf from tensorflow.contrib.solvers.python.ops import least_squares from tensorflow.contrib.solvers.python.ops import util +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test as test_lib def _add_test(test, test_name, fn): @@ -31,7 +33,7 @@ def _add_test(test, test_name, fn): setattr(test, test_name, fn) -class LeastSquaresTest(tf.test.TestCase): +class LeastSquaresTest(test_lib.TestCase): pass # Filled in below. @@ -47,11 +49,11 @@ def _get_least_squares_tests(dtype_, use_static_shape_, shape_): max_iter = 20 with self.test_session() as sess: if use_static_shape_: - a = tf.constant(a_np) - rhs = tf.constant(rhs_np) + a = constant_op.constant(a_np) + rhs = constant_op.constant(rhs_np) else: - a = tf.placeholder(dtype_) - rhs = tf.placeholder(dtype_) + a = array_ops.placeholder(dtype_) + rhs = array_ops.placeholder(dtype_) operator = util.create_operator(a) cgls_graph = least_squares.cgls(operator, rhs, tol=tol, max_iter=max_iter) if use_static_shape_: @@ -82,4 +84,4 @@ if __name__ == "__main__": name = "_".join(["LeastSquares", test_fn.__name__, arg_string]) _add_test(LeastSquaresTest, name, test_fn) - tf.test.main() + test_lib.main() diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index f8265883c9..930df2414b 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -18,10 +18,12 @@ from __future__ import division from __future__ import print_function import numpy as np -import tensorflow as tf from tensorflow.contrib.solvers.python.ops import linear_equations from tensorflow.contrib.solvers.python.ops import util +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test as test_lib def _add_test(test, test_name, fn): @@ -31,7 +33,7 @@ def _add_test(test, test_name, fn): setattr(test, test_name, fn) -class LinearEquationsTest(tf.test.TestCase): +class LinearEquationsTest(test_lib.TestCase): pass # Filled in below. @@ -49,11 +51,11 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): max_iter = 20 with self.test_session() as sess: if use_static_shape_: - a = tf.constant(a_np) - rhs = tf.constant(rhs_np) + a = constant_op.constant(a_np) + rhs = constant_op.constant(rhs_np) else: - a = tf.placeholder(dtype_) - rhs = tf.placeholder(dtype_) + a = array_ops.placeholder(dtype_) + rhs = array_ops.placeholder(dtype_) operator = util.create_operator(a) cg_graph = linear_equations.conjugate_gradient( operator, rhs, tol=tol, max_iter=max_iter) @@ -82,8 +84,7 @@ if __name__ == "__main__": use_static_shape) for test_fn in _get_linear_equations_tests(dtype, use_static_shape, shape): - name = "_".join( - ["LinearEquations", test_fn.__name__, arg_string]) + name = "_".join(["LinearEquations", test_fn.__name__, arg_string]) _add_test(LinearEquationsTest, name, test_fn) - tf.test.main() + test_lib.main() diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py index c1d85546e8..1566984b27 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py @@ -18,12 +18,15 @@ from __future__ import division from __future__ import print_function import numpy as np -import tensorflow as tf from tensorflow.contrib.solvers.python.ops import util +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test -class UtilTest(tf.test.TestCase): +class UtilTest(test.TestCase): def _testCreateOperator(self, use_static_shape_): for dtype in np.float32, np.float64: @@ -32,17 +35,17 @@ class UtilTest(tf.test.TestCase): y_np = np.array([[2], [-3.], [5.]], dtype=dtype) with self.test_session() as sess: if use_static_shape_: - a = tf.constant(a_np, dtype=dtype) - x = tf.constant(x_np, dtype=dtype) - y = tf.constant(y_np, dtype=dtype) + a = constant_op.constant(a_np, dtype=dtype) + x = constant_op.constant(x_np, dtype=dtype) + y = constant_op.constant(y_np, dtype=dtype) else: - a = tf.placeholder(dtype) - x = tf.placeholder(dtype) - y = tf.placeholder(dtype) + a = array_ops.placeholder(dtype) + x = array_ops.placeholder(dtype) + y = array_ops.placeholder(dtype) op = util.create_operator(a) ax = op.apply(x) aty = op.apply_adjoint(y) - op_shape = tf.convert_to_tensor(op.shape) + op_shape = ops.convert_to_tensor(op.shape) if use_static_shape_: op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty]) else: @@ -65,7 +68,7 @@ class UtilTest(tf.test.TestCase): x_np = np.array([[2], [-3.], [5.]]) x_norm_np = np.linalg.norm(x_np) x_normalized_np = x_np / x_norm_np - x = tf.constant(x_np) + x = constant_op.constant(x_np) l2norm = util.l2norm(x) l2norm_squared = util.l2norm_squared(x) x_normalized, x_norm = util.l2normalize(x) @@ -76,4 +79,4 @@ class UtilTest(tf.test.TestCase): if __name__ == '__main__': - tf.test.main() + test.main() diff --git a/tensorflow/contrib/solvers/python/ops/lanczos.py b/tensorflow/contrib/solvers/python/ops/lanczos.py index e2eba0d999..565639ff12 100644 --- a/tensorflow/contrib/solvers/python/ops/lanczos.py +++ b/tensorflow/contrib/solvers/python/ops/lanczos.py @@ -22,9 +22,15 @@ from __future__ import print_function import collections -import tensorflow as tf - from tensorflow.contrib.solvers.python.ops import util +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import tensor_array_ops def lanczos_bidiag(operator, @@ -82,20 +88,17 @@ def lanczos_bidiag(operator, """ def tarray(size, dtype, name): - return tf.TensorArray( - dtype=dtype, - size=size, - tensor_array_name=name, - clear_after_read=False) + return tensor_array_ops.TensorArray( + dtype=dtype, size=size, tensor_array_name=name, clear_after_read=False) # Reads a row-vector at location i in tarray and returns it as a # column-vector. def read_colvec(tarray, i): - return tf.expand_dims(tarray.read(i), -1) + return array_ops.expand_dims(tarray.read(i), -1) # Writes an column-vector as a row-vecor at location i in tarray. def write_colvec(tarray, colvec, i): - return tarray.write(i, tf.squeeze(colvec)) + return tarray.write(i, array_ops.squeeze(colvec)) # Ephemeral class holding Lanczos bidiagonalization state: # u = left Lanczos vectors @@ -112,21 +115,20 @@ def lanczos_bidiag(operator, return lanzcos_bidiag_state( write_colvec(old.u, u, i + 1), write_colvec(old.v, v, i), - old.alpha.write(i, alpha), - old.beta.write(i, beta)) + old.alpha.write(i, alpha), old.beta.write(i, beta)) def gram_schmidt_step(j, basis, v): """Makes v orthogonal to the j'th vector in basis.""" v_shape = v.get_shape() basis_vec = read_colvec(basis, j) - v -= tf.matmul(basis_vec, v, adjoint_a=True) * basis_vec + v -= math_ops.matmul(basis_vec, v, adjoint_a=True) * basis_vec v.set_shape(v_shape) return j + 1, basis, v def orthogonalize_once(i, basis, v): - j = tf.constant(0, dtype=tf.int32) - _, _, v = tf.while_loop(lambda j, basis, v: j < i, gram_schmidt_step, - [j, basis, v]) + j = constant_op.constant(0, dtype=dtypes.int32) + _, _, v = control_flow_ops.while_loop(lambda j, basis, v: j < i, + gram_schmidt_step, [j, basis, v]) return util.l2normalize(v) # Iterated modified Gram-Schmidt orthogonalization adapted from PROPACK. @@ -139,9 +141,9 @@ def lanczos_bidiag(operator, # round of MGS. See proof in: # B. N. Parlett, ``The Symmetric Eigenvalue Problem'', # Prentice-Hall, Englewood Cliffs, NJ, 1980. pp. 105-109 - return tf.cond(v_new_norm < 0.7071 * v_norm, - lambda: orthogonalize_once(i, basis, v), - lambda: (v_new, v_new_norm)) + return control_flow_ops.cond(v_new_norm < 0.7071 * v_norm, + lambda: orthogonalize_once(i, basis, v), + lambda: (v_new, v_new_norm)) def stopping_criterion(i, _): # TODO(rmlarsen): Stop if an invariant subspace is detected. @@ -153,9 +155,8 @@ def lanczos_bidiag(operator, r = operator.apply_adjoint(u) # The shape inference doesn't work across cond, save and reapply the shape. r_shape = r.get_shape() - r = tf.cond( - i > 0, - lambda: r - ls.beta.read(i - 1) * read_colvec(ls.v, i - 1), + r = control_flow_ops.cond( + i > 0, lambda: r - ls.beta.read(i - 1) * read_colvec(ls.v, i - 1), lambda: r) r.set_shape(r_shape) if orthogonalize: @@ -170,10 +171,10 @@ def lanczos_bidiag(operator, return i + 1, update_state(ls, i, u, v, alpha, beta) - with tf.name_scope(name): + with ops.name_scope(name): dtype = operator.dtype if starting_vector is None: - starting_vector = tf.random_uniform( + starting_vector = random_ops.random_uniform( operator.shape[:1], -1, 1, dtype=dtype) u0, _ = util.l2normalize(starting_vector) ls = lanzcos_bidiag_state( @@ -181,11 +182,13 @@ def lanczos_bidiag(operator, v=tarray(k, dtype, "v"), alpha=tarray(k, dtype, "alpha"), beta=tarray(k, dtype, "beta")) - i = tf.constant(0, dtype=tf.int32) - _, ls = tf.while_loop(stopping_criterion, lanczos_bidiag_step, [i, ls]) + i = constant_op.constant(0, dtype=dtypes.int32) + _, ls = control_flow_ops.while_loop(stopping_criterion, lanczos_bidiag_step, + [i, ls]) return lanzcos_bidiag_state( - tf.matrix_transpose(ls.u.stack()), - tf.matrix_transpose(ls.v.stack()), ls.alpha.stack(), ls.beta.stack()) + array_ops.matrix_transpose(ls.u.stack()), + array_ops.matrix_transpose(ls.v.stack()), + ls.alpha.stack(), ls.beta.stack()) # TODO(rmlarsen): Implement C++ ops for handling bidiagonal matrices @@ -219,14 +222,16 @@ def bidiag_matmul(matrix, alpha, beta, adjoint_b=False, name="bidiag_matmul"): If `adjoint_b` is False the `A * B` is returned. If `adjoint_b` is True the `A * B'` is returned. """ - with tf.name_scope(name): - alpha = tf.expand_dims(alpha, 0) + with ops.name_scope(name): + alpha = array_ops.expand_dims(alpha, 0) if adjoint_b is False: - beta = tf.expand_dims(beta, 0) + beta = array_ops.expand_dims(beta, 0) return matrix[:, :-1] * alpha + matrix[:, 1:] * beta else: - beta = tf.expand_dims(beta[:-1], 0) - shape = tf.shape(matrix) - zero_column = tf.expand_dims(tf.zeros(shape[:1], dtype=matrix.dtype), 1) - return matrix * alpha + tf.concat_v2([zero_column, matrix[:, :-1] * beta], - 1) + beta = array_ops.expand_dims(beta[:-1], 0) + shape = array_ops.shape(matrix) + zero_column = array_ops.expand_dims( + array_ops.zeros( + shape[:1], dtype=matrix.dtype), 1) + return matrix * alpha + array_ops.concat_v2( + [zero_column, matrix[:, :-1] * beta], 1) diff --git a/tensorflow/contrib/solvers/python/ops/least_squares.py b/tensorflow/contrib/solvers/python/ops/least_squares.py index 9a2d3b24dd..fb7c0eb649 100644 --- a/tensorflow/contrib/solvers/python/ops/least_squares.py +++ b/tensorflow/contrib/solvers/python/ops/least_squares.py @@ -20,9 +20,13 @@ from __future__ import print_function import collections -import tensorflow as tf - from tensorflow.contrib.solvers.python.ops import util +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops def cgls(operator, rhs, tol=1e-6, max_iter=20, name="cgls"): @@ -74,7 +78,7 @@ def cgls(operator, rhs, tol=1e-6, max_iter=20, name="cgls"): ["i", "x", "r", "p", "gamma"]) def stopping_criterion(i, state): - return tf.logical_and(i < max_iter, state.gamma > tol) + return math_ops.logical_and(i < max_iter, state.gamma > tol) # TODO(rmlarsen): add preconditioning def cgls_step(i, state): @@ -88,19 +92,22 @@ def cgls(operator, rhs, tol=1e-6, max_iter=20, name="cgls"): p = s + beta * state.p return i + 1, cgls_state(i + 1, x, r, p, gamma) - with tf.name_scope(name): + with ops.name_scope(name): n = operator.shape[1:] - rhs = tf.expand_dims(rhs, -1) + rhs = array_ops.expand_dims(rhs, -1) s0 = operator.apply_adjoint(rhs) gamma0 = util.l2norm_squared(s0) tol = tol * tol * gamma0 - x = tf.expand_dims(tf.zeros(n, dtype=rhs.dtype.base_dtype), -1) - i = tf.constant(0, dtype=tf.int32) + x = array_ops.expand_dims( + array_ops.zeros( + n, dtype=rhs.dtype.base_dtype), -1) + i = constant_op.constant(0, dtype=dtypes.int32) state = cgls_state(i=i, x=x, r=rhs, p=s0, gamma=gamma0) - _, state = tf.while_loop(stopping_criterion, cgls_step, [i, state]) + _, state = control_flow_ops.while_loop(stopping_criterion, cgls_step, + [i, state]) return cgls_state( state.i, - x=tf.squeeze(state.x), - r=tf.squeeze(state.r), - p=tf.squeeze(state.p), + x=array_ops.squeeze(state.x), + r=array_ops.squeeze(state.r), + p=array_ops.squeeze(state.p), gamma=state.gamma) diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 41fd6e466b..8cba56eba6 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -20,9 +20,13 @@ from __future__ import print_function import collections -import tensorflow as tf - from tensorflow.contrib.solvers.python.ops import util +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops def conjugate_gradient(operator, @@ -67,7 +71,7 @@ def conjugate_gradient(operator, cg_state = collections.namedtuple("CGState", ["i", "x", "r", "p", "gamma"]) def stopping_criterion(i, state): - return tf.logical_and(i < max_iter, state.gamma > tol) + return math_ops.logical_and(i < max_iter, state.gamma > tol) # TODO(rmlarsen): add preconditioning def cg_step(i, state): @@ -80,18 +84,21 @@ def conjugate_gradient(operator, p = r + beta * state.p return i + 1, cg_state(i + 1, x, r, p, gamma) - with tf.name_scope(name): + with ops.name_scope(name): n = operator.shape[1:] - rhs = tf.expand_dims(rhs, -1) + rhs = array_ops.expand_dims(rhs, -1) gamma0 = util.l2norm_squared(rhs) tol = tol * tol * gamma0 - x = tf.expand_dims(tf.zeros(n, dtype=rhs.dtype.base_dtype), -1) - i = tf.constant(0, dtype=tf.int32) + x = array_ops.expand_dims( + array_ops.zeros( + n, dtype=rhs.dtype.base_dtype), -1) + i = constant_op.constant(0, dtype=dtypes.int32) state = cg_state(i=i, x=x, r=rhs, p=rhs, gamma=gamma0) - _, state = tf.while_loop(stopping_criterion, cg_step, [i, state]) + _, state = control_flow_ops.while_loop(stopping_criterion, cg_step, + [i, state]) return cg_state( state.i, - x=tf.squeeze(state.x), - r=tf.squeeze(state.r), - p=tf.squeeze(state.p), + x=array_ops.squeeze(state.x), + r=array_ops.squeeze(state.r), + p=array_ops.squeeze(state.p), gamma=state.gamma) diff --git a/tensorflow/contrib/solvers/python/ops/util.py b/tensorflow/contrib/solvers/python/ops/util.py index 4f8bbb883d..777e0c185d 100644 --- a/tensorflow/contrib/solvers/python/ops/util.py +++ b/tensorflow/contrib/solvers/python/ops/util.py @@ -20,7 +20,10 @@ from __future__ import print_function import collections -import tensorflow as tf +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops def create_operator(matrix): @@ -34,27 +37,27 @@ def create_operator(matrix): if shape.is_fully_defined(): shape = shape.as_list() else: - shape = tf.shape(matrix) + shape = array_ops.shape(matrix) return linear_operator( shape=shape, dtype=matrix.dtype, - apply=lambda v: tf.matmul(matrix, v, adjoint_a=False), - apply_adjoint=lambda v: tf.matmul(matrix, v, adjoint_a=True)) + apply=lambda v: math_ops.matmul(matrix, v, adjoint_a=False), + apply_adjoint=lambda v: math_ops.matmul(matrix, v, adjoint_a=True)) # TODO(rmlarsen): Measure if we should just call matmul. def dot(x, y): - return tf.reduce_sum(tf.conj(x) * y) + return math_ops.reduce_sum(math_ops.conj(x) * y) # TODO(rmlarsen): Implement matrix/vector norm op in C++ in core. # We need 1-norm, inf-norm, and Frobenius norm. def l2norm_squared(v): - return tf.constant(2, dtype=v.dtype.base_dtype) * tf.nn.l2_loss(v) + return constant_op.constant(2, dtype=v.dtype.base_dtype) * nn_ops.l2_loss(v) def l2norm(v): - return tf.sqrt(l2norm_squared(v)) + return math_ops.sqrt(l2norm_squared(v)) def l2normalize(v): |