profile
viewpoint
Martin Ganahl mganahl Perimeter Institute for Theoretical Physics Canada Tensor con 🚜@ Perimeter Institute

google/TensorNetwork 1328

A library for easy and efficient manipulation of tensor networks.

mhibatallah/RNNWavefunctions 32

A new wavefunction ansatz based on Recurrent Neural Networks to perform Variational Monte-Carlo Simulations

mganahl/PyTeN 2

A library for Matrix Product State calculations

mganahl/MPSTools 1

MPS library for strongly correlated 1D systems

mganahl/evoMPS 0

An implementation of the time dependent variational principle for matrix product states

mganahl/HeisED 0

Arnoldi and Lanczos for the XXZ Heisenberg model in 1 and 2d

mganahl/jax 0

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

mganahl/ncon 0

Tensor network contraction function for Python 3.

push eventmganahl/TensorNetwork

mganahl

commit sha 00ef0db96f73300555a411b3daa641d8f9f45deb

add missing typing

view details

push time in 4 hours

push eventmganahl/TensorNetwork

mganahl

commit sha 42ac8879b7de9262e2cb4a9179fdb91acf9bca4c

fix default of ZNCharge

view details

push time in 4 hours

push eventgoogle/TensorNetwork

gevenbly

commit sha a053904e65b41cb0435002e796a855d67f0fe0a3

added custom path solvers (#674) * added custom path solvers added custom path solvers, providing an alternative to those based on opt_einsum * improved custom path solvers Implemented general improvements and some bug fixes. Added unit tests. * removed previous solver versions remove `solve_functs`, which is now split into `pathsolvers` and `nconinterface` * implement changes suggested by mganahl Improved code linting and formatting. Variable name changes to bring inline with existing library functions. Minor efficiency improvements. * Delete test_nconinterface.py * fixing some linting complaints fixing some linting complaints * Delete test.py * Update nconinterface_test.py * removed commented code * fixed unit tests removed random variables from unit tests. Small formating changes. * Update example.py Co-authored-by: Martin Ganahl <martin.ganahl@gmail.com> Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

push time in 4 hours

PR merged google/TensorNetwork

added custom path solvers cla: yes

Added custom path solvers, providing an alternative to those based on opt_einsum. Includes 3 different solver algorithms (two are variants of a greedy search, one is a full search over all paths), as well as some helper functions that interface the solvers with networks defined in the ncon syntax.

+1006 -0

9 comments

6 changed files

gevenbly

pr closed time in 4 hours

PR opened sebgrijalva/TensorNetwork

some modifications

Hi @sebgrijalva I finally had a look at your PR. Apologies that it took so long. Thanks again for submitting, and for the work, it looks great! I took the liberty to make some modifications to the tutorial, hope you don't mind. You can merge this PR into your master, and then it should show up in the PR for the upstream repo. I'll wait with pulling this in though until we have finisehd some ongoing work on the MPS classes

+70 -85

0 comment

4 changed files

pr created time in 5 hours

push eventmganahl/TensorNetwork

mganahl

commit sha 10dfc0d135b2abd70c58ca92ffc4e361a89d0646

some mods

view details

push time in 5 hours

create barnchmganahl/TensorNetwork

branch : sebgrijalva-master

created branch time in 5 hours

push eventmganahl/TensorNetwork

mganahl

commit sha 739f5317eb718dfd4e6d3219463f161506658e59

typing

view details

push time in 7 hours

push eventmganahl/TensorNetwork

mganahl

commit sha e9e369fc7a794bc5094cb47e679cf356b34bd383

linting

view details

push time in 7 hours

push eventsebgrijalva/TensorNetwork

Martin Ganahl

commit sha 5322a5b1d2dd97c42de200233684fbadefdc95a2

ncon with batching support (#682) * adding sum and matmul to backends * formatting * change to jitted_functions.py * added batched ncon * typo * typo * docstring * docstring * rename routine, change docstring * renaming * fix test * more readable code * typo * linting * allow alphanumeric strings only * test updated * yapf * typo * typo Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Chase Roberts

commit sha a4d0afe55d536c16292414fa4c19a0e38552fee3

Update mpo.py (#689)

view details

Martin Ganahl

commit sha c9b418349f63e91d0ae2f8d162e6cceb57ece0a5

reduce python overhead in ncon a tiny bit (#693) * shaved off a few milliseconds of constant overhead * linting * newlinw * shorten code * typo * shorted check * remove newline

view details

Martin Ganahl

commit sha 00065ffdbbe0d6e4075045888e469e78e06dcf55

fix abstract_backend (#694)

view details

jeyassri balachandran

commit sha 7967a91c7571f1fa6c97954f6c0c16af0ca4c173

issues-686 repr in index.py to return more info (#696) * modified repr in index.py to return dense shape, charges and flow information * added formatting descriptors * fixed linting whitespace * modiefied Index.__repr__ * remove commented code Co-authored-by: mganahl <martin.ganahl@gmail.com>

view details

Martin Ganahl

commit sha 892468e40ce7e7191720d1f2e6b5e973d1c117f5

add eigsh_lanczos to blocksparse backend (#688) * typing * typing * typing * typing. lintingc * linting * typing * wip * transpose_data -> contiguous * wip adding eigsh_lanczos * __matmul__ support for tensors (just like numpy) * extend __matmul__ to tensors * restrict to vectors and matrices * tests adjusted * yapf * wip * wip * wip * testing added * comment * comment * typing * yapf * yapf * linting * typinig changes * typing * remove newline * comment * comment * comment * add type * linting * typing * typing * typing * linting * remove parens * space * remove complex test * remove complex casting * remove parense * add some parense for readability

view details

Martin Ganahl

commit sha 51f9e8bc6bad9143a99ddb652da8eb2c12373a00

shave off a few microns of constant overhead (#697) * shave off a few microns of constant overhead * remove time import * shorted code * typo * remove disable Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Adam Lewis

commit sha 04d4e469ccabce2a05a2a176614410e9e77c6035

Numpy gmres (#704) * Add linalg directory and initialization code * Delete linalg for merge * Add GMRES to numpy backend * Add GMRES to abstract backend and test NotImplemented * Test gmres in all backends, better tests in numpy. Co-authored-by: Martin Ganahl <martin.ganahl@gmail.com>

view details

Adam Lewis

commit sha db81bc0271f212b867ebc5ea115d443bab40fbf2

Jax gmres (#707) * Add linalg directory and initialization code * Delete linalg for merge * Jax GMRES and tests * Correct typing of ShapedArray * Correct typing in gmres.py * Refactor GMRES to use existing Arnoldi * Fix obsolete test from merge * Cleanups for PR * Correct docstring * Fix caching Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Chase Roberts

commit sha bb440f1b66f6722b675cf3197ca2bd620cf35284

Update version.py Minor release

view details

Martin Ganahl

commit sha de2e0c3a6a03ffc7c15670f007970734cc2bd811

Merge branch 'master' into master

view details

push time in 7 hours

push eventgevenbly/TensorNetwork

Chase Roberts

commit sha bb440f1b66f6722b675cf3197ca2bd620cf35284

Update version.py Minor release

view details

Martin Ganahl

commit sha 1d9abf5df11e53d02e92809c766a5adb557772ad

Merge branch 'master' into master

view details

push time in 7 hours

push eventmganahl/TensorNetwork

Chase Roberts

commit sha bb440f1b66f6722b675cf3197ca2bd620cf35284

Update version.py Minor release

view details

Martin Ganahl

commit sha a27aea61e986f44a931202348b31e1da9f9ea3df

Merge branch 'master' into diag_backend

view details

push time in 7 hours

push eventmganahl/TensorNetwork

mganahl

commit sha 093c79443aab217c473b01006e3d74b665a0a20b

remove newlinw

view details

push time in 7 hours

push eventmganahl/TensorNetwork

mganahl

commit sha 9ee676420a587ddbe428f71063e790c442cbe2b8

mnimize changes

view details

push time in 7 hours

PR opened google/TensorNetwork

bug fixes, some improvements
+9 -13

0 comment

1 changed file

pr created time in 7 hours

push eventmganahl/TensorNetwork

mganahl

commit sha 9cecf0a8e6af969dd076c5f8490c95ccb82aae28

bug fixes, some improvements

view details

push time in 7 hours

create barnchmganahl/TensorNetwork

branch : bugfix_blocksparse

created branch time in 7 hours

push eventmganahl/TensorNetwork

Adam Lewis

commit sha db81bc0271f212b867ebc5ea115d443bab40fbf2

Jax gmres (#707) * Add linalg directory and initialization code * Delete linalg for merge * Jax GMRES and tests * Correct typing of ShapedArray * Correct typing in gmres.py * Refactor GMRES to use existing Arnoldi * Fix obsolete test from merge * Cleanups for PR * Correct docstring * Fix caching Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Chase Roberts

commit sha bb440f1b66f6722b675cf3197ca2bd620cf35284

Update version.py Minor release

view details

mganahl

commit sha 301b42a1c9d2d33d44c2432ff39e5d550db8dd42

WIP

view details

mganahl

commit sha 31d80c3086c369f6246d743ebd735677af900876

Merge branch 'master' into new_charge_encoding

view details

push time in a day

push eventmganahl/TensorNetwork

mganahl

commit sha 1253b83069306d64bb538a48d39d6bf291096629

fix import

view details

push time in a day

push eventmganahl/TensorNetwork

Adam Lewis

commit sha db81bc0271f212b867ebc5ea115d443bab40fbf2

Jax gmres (#707) * Add linalg directory and initialization code * Delete linalg for merge * Jax GMRES and tests * Correct typing of ShapedArray * Correct typing in gmres.py * Refactor GMRES to use existing Arnoldi * Fix obsolete test from merge * Cleanups for PR * Correct docstring * Fix caching Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Chase Roberts

commit sha bb440f1b66f6722b675cf3197ca2bd620cf35284

Update version.py Minor release

view details

push time in a day

pull request commentgoogle/TensorNetwork

Jax gmres

looks good, I'll pull it in as soon as building's finished

alewis

comment created time in a day

push eventmganahl/TensorNetwork

mganahl

commit sha 3f890398168a1a60bafb5c8f8948ee1bfe6b2efc

np.sum -> jnp.sum in JaxBackend

view details

Adam Lewis

commit sha 04d4e469ccabce2a05a2a176614410e9e77c6035

Numpy gmres (#704) * Add linalg directory and initialization code * Delete linalg for merge * Add GMRES to numpy backend * Add GMRES to abstract backend and test NotImplemented * Test gmres in all backends, better tests in numpy. Co-authored-by: Martin Ganahl <martin.ganahl@gmail.com>

view details

mganahl

commit sha a438867a5cee7dc4731a729125fa912c607c264f

WIP

view details

mganahl

commit sha 7421305661137cc9e398d0266ef6a33d30396883

WIP

view details

mganahl

commit sha 9aa1a8a4e4dadbf7fc28f8f043b00d56d5918bfa

WIP

view details

mganahl

commit sha b18253584a4bfa3e0ad56551e8bc44d15cb90627

Merge branch 'master' into new_charge_encoding

view details

push time in a day

pull request commentgoogle/TensorNetwork

Jax gmres

sorry can't see it

alewis

comment created time in a day

pull request commentgoogle/TensorNetwork

Jax gmres


import jax
import tensornetwork as tn
import numpy as np
import time
D=200
dtype=np.float32
matrix = jax.numpy.array(np.random.rand(D,D).astype(dtype))
vector = jax.numpy.array(np.random.rand(D,).astype(dtype))
@jax.jit
def matvec_jax_matrix(vec,matrix):
    return jax.numpy.tensordot(matrix, vec,([1],[0]))
jax_backend = tn.backends.jax.jax_backend.JaxBackend()
ncv=10
t1 = time.time()
eta_j, U_j = jax_backend.eigsh_lanczos(matvec_jax_matrix,[matrix],vector,num_krylov_vecs = ncv,numeig=1, 
                                       reorthogonalize=False)
print('jax eigvals:', eta_j)
t2 = time.time()
eta_j, U_j = jax_backend.eigsh_lanczos(matvec_jax_matrix,[matrix],vector,num_krylov_vecs = ncv,numeig=1, 
                                       reorthogonalize=False)

print('jax eigvals:', eta_j)
t3 = time.time()
print('jax first:', t2 - t1)
print('jax second:', t3 - t2)

could you run this and send the result?

alewis

comment created time in a day

pull request commentgoogle/TensorNetwork

Jax gmres

Hey @alewis, did you check that your modifications to eigsh_lanczos and eigs do not trigger unneccesary tracing? I checkout out your PR and on my laptop it seems it does

alewis

comment created time in a day

push eventmganahl/TensorNetwork

Adam Lewis

commit sha 04d4e469ccabce2a05a2a176614410e9e77c6035

Numpy gmres (#704) * Add linalg directory and initialization code * Delete linalg for merge * Add GMRES to numpy backend * Add GMRES to abstract backend and test NotImplemented * Test gmres in all backends, better tests in numpy. Co-authored-by: Martin Ganahl <martin.ganahl@gmail.com>

view details

push time in a day

Pull request review commentgoogle/TensorNetwork

Jax gmres

 from functools import partial-from typing import List, Any, Tuple+from typing import List, Any, Tuple, Callable, Sequence import numpy as np Tensor = Any  def _generate_jitted_eigsh_lanczos(jax):   """-  Helper function to generate jitted lanczos function used -  in JaxBackend.eigsh_lanczos. The function `jax_lanczos` +  Helper function to generate jitted lanczos function used

no, just curious. If you want it removed, then it would be great if you could prepare a PR that fixes this for the whole library (just run your script over all files). This way we'll have a cleaner code history

alewis

comment created time in a day

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def implicitly_restarted_arnoldi_method(     ]    return implicitly_restarted_arnoldi_method+++def gmres_wrapper(jax):+  """+  Allows Jax (the module) to be passed in as an argument rather than imported,+  since doing the latter breaks the build.+  Args:+    jax: The imported Jax module.+  Returns:+    gmres: A function performing gmres_m as described below.+  """+  jnp = jax.numpy++  def gmres_m(A_mv: Callable, A_args: Sequence,+              b: jax.ShapedArray, x0: jax.ShapedArray, tol: float, atol: float,+              num_krylov_vectors: int,+              maxiter: int) -> Tuple[jax.ShapedArray, float, int, bool]:+    """+    Solve A x = b for x using the m-restarted GMRES method. This is+    intended to be called via jax_backend.gmres.++    Given a linear mapping with (n x n) matrix representation+        A = A_mv(*A_args) gmres_m solves+        Ax = b          (1)+    where x and b are length-b vectors, using the method of+    Generalized Minimum RESiduals with M iterations per restart (GMRES_M).++    Args:++    A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+               `v` have the same shape.+    b        : The `b` in `A @ x = b`.

looks like A_args is missing from the docstring

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def A(H,x):           type(initial_state)))     if A not in _CACHED_MATVECS:       _CACHED_MATVECS[A] = libjax.tree_util.Partial(A)-    if not hasattr(self, '_jaxlan'):-      # pylint: disable=attribute-defined-outside-init-      self._jaxlan = jitted_functions._generate_jitted_eigsh_lanczos(libjax)+    eigsh_lanczos = jitted_functions._generate_jitted_eigsh_lanczos(libjax)+    return eigsh_lanczos(_CACHED_MATVECS[A], args, initial_state,+                         num_krylov_vecs, numeig, delta, reorthogonalize)++  def gmres(self,+            A_mv: Callable,+            b: Tensor,+            A_args: Optional[List] = None,+            A_kwargs: Optional[dict] = None,+            x0: Optional[Tensor] = None,+            tol: float = 1E-05,+            atol: Optional[float] = None,+            num_krylov_vectors: Optional[int] = None,+            maxiter: Optional[int] = 1,

isn't the default 1?

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 from functools import partial-from typing import List, Any, Tuple+from typing import List, Any, Tuple, Callable, Sequence import numpy as np Tensor = Any  def _generate_jitted_eigsh_lanczos(jax):   """-  Helper function to generate jitted lanczos function used -  in JaxBackend.eigsh_lanczos. The function `jax_lanczos` +  Helper function to generate jitted lanczos function used

why?

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def test_matmul():   actual = backend.matmul(a, b)   expected = np.matmul(t1, t2)   np.testing.assert_allclose(expected, actual)+++def test_gmres_raises():+  backend = jax_backend.JaxBackend()+  dummy_mv = lambda x: x+  N = 10++  b = jax.numpy.zeros((N,))+  x0 = jax.numpy.zeros((N+1),)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different sizes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,), dtype=jax.numpy.float32)+  b = jax.numpy.zeros((N,), dtype=jax.numpy.float64)+  diff = (f"If x0 is supplied, its dtype, {x0.dtype}, must match b's"+          f", {b.dtype}.")+  with pytest.raises(ValueError, match=diff): # x0, b have different dtypes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,))+  b = jax.numpy.zeros((N,)).reshape(2, N//2)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different shapes+    backend.gmres(dummy_mv, b, x0=x0)++  num_krylov_vectors = 0+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors <= 0+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)+  num_krylov_vectors = N+1+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors > b.size+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)++  tol = -1.+  diff = (f"tol = {tol} must be positive.")+  with pytest.raises(ValueError, match=diff): # tol < 0+    backend.gmres(dummy_mv, b, tol=tol)++  atol = -1+  diff = (f"atol = {atol} must be positive.")+  with pytest.raises(ValueError, match=diff): # atol < 0+    backend.gmres(dummy_mv, b, atol=atol)++  M = lambda x: x+  diff = "M is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, M=M)++  A_kwargs = {"bee": "honey"}+  diff = "A_kwargs is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, A_kwargs=A_kwargs)+++#jax_qr_dtypes = [np.float32, np.float64, np.complex64, np.complex128]+jax_qr_dtypes = [np.float32]+@pytest.mark.parametrize("dtype", jax_qr_dtypes)+def test_gmres_on_small_known_problem(dtype):+  dummy = jax.numpy.zeros(1, dtype=dtype)+  dtype = dummy.dtype++  backend = jax_backend.JaxBackend()+  A = jax.numpy.array(([[1, 1], [3, -4]]), dtype=dtype)+  b = jax.numpy.array([3, 2], dtype=dtype)+  x0 = jax.numpy.ones(2, dtype=dtype)+  n_kry = 2++  def A_mv(x):+    return A @ x+  tol = 100*jax.numpy.finfo(dtype).eps+  x, _ = backend.gmres(A_mv, b, x0=x0, num_krylov_vectors=n_kry, tol=tol)+  solution = jax.numpy.array([2., 1.], dtype=dtype)+  eps = jax.numpy.linalg.norm(jax.numpy.abs(solution) - jax.numpy.abs(x))+  print(eps)+  assert eps < tol+++@pytest.mark.parametrize("dtype", jax_qr_dtypes)+def test_gmres_on_larger_random_problem(dtype):+  dummy = jax.numpy.zeros(1, dtype=dtype)+  dtype = dummy.dtype+  backend = jax_backend.JaxBackend()+  matshape = (100, 100)+  vecshape = (100,)+  A = backend.randn(matshape, dtype=dtype)+  solution = backend.randn(vecshape, dtype=dtype)+  def A_mv(x):+    return A @ x+  b = A_mv(solution)+  tol = b.size * jax.numpy.finfo(dtype).eps

good to know!

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def test_matmul():   actual = backend.matmul(a, b)   expected = np.matmul(t1, t2)   np.testing.assert_allclose(expected, actual)+++def test_gmres_raises():+  backend = jax_backend.JaxBackend()+  dummy_mv = lambda x: x+  N = 10++  b = jax.numpy.zeros((N,))+  x0 = jax.numpy.zeros((N+1),)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different sizes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,), dtype=jax.numpy.float32)+  b = jax.numpy.zeros((N,), dtype=jax.numpy.float64)+  diff = (f"If x0 is supplied, its dtype, {x0.dtype}, must match b's"+          f", {b.dtype}.")+  with pytest.raises(ValueError, match=diff): # x0, b have different dtypes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,))+  b = jax.numpy.zeros((N,)).reshape(2, N//2)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different shapes+    backend.gmres(dummy_mv, b, x0=x0)++  num_krylov_vectors = 0+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors <= 0+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)+  num_krylov_vectors = N+1+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors > b.size+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)++  tol = -1.+  diff = (f"tol = {tol} must be positive.")+  with pytest.raises(ValueError, match=diff): # tol < 0+    backend.gmres(dummy_mv, b, tol=tol)++  atol = -1+  diff = (f"atol = {atol} must be positive.")+  with pytest.raises(ValueError, match=diff): # atol < 0+    backend.gmres(dummy_mv, b, atol=atol)++  M = lambda x: x+  diff = "M is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, M=M)++  A_kwargs = {"bee": "honey"}+  diff = "A_kwargs is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, A_kwargs=A_kwargs)+++#jax_qr_dtypes = [np.float32, np.float64, np.complex64, np.complex128]+jax_qr_dtypes = [np.float32]+@pytest.mark.parametrize("dtype", jax_qr_dtypes)+def test_gmres_on_small_known_problem(dtype):+  dummy = jax.numpy.zeros(1, dtype=dtype)+  dtype = dummy.dtype++  backend = jax_backend.JaxBackend()+  A = jax.numpy.array(([[1, 1], [3, -4]]), dtype=dtype)+  b = jax.numpy.array([3, 2], dtype=dtype)+  x0 = jax.numpy.ones(2, dtype=dtype)+  n_kry = 2++  def A_mv(x):+    return A @ x+  tol = 100*jax.numpy.finfo(dtype).eps+  x, _ = backend.gmres(A_mv, b, x0=x0, num_krylov_vectors=n_kry, tol=tol)+  solution = jax.numpy.array([2., 1.], dtype=dtype)+  eps = jax.numpy.linalg.norm(jax.numpy.abs(solution) - jax.numpy.abs(x))+  print(eps)+  assert eps < tol+++@pytest.mark.parametrize("dtype", jax_qr_dtypes)

test other dtypes (float64, complex64, complex128)

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def test_matmul():   actual = backend.matmul(a, b)   expected = np.matmul(t1, t2)   np.testing.assert_allclose(expected, actual)+++def test_gmres_raises():+  backend = jax_backend.JaxBackend()+  dummy_mv = lambda x: x+  N = 10++  b = jax.numpy.zeros((N,))+  x0 = jax.numpy.zeros((N+1),)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different sizes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,), dtype=jax.numpy.float32)+  b = jax.numpy.zeros((N,), dtype=jax.numpy.float64)+  diff = (f"If x0 is supplied, its dtype, {x0.dtype}, must match b's"+          f", {b.dtype}.")+  with pytest.raises(ValueError, match=diff): # x0, b have different dtypes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,))+  b = jax.numpy.zeros((N,)).reshape(2, N//2)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different shapes+    backend.gmres(dummy_mv, b, x0=x0)++  num_krylov_vectors = 0+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors <= 0+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)+  num_krylov_vectors = N+1+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors > b.size+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)++  tol = -1.+  diff = (f"tol = {tol} must be positive.")+  with pytest.raises(ValueError, match=diff): # tol < 0+    backend.gmres(dummy_mv, b, tol=tol)++  atol = -1+  diff = (f"atol = {atol} must be positive.")+  with pytest.raises(ValueError, match=diff): # atol < 0+    backend.gmres(dummy_mv, b, atol=atol)++  M = lambda x: x+  diff = "M is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, M=M)++  A_kwargs = {"bee": "honey"}+  diff = "A_kwargs is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, A_kwargs=A_kwargs)+++#jax_qr_dtypes = [np.float32, np.float64, np.complex64, np.complex128]+jax_qr_dtypes = [np.float32]+@pytest.mark.parametrize("dtype", jax_qr_dtypes)+def test_gmres_on_small_known_problem(dtype):+  dummy = jax.numpy.zeros(1, dtype=dtype)+  dtype = dummy.dtype++  backend = jax_backend.JaxBackend()+  A = jax.numpy.array(([[1, 1], [3, -4]]), dtype=dtype)+  b = jax.numpy.array([3, 2], dtype=dtype)+  x0 = jax.numpy.ones(2, dtype=dtype)+  n_kry = 2++  def A_mv(x):+    return A @ x+  tol = 100*jax.numpy.finfo(dtype).eps+  x, _ = backend.gmres(A_mv, b, x0=x0, num_krylov_vectors=n_kry, tol=tol)+  solution = jax.numpy.array([2., 1.], dtype=dtype)+  eps = jax.numpy.linalg.norm(jax.numpy.abs(solution) - jax.numpy.abs(x))+  print(eps)

remove print

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def test_matmul():   actual = backend.matmul(a, b)   expected = np.matmul(t1, t2)   np.testing.assert_allclose(expected, actual)+++def test_gmres_raises():+  backend = jax_backend.JaxBackend()+  dummy_mv = lambda x: x+  N = 10++  b = jax.numpy.zeros((N,))+  x0 = jax.numpy.zeros((N+1),)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different sizes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,), dtype=jax.numpy.float32)+  b = jax.numpy.zeros((N,), dtype=jax.numpy.float64)+  diff = (f"If x0 is supplied, its dtype, {x0.dtype}, must match b's"+          f", {b.dtype}.")+  with pytest.raises(ValueError, match=diff): # x0, b have different dtypes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,))+  b = jax.numpy.zeros((N,)).reshape(2, N//2)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different shapes+    backend.gmres(dummy_mv, b, x0=x0)++  num_krylov_vectors = 0+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors <= 0+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)+  num_krylov_vectors = N+1+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors > b.size+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)++  tol = -1.+  diff = (f"tol = {tol} must be positive.")+  with pytest.raises(ValueError, match=diff): # tol < 0+    backend.gmres(dummy_mv, b, tol=tol)++  atol = -1+  diff = (f"atol = {atol} must be positive.")+  with pytest.raises(ValueError, match=diff): # atol < 0+    backend.gmres(dummy_mv, b, atol=atol)++  M = lambda x: x+  diff = "M is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, M=M)++  A_kwargs = {"bee": "honey"}+  diff = "A_kwargs is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, A_kwargs=A_kwargs)+++#jax_qr_dtypes = [np.float32, np.float64, np.complex64, np.complex128]+jax_qr_dtypes = [np.float32]

and complex if possible

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def test_matmul():   actual = backend.matmul(a, b)   expected = np.matmul(t1, t2)   np.testing.assert_allclose(expected, actual)+++def test_gmres_raises():+  backend = jax_backend.JaxBackend()+  dummy_mv = lambda x: x+  N = 10++  b = jax.numpy.zeros((N,))+  x0 = jax.numpy.zeros((N+1),)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different sizes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,), dtype=jax.numpy.float32)+  b = jax.numpy.zeros((N,), dtype=jax.numpy.float64)+  diff = (f"If x0 is supplied, its dtype, {x0.dtype}, must match b's"+          f", {b.dtype}.")+  with pytest.raises(ValueError, match=diff): # x0, b have different dtypes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,))+  b = jax.numpy.zeros((N,)).reshape(2, N//2)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different shapes+    backend.gmres(dummy_mv, b, x0=x0)++  num_krylov_vectors = 0+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors <= 0+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)+  num_krylov_vectors = N+1+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors > b.size+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)++  tol = -1.+  diff = (f"tol = {tol} must be positive.")+  with pytest.raises(ValueError, match=diff): # tol < 0+    backend.gmres(dummy_mv, b, tol=tol)++  atol = -1+  diff = (f"atol = {atol} must be positive.")+  with pytest.raises(ValueError, match=diff): # atol < 0+    backend.gmres(dummy_mv, b, atol=atol)++  M = lambda x: x+  diff = "M is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, M=M)++  A_kwargs = {"bee": "honey"}+  diff = "A_kwargs is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, A_kwargs=A_kwargs)+++#jax_qr_dtypes = [np.float32, np.float64, np.complex64, np.complex128]+jax_qr_dtypes = [np.float32]

Also test for float64

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def test_matmul():   actual = backend.matmul(a, b)   expected = np.matmul(t1, t2)   np.testing.assert_allclose(expected, actual)+++def test_gmres_raises():+  backend = jax_backend.JaxBackend()+  dummy_mv = lambda x: x+  N = 10++  b = jax.numpy.zeros((N,))+  x0 = jax.numpy.zeros((N+1),)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different sizes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,), dtype=jax.numpy.float32)+  b = jax.numpy.zeros((N,), dtype=jax.numpy.float64)+  diff = (f"If x0 is supplied, its dtype, {x0.dtype}, must match b's"+          f", {b.dtype}.")+  with pytest.raises(ValueError, match=diff): # x0, b have different dtypes+    backend.gmres(dummy_mv, b, x0=x0)++  x0 = jax.numpy.zeros((N,))+  b = jax.numpy.zeros((N,)).reshape(2, N//2)+  diff = "If x0 is supplied, its shape"+  with pytest.raises(ValueError, match=diff): # x0, b have different shapes+    backend.gmres(dummy_mv, b, x0=x0)++  num_krylov_vectors = 0+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors <= 0+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)+  num_krylov_vectors = N+1+  diff = (f"num_krylov_vectors must be in "+          f"0 < {num_krylov_vectors} <= {b.size}")+  with pytest.raises(ValueError, match=diff): # num_krylov_vectors > b.size+    backend.gmres(dummy_mv, b, num_krylov_vectors=num_krylov_vectors)++  tol = -1.+  diff = (f"tol = {tol} must be positive.")+  with pytest.raises(ValueError, match=diff): # tol < 0+    backend.gmres(dummy_mv, b, tol=tol)++  atol = -1+  diff = (f"atol = {atol} must be positive.")+  with pytest.raises(ValueError, match=diff): # atol < 0+    backend.gmres(dummy_mv, b, atol=atol)++  M = lambda x: x+  diff = "M is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, M=M)++  A_kwargs = {"bee": "honey"}+  diff = "A_kwargs is not supported by the Jax backend."+  with pytest.raises(NotImplementedError, match=diff):+    backend.gmres(dummy_mv, b, A_kwargs=A_kwargs)+++#jax_qr_dtypes = [np.float32, np.float64, np.complex64, np.complex128]

remove commented code

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def A(H,x):           type(initial_state)))     if A not in _CACHED_MATVECS:       _CACHED_MATVECS[A] = libjax.tree_util.Partial(A)-    if not hasattr(self, '_jaxlan'):-      # pylint: disable=attribute-defined-outside-init-      self._jaxlan = jitted_functions._generate_jitted_eigsh_lanczos(libjax)+    eigsh_lanczos = jitted_functions._generate_jitted_eigsh_lanczos(libjax)+    return eigsh_lanczos(_CACHED_MATVECS[A], args, initial_state,+                         num_krylov_vecs, numeig, delta, reorthogonalize)++  def gmres(self,+            A_mv: Callable,+            b: Tensor,+            A_args: Optional[List] = None,+            A_kwargs: Optional[dict] = None,+            x0: Optional[Tensor] = None,+            tol: float = 1E-05,+            atol: Optional[float] = None,+            num_krylov_vectors: Optional[int] = None,+            maxiter: Optional[int] = 1,

remove Optional

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def implicitly_restarted_arnoldi_method(     ]    return implicitly_restarted_arnoldi_method+++def gmres_wrapper(jax):+  """+  Allows Jax (the module) to be passed in as an argument rather than imported,+  since doing the latter breaks the build.+  Args:+    jax: The imported Jax module.+  Returns:+    gmres: A function performing gmres_m as described below.+  """+  jnp = jax.numpy++  def gmres_m(A_mv: Callable, A_args: Sequence,+              b: jax.ShapedArray, x0: jax.ShapedArray, tol: float, atol: float,+              num_krylov_vectors: int,+              maxiter: int) -> Tuple[jax.ShapedArray, float, int, bool]:+    """+    Solve A x = b for x using the m-restarted GMRES method. This is+    intended to be called via jax_backend.gmres.++    Given a linear mapping with (n x n) matrix representation+        A = A_mv(*A_args) gmres_m solves+        Ax = b          (1)+    where x and b are length-b vectors, using the method of+    Generalized Minimum RESiduals with M iterations per restart (GMRES_M).++    Args:++    A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+               `v` have the same shape.+    b        : The `b` in `A @ x = b`.+    A_args   : Positional arguments to `A_mv`.+    x0       : Initial guess solution.+    tol, atol: Solution tolerance to achieve,+               norm(residual) <= max(tol*norm(b), atol).+               tol is also used to set the threshold at which the Arnoldi+               factorization terminates.+    num_krylov_vectors+             : Size of the Krylov space to build at each restart.+    maxiter  : The Krylov space will be repeatedly rebuilt up to this many+               times.+++    RETURNS+    -------+    x (array, (n,)) : The approximate solution.+    beta (float)    : Norm of the residual at termination.+    n_iter (int)    : Number of iterations at termination.+    converged (bool): Whether the desired tolerance was achieved.+    """+    x = x0+    converged = False+    r, beta = gmres_residual(A_mv, A_args, b, x)+    b_norm = jnp.linalg.norm(b)+    for n_iter in range(maxiter):+      # pylint: disable=too-many-function-args+      x = gmres(A_mv, A_args, num_krylov_vectors, x, r, beta, tol)+      r, beta = gmres_residual(A_mv, A_args, b, x)+      if beta <= max(tol*b_norm, atol):+        converged = True+        break+    return (x, beta, n_iter, converged)++  @jax.jit+  def gmres_residual(A_mv: Callable, A_args: Sequence, b: jax.ShapedArray,+                     x: jax.ShapedArray) -> Tuple[jax.ShapedArray, float]:+    """+    Computes the residual vector r and its norm, beta, which is minimized by+    GMRES.+    """+    r = b - A_mv(x, *A_args)+    beta = jnp.linalg.norm(r)+    return r, beta+++  #@partial(jax.jit, static_argnums=(2,))

why is the jit command commented?

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def implicitly_restarted_arnoldi_method(     ]    return implicitly_restarted_arnoldi_method+++def gmres_wrapper(jax):+  """+  Allows Jax (the module) to be passed in as an argument rather than imported,+  since doing the latter breaks the build.+  Args:+    jax: The imported Jax module.+  Returns:+    gmres: A function performing gmres_m as described below.+  """+  jnp = jax.numpy++  def gmres_m(A_mv: Callable, A_args: Sequence,+              b: jax.ShapedArray, x0: jax.ShapedArray, tol: float, atol: float,+              num_krylov_vectors: int,+              maxiter: int) -> Tuple[jax.ShapedArray, float, int, bool]:+    """+    Solve A x = b for x using the m-restarted GMRES method. This is+    intended to be called via jax_backend.gmres.++    Given a linear mapping with (n x n) matrix representation+        A = A_mv(*A_args) gmres_m solves+        Ax = b          (1)+    where x and b are length-b vectors, using the method of+    Generalized Minimum RESiduals with M iterations per restart (GMRES_M).++    Args:++    A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+               `v` have the same shape.+    b        : The `b` in `A @ x = b`.

reverse order b and A_args

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def implicitly_restarted_arnoldi_method(     ]    return implicitly_restarted_arnoldi_method+++def gmres_wrapper(jax):+  """+  Allows Jax (the module) to be passed in as an argument rather than imported,+  since doing the latter breaks the build.+  Args:+    jax: The imported Jax module.+  Returns:+    gmres: A function performing gmres_m as described below.+  """+  jnp = jax.numpy++  def gmres_m(A_mv: Callable, A_args: Sequence,+              b: jax.ShapedArray, x0: jax.ShapedArray, tol: float, atol: float,+              num_krylov_vectors: int,+              maxiter: int) -> Tuple[jax.ShapedArray, float, int, bool]:+    """+    Solve A x = b for x using the m-restarted GMRES method. This is+    intended to be called via jax_backend.gmres.++    Given a linear mapping with (n x n) matrix representation+        A = A_mv(*A_args) gmres_m solves+        Ax = b          (1)+    where x and b are length-b vectors, using the method of

length-n vectors?

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def body(vals):       return [krylov_vectors, H, matvec, Av, norm, threshold, i + 1, maxiter]      def cond_fun(vals):-      _, _, _, _, norm, threshold, iteration, maxiter = vals--      # check if an invariant subspace has been found-      def check_thresh(check_vals):-        val, thresh = check_vals-        return jax.lax.cond(val < thresh, False, lambda x: x, True, lambda x: x)--      return jax.lax.cond(iteration < maxiter, [norm, threshold], check_thresh,-                          False, lambda x: x)--    norms_dtype = np.real(v0.dtype).dtype-    kvfinal, Hfinal, _, _, norm, _, it, _ = jax.lax.while_loop(-        cond_fun, body, [-            krylov_vectors, H, matvec, v,-            norms_dtype.type(1E3), eps, start, num_krylov_vecs-        ])+      # Continue loop while iteration < num_krylov_vecs and norm > eps+      _, _, _, _, norm, _, iteration, _ = vals+      counter_done = (iteration >= num_krylov_vecs)+      norm_not_too_small = norm > eps+      continue_iteration = jax.lax.cond(counter_done,+                                        _, lambda x: False,+                                        _, lambda x: norm_not_too_small)++      return continue_iteration+    initial_norm_typecaster = np.zeros((1,)) + eps + 1.

you can do initial_norm = v.real.dtype.type(1.0+eps) instead

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 from functools import partial-from typing import List, Any, Tuple+from typing import List, Any, Tuple, Callable, Sequence import numpy as np Tensor = Any  def _generate_jitted_eigsh_lanczos(jax):   """-  Helper function to generate jitted lanczos function used -  in JaxBackend.eigsh_lanczos. The function `jax_lanczos` +  Helper function to generate jitted lanczos function used

Did you remove trailing white space on purpose, or is this from yapfing?

alewis

comment created time in 2 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

 def A(H,x):           type(initial_state)))     if A not in _CACHED_MATVECS:       _CACHED_MATVECS[A] = libjax.tree_util.Partial(libjax.jit(A))-    if not hasattr(self, '_iram'):-      # pylint: disable=attribute-defined-outside-init-      self._iram = jitted_functions._implicitly_restarted_arnoldi(libjax)-    return self._iram(_CACHED_MATVECS[A], args, initial_state, num_krylov_vecs,-                      numeig, which, tol, maxiter)+    imp_arnoldi = jitted_functions._implicitly_restarted_arnoldi(libjax)

love that name!

alewis

comment created time in 2 days

push eventmganahl/TensorNetwork

Martin Ganahl

commit sha 892468e40ce7e7191720d1f2e6b5e973d1c117f5

add eigsh_lanczos to blocksparse backend (#688) * typing * typing * typing * typing. lintingc * linting * typing * wip * transpose_data -> contiguous * wip adding eigsh_lanczos * __matmul__ support for tensors (just like numpy) * extend __matmul__ to tensors * restrict to vectors and matrices * tests adjusted * yapf * wip * wip * wip * testing added * comment * comment * typing * yapf * yapf * linting * typinig changes * typing * remove newline * comment * comment * comment * add type * linting * typing * typing * typing * linting * remove parens * space * remove complex test * remove complex casting * remove parense * add some parense for readability

view details

Martin Ganahl

commit sha 51f9e8bc6bad9143a99ddb652da8eb2c12373a00

shave off a few microns of constant overhead (#697) * shave off a few microns of constant overhead * remove time import * shorted code * typo * remove disable Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Martin Ganahl

commit sha 2d516b7b441654ef3ed4e931312ddb30bb294cb5

Merge branch 'master' into fix_printing

view details

push time in 3 days

push eventalewis/TensorNetwork

Martin Ganahl

commit sha 51f9e8bc6bad9143a99ddb652da8eb2c12373a00

shave off a few microns of constant overhead (#697) * shave off a few microns of constant overhead * remove time import * shorted code * typo * remove disable Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Martin Ganahl

commit sha bc4d5cb0ceb0ee8d121ae231f830df5953185cad

Merge branch 'master' into numpy_gmres

view details

push time in 3 days

issue commentgoogle/TensorNetwork

Numpy calls in non-numpy backends

Thanks Adam, yes the jax-backend erroneously calls np.sum, this should be jnp.sum

alewis

comment created time in 6 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

+from functools import partial+from typing import Any, Optional, Tuple, Callable, List, Text, Type, Sequence+import jax+import jax.numpy as jnp++def gmres_m(A_mv: Callable, A_args: Sequence, +            b: jnp.ShapedArray, x0: jnp.ShapedArray, tol: float, atol: float,+            num_krylov_vectors: int,+            maxiter: int) -> Tuple[jnp.ShapedArray, float, int, bool]:+  """+  Solve A x = b for x using the m-restarted GMRES method. This is+  intended to be called via jax_backend.gmres.++  Given a linear mapping with (n x n) matrix representation+      A = A_mv(*A_args) gmres_m solves+      Ax = b          (1)+  where x and b are length-b vectors, using the method of+  Generalized Minimum RESiduals with M iterations per restart (GMRES_M).++  Args:++  A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+             `v` have the same shape.+  b        : The `b` in `A @ x = b`.+  A_args   : Positional arguments to `A_mv`.+  x0       : Initial guess solution.+  tol, atol: Solution tolerance to achieve,+             norm(residual) <= max(tol*norm(b), atol).+  num_krylov_vectors+           : Size of the Krylov space to build at each restart.+  maxiter  : The Krylov space will be repeatedly rebuilt up to this many+             times.+++  RETURNS+  -------+  x (array, (n,)) : The approximate solution.+  beta (float)    : Norm of the residual at termination.+  n_iter (int)    : Number of iterations at termination.+  converged (bool): Whether the desired tolerance was achieved.+  """+  num_krylov_vectors += 1+  x = x0+  converged = False+  r, beta = gmres_residual(A_mv, A_args, b, x)+  b_norm = jnp.linalg.norm(b)+  for n_iter in range(maxiter):+    # pylint: disable=too-many-function-args+    x = gmres(A_mv, A_args, num_krylov_vectors, x, r, beta)+    r, beta = gmres_residual(A_mv, A_args, b, x)+    if beta <= max(tol*b_norm, atol):+      converged = True+      break+  return (x, beta, n_iter, converged)+++@jax.jit+def gmres_residual(A_mv: Callable, A_args: Sequence, b: jnp.ShapedArray,+                   x: jnp.ShapedArray) -> Tuple[jnp.ShapedArray, float]:+  """+  Computes the residual vector r and its norm, beta, which is minimized by+  GMRES.+  """+  r = b - A_mv(x, *A_args)+  beta = jnp.linalg.norm(r)+  return r, beta+++@partial(jax.jit, static_argnums=(2,))+def gmres(A_mv: Callable, A_args: Sequence, n_kry: int,+          x0: jnp.ShapedArray, r: jnp.ShapedArray,+          beta: float) -> jnp.ShapedArray:+  """+  Solve A x = b for x by the unrestarted GMRES method.+  Given A, a trial solution x, the residual r,+  and the size n_kry of the Krylov space, iterates x towards the solution,+  by finding y in x = x_0 + V y minimizing ||beta - H y||.+  """+  v = r / beta+  Vk_1, Htilde = gmres_arnoldi(A_mv, A_args, n_kry, v)+  Q, Rtilde = jnp.linalg.qr(Htilde, mode="complete")+  Q = Q.T.conj()+  R = Rtilde[:-1, :]+  g = beta*jnp.ravel(Q[:-1, 0])+  y = jax.scipy.linalg.solve_triangular(R, g)+  update = Vk_1[:, :-1] @ y+  x = x0 + update+  return x+++@partial(jax.jit, static_argnums=(2,))

would it make sense to move gmres into the jitted_functions.py file?

alewis

comment created time in 6 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

+from functools import partial+from typing import Any, Optional, Tuple, Callable, List, Text, Type, Sequence+import jax+import jax.numpy as jnp++def gmres_m(A_mv: Callable, A_args: Sequence, +            b: jnp.ShapedArray, x0: jnp.ShapedArray, tol: float, atol: float,+            num_krylov_vectors: int,+            maxiter: int) -> Tuple[jnp.ShapedArray, float, int, bool]:+  """+  Solve A x = b for x using the m-restarted GMRES method. This is+  intended to be called via jax_backend.gmres.++  Given a linear mapping with (n x n) matrix representation+      A = A_mv(*A_args) gmres_m solves+      Ax = b          (1)+  where x and b are length-b vectors, using the method of+  Generalized Minimum RESiduals with M iterations per restart (GMRES_M).++  Args:++  A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+             `v` have the same shape.+  b        : The `b` in `A @ x = b`.+  A_args   : Positional arguments to `A_mv`.+  x0       : Initial guess solution.+  tol, atol: Solution tolerance to achieve,+             norm(residual) <= max(tol*norm(b), atol).+  num_krylov_vectors+           : Size of the Krylov space to build at each restart.+  maxiter  : The Krylov space will be repeatedly rebuilt up to this many+             times.+++  RETURNS+  -------+  x (array, (n,)) : The approximate solution.+  beta (float)    : Norm of the residual at termination.+  n_iter (int)    : Number of iterations at termination.+  converged (bool): Whether the desired tolerance was achieved.+  """+  num_krylov_vectors += 1+  x = x0+  converged = False+  r, beta = gmres_residual(A_mv, A_args, b, x)+  b_norm = jnp.linalg.norm(b)+  for n_iter in range(maxiter):+    # pylint: disable=too-many-function-args+    x = gmres(A_mv, A_args, num_krylov_vectors, x, r, beta)+    r, beta = gmres_residual(A_mv, A_args, b, x)+    if beta <= max(tol*b_norm, atol):+      converged = True+      break+  return (x, beta, n_iter, converged)+++@jax.jit+def gmres_residual(A_mv: Callable, A_args: Sequence, b: jnp.ShapedArray,+                   x: jnp.ShapedArray) -> Tuple[jnp.ShapedArray, float]:+  """+  Computes the residual vector r and its norm, beta, which is minimized by+  GMRES.+  """+  r = b - A_mv(x, *A_args)+  beta = jnp.linalg.norm(r)+  return r, beta+++@partial(jax.jit, static_argnums=(2,))+def gmres(A_mv: Callable, A_args: Sequence, n_kry: int,+          x0: jnp.ShapedArray, r: jnp.ShapedArray,+          beta: float) -> jnp.ShapedArray:+  """+  Solve A x = b for x by the unrestarted GMRES method.+  Given A, a trial solution x, the residual r,+  and the size n_kry of the Krylov space, iterates x towards the solution,+  by finding y in x = x_0 + V y minimizing ||beta - H y||.+  """+  v = r / beta+  Vk_1, Htilde = gmres_arnoldi(A_mv, A_args, n_kry, v)+  Q, Rtilde = jnp.linalg.qr(Htilde, mode="complete")+  Q = Q.T.conj()+  R = Rtilde[:-1, :]+  g = beta*jnp.ravel(Q[:-1, 0])+  y = jax.scipy.linalg.solve_triangular(R, g)+  update = Vk_1[:, :-1] @ y+  x = x0 + update+  return x+++@partial(jax.jit, static_argnums=(2,))

sorry I was not precise in my wording. I meant the arnoldi factorization in jitted_functions.py. This one priduces exactly that

alewis

comment created time in 6 days

push eventmganahl/TensorNetwork

mganahl

commit sha 8e56e0fb7331f7fcfc106556d22b698587a919a5

fix flatten

view details

mganahl

commit sha 9490cd7b40f6618ed3b5f132758c19aee86c9135

remove unneccesary code

view details

mganahl

commit sha 60702e060aa3cbe038a3322b065e42c1978d4ae8

adapted functions to new charge

view details

mganahl

commit sha 65dfcc339ad52e31d789ead05fd41438691e6474

repr, bugfix

view details

push time in 7 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

+from functools import partial+from typing import Any, Optional, Tuple, Callable, List, Text, Type, Sequence+import jax+import jax.numpy as jnp++def gmres_m(A_mv: Callable, A_args: Sequence, +            b: jnp.ShapedArray, x0: jnp.ShapedArray, tol: float, atol: float,+            num_krylov_vectors: int,+            maxiter: int) -> Tuple[jnp.ShapedArray, float, int, bool]:+  """+  Solve A x = b for x using the m-restarted GMRES method. This is+  intended to be called via jax_backend.gmres.++  Given a linear mapping with (n x n) matrix representation+      A = A_mv(*A_args) gmres_m solves+      Ax = b          (1)+  where x and b are length-b vectors, using the method of+  Generalized Minimum RESiduals with M iterations per restart (GMRES_M).++  Args:++  A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+             `v` have the same shape.+  b        : The `b` in `A @ x = b`.+  A_args   : Positional arguments to `A_mv`.+  x0       : Initial guess solution.+  tol, atol: Solution tolerance to achieve,+             norm(residual) <= max(tol*norm(b), atol).+  num_krylov_vectors+           : Size of the Krylov space to build at each restart.+  maxiter  : The Krylov space will be repeatedly rebuilt up to this many+             times.+++  RETURNS+  -------+  x (array, (n,)) : The approximate solution.+  beta (float)    : Norm of the residual at termination.+  n_iter (int)    : Number of iterations at termination.+  converged (bool): Whether the desired tolerance was achieved.+  """+  num_krylov_vectors += 1+  x = x0+  converged = False+  r, beta = gmres_residual(A_mv, A_args, b, x)+  b_norm = jnp.linalg.norm(b)+  for n_iter in range(maxiter):+    # pylint: disable=too-many-function-args+    x = gmres(A_mv, A_args, num_krylov_vectors, x, r, beta)+    r, beta = gmres_residual(A_mv, A_args, b, x)+    if beta <= max(tol*b_norm, atol):+      converged = True+      break+  return (x, beta, n_iter, converged)+++@jax.jit+def gmres_residual(A_mv: Callable, A_args: Sequence, b: jnp.ShapedArray,+                   x: jnp.ShapedArray) -> Tuple[jnp.ShapedArray, float]:+  """+  Computes the residual vector r and its norm, beta, which is minimized by+  GMRES.+  """+  r = b - A_mv(x, *A_args)+  beta = jnp.linalg.norm(r)+  return r, beta+++@partial(jax.jit, static_argnums=(2,))+def gmres(A_mv: Callable, A_args: Sequence, n_kry: int,+          x0: jnp.ShapedArray, r: jnp.ShapedArray,+          beta: float) -> jnp.ShapedArray:+  """+  Solve A x = b for x by the unrestarted GMRES method.+  Given A, a trial solution x, the residual r,+  and the size n_kry of the Krylov space, iterates x towards the solution,+  by finding y in x = x_0 + V y minimizing ||beta - H y||.+  """+  v = r / beta+  Vk_1, Htilde = gmres_arnoldi(A_mv, A_args, n_kry, v)+  Q, Rtilde = jnp.linalg.qr(Htilde, mode="complete")+  Q = Q.T.conj()+  R = Rtilde[:-1, :]+  g = beta*jnp.ravel(Q[:-1, 0])+  y = jax.scipy.linalg.solve_triangular(R, g)+  update = Vk_1[:, :-1] @ y+  x = x0 + update+  return x+++@partial(jax.jit, static_argnums=(2,))

If so, do you think you can use the one that is implemented in the backend?

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Jax gmres

+from functools import partial+from typing import Any, Optional, Tuple, Callable, List, Text, Type, Sequence+import jax+import jax.numpy as jnp++def gmres_m(A_mv: Callable, A_args: Sequence, +            b: jnp.ShapedArray, x0: jnp.ShapedArray, tol: float, atol: float,+            num_krylov_vectors: int,+            maxiter: int) -> Tuple[jnp.ShapedArray, float, int, bool]:+  """+  Solve A x = b for x using the m-restarted GMRES method. This is+  intended to be called via jax_backend.gmres.++  Given a linear mapping with (n x n) matrix representation+      A = A_mv(*A_args) gmres_m solves+      Ax = b          (1)+  where x and b are length-b vectors, using the method of+  Generalized Minimum RESiduals with M iterations per restart (GMRES_M).++  Args:++  A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+             `v` have the same shape.+  b        : The `b` in `A @ x = b`.+  A_args   : Positional arguments to `A_mv`.+  x0       : Initial guess solution.+  tol, atol: Solution tolerance to achieve,+             norm(residual) <= max(tol*norm(b), atol).+  num_krylov_vectors+           : Size of the Krylov space to build at each restart.+  maxiter  : The Krylov space will be repeatedly rebuilt up to this many+             times.+++  RETURNS+  -------+  x (array, (n,)) : The approximate solution.+  beta (float)    : Norm of the residual at termination.+  n_iter (int)    : Number of iterations at termination.+  converged (bool): Whether the desired tolerance was achieved.+  """+  num_krylov_vectors += 1+  x = x0+  converged = False+  r, beta = gmres_residual(A_mv, A_args, b, x)+  b_norm = jnp.linalg.norm(b)+  for n_iter in range(maxiter):+    # pylint: disable=too-many-function-args+    x = gmres(A_mv, A_args, num_krylov_vectors, x, r, beta)+    r, beta = gmres_residual(A_mv, A_args, b, x)+    if beta <= max(tol*b_norm, atol):+      converged = True+      break+  return (x, beta, n_iter, converged)+++@jax.jit+def gmres_residual(A_mv: Callable, A_args: Sequence, b: jnp.ShapedArray,+                   x: jnp.ShapedArray) -> Tuple[jnp.ShapedArray, float]:+  """+  Computes the residual vector r and its norm, beta, which is minimized by+  GMRES.+  """+  r = b - A_mv(x, *A_args)+  beta = jnp.linalg.norm(r)+  return r, beta+++@partial(jax.jit, static_argnums=(2,))+def gmres(A_mv: Callable, A_args: Sequence, n_kry: int,+          x0: jnp.ShapedArray, r: jnp.ShapedArray,+          beta: float) -> jnp.ShapedArray:+  """+  Solve A x = b for x by the unrestarted GMRES method.+  Given A, a trial solution x, the residual r,+  and the size n_kry of the Krylov space, iterates x towards the solution,+  by finding y in x = x_0 + V y minimizing ||beta - H y||.+  """+  v = r / beta+  Vk_1, Htilde = gmres_arnoldi(A_mv, A_args, n_kry, v)+  Q, Rtilde = jnp.linalg.qr(Htilde, mode="complete")+  Q = Q.T.conj()+  R = Rtilde[:-1, :]+  g = beta*jnp.ravel(Q[:-1, 0])+  y = jax.scipy.linalg.solve_triangular(R, g)+  update = Vk_1[:, :-1] @ y+  x = x0 + update+  return x+++@partial(jax.jit, static_argnums=(2,))

Is this returning an arnoldi factorization?

alewis

comment created time in 7 days

push eventmganahl/TensorNetwork

mganahl

commit sha d6a71d0f3d7064cca4a1ddad226592465a655978

finished basic implementation of BaseCharge

view details

push time in 7 days

push eventmganahl/TensorNetwork

Martin Ganahl

commit sha 51f9e8bc6bad9143a99ddb652da8eb2c12373a00

shave off a few microns of constant overhead (#697) * shave off a few microns of constant overhead * remove time import * shorted code * typo * remove disable Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

mganahl

commit sha 65389f3b3170eb553d1e5f503ee098731e603ad6

Merge remote-tracking branch 'upstream/master' into new_charge_encoding

view details

mganahl

commit sha 9f3894a6b05bfda42497963772fe32eca0e1ce43

WIP

view details

push time in 7 days

pull request commentgoogle/TensorNetwork

Abstract gmres

sounds good

alewis

comment created time in 7 days

push eventgoogle/TensorNetwork

mganahl

commit sha 2e57c0ac3f4e86c6e853190394b01e5d2649ce07

fix error

view details

push time in 7 days

push eventgoogle/TensorNetwork

Martin Ganahl

commit sha 6f00431e3c92748423cf0009359f7f6f3a3e5627

implicitly restarted arnoldi method for JaxBackend (#643) * added eigs * added arnoldi stuff * working on arnoldi * fix call signature * revert * ??? * ???? * fix test * fix argument order in matvec * fix linter * fix doc * fix API * linting * linting * tests fixed * fix test * fixing a bunch of things * fix code * working on arnoldi * check * check added * remove [] * remove [] * more arnoldi * docstring * added eigs to jax * fix bug in arnoldi * trying to make arnoldi faster in jax * fix doc * fix issue * added jitting matvec * some docstrings * remove iram * tests for arnoldi * docstring * linting * remove time import * remove typing import * add seed * remove [] * formatting * formatting * docstring * dicstring * docstring * remove typing * remove redundant ravels * docstring * add complex128 testt * typo * fix comments, typos, tests * adding iram * docstring * docstring added * docstring + error msg changed * docstring * appease pytype * tests added * fix defaults, remove typo * linting * typo * typo * fix docstring * improve docstring * linting * comments * newline * rename iram * add typing * typing import * add keywords * remove kwargs from jitted * linting Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Martin Ganahl

commit sha 28683d739b2e76d71fb74b9d22a4f9faa3fb8cf8

add seeds (#676)

view details

Martin Ganahl

commit sha e8afa6542724113024e2412ce42a23fb720e1d6a

Ncon bugfix (#675) * fix bug * add test * remove skip * linting

view details

Martin Ganahl

commit sha 79d9d886e22d08ac84babe2c8ff14460eb56775d

fix jitting (#678) * fix jitting issue * list->np.ndarray * typo * remove unneded check * code cleaning * formatting * typing * typing * typigng, linting

view details

Martin Ganahl

commit sha 2d38d0a2db92c2105c7269fbd2f0e8e8d14e18fb

fix some typing in backends (#683) * typing * typing * typing * typing. lintingc * linting * typing

view details

Martin Ganahl

commit sha be7713ab53c835b5591c64fbe853f9b66e7d58af

extend __matmul__ support to vectors and matrices (#687) * __matmul__ support for tensors (just like numpy) * extend __matmul__ to tensors * restrict to vectors and matrices * tests adjusted * yapf Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Martin Ganahl

commit sha 6f62db9f624accdce9d7cd7af2dbf70c4dcf80be

adding sum and matmul to backends (#681) * adding sum and matmul to backends * formatting * change to jitted_functions.py

view details

Jackson

commit sha 98b0b13318ef44e0b20b1b6d6f181db9040d63f2

Add support for asymmetric input/output dimensions in Entangler Layer (#677) * * Changed entangler layer to create complete node levels * Added support for asymmetric input/output dimensions * * Clarify test parameters * * Fix typo * * Update Entangler doc string * Add additional entangler tests * Update entangler to use min(output_leg, input_leg) for primary size * * Cleanup test_layer Co-authored-by: Martin Ganahl <martin.ganahl@gmail.com> Co-authored-by: Ben Penchas <bpenchas@google.com>

view details

Martin Ganahl

commit sha 5ae84eb82b31ccbf84a08c327c8ca472280af786

transpose_data -> contiguous (#684) * transpose_data -> contiguous * linting * fix test

view details

Martin Ganahl

commit sha 5322a5b1d2dd97c42de200233684fbadefdc95a2

ncon with batching support (#682) * adding sum and matmul to backends * formatting * change to jitted_functions.py * added batched ncon * typo * typo * docstring * docstring * rename routine, change docstring * renaming * fix test * more readable code * typo * linting * allow alphanumeric strings only * test updated * yapf * typo * typo Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Chase Roberts

commit sha a4d0afe55d536c16292414fa4c19a0e38552fee3

Update mpo.py (#689)

view details

Martin Ganahl

commit sha c9b418349f63e91d0ae2f8d162e6cceb57ece0a5

reduce python overhead in ncon a tiny bit (#693) * shaved off a few milliseconds of constant overhead * linting * newlinw * shorten code * typo * shorted check * remove newline

view details

Martin Ganahl

commit sha 00065ffdbbe0d6e4075045888e469e78e06dcf55

fix abstract_backend (#694)

view details

jeyassri balachandran

commit sha 7967a91c7571f1fa6c97954f6c0c16af0ca4c173

issues-686 repr in index.py to return more info (#696) * modified repr in index.py to return dense shape, charges and flow information * added formatting descriptors * fixed linting whitespace * modiefied Index.__repr__ * remove commented code Co-authored-by: mganahl <martin.ganahl@gmail.com>

view details

Martin Ganahl

commit sha 892468e40ce7e7191720d1f2e6b5e973d1c117f5

add eigsh_lanczos to blocksparse backend (#688) * typing * typing * typing * typing. lintingc * linting * typing * wip * transpose_data -> contiguous * wip adding eigsh_lanczos * __matmul__ support for tensors (just like numpy) * extend __matmul__ to tensors * restrict to vectors and matrices * tests adjusted * yapf * wip * wip * wip * testing added * comment * comment * typing * yapf * yapf * linting * typinig changes * typing * remove newline * comment * comment * comment * add type * linting * typing * typing * typing * linting * remove parens * space * remove complex test * remove complex casting * remove parense * add some parense for readability

view details

Martin Ganahl

commit sha 51f9e8bc6bad9143a99ddb652da8eb2c12373a00

shave off a few microns of constant overhead (#697) * shave off a few microns of constant overhead * remove time import * shorted code * typo * remove disable Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

mganahl

commit sha 3de65450334ff2dc66f9c2abe8cdf1a134a042f4

Merge branch 'master' into experimental_ncon

view details

push time in 7 days

PR opened google/TensorNetwork

merge master into experimental_ncon
+2193 -504

0 comment

38 changed files

pr created time in 7 days

push eventmganahl/TensorNetwork

Martin Ganahl

commit sha 6f00431e3c92748423cf0009359f7f6f3a3e5627

implicitly restarted arnoldi method for JaxBackend (#643) * added eigs * added arnoldi stuff * working on arnoldi * fix call signature * revert * ??? * ???? * fix test * fix argument order in matvec * fix linter * fix doc * fix API * linting * linting * tests fixed * fix test * fixing a bunch of things * fix code * working on arnoldi * check * check added * remove [] * remove [] * more arnoldi * docstring * added eigs to jax * fix bug in arnoldi * trying to make arnoldi faster in jax * fix doc * fix issue * added jitting matvec * some docstrings * remove iram * tests for arnoldi * docstring * linting * remove time import * remove typing import * add seed * remove [] * formatting * formatting * docstring * dicstring * docstring * remove typing * remove redundant ravels * docstring * add complex128 testt * typo * fix comments, typos, tests * adding iram * docstring * docstring added * docstring + error msg changed * docstring * appease pytype * tests added * fix defaults, remove typo * linting * typo * typo * fix docstring * improve docstring * linting * comments * newline * rename iram * add typing * typing import * add keywords * remove kwargs from jitted * linting Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Martin Ganahl

commit sha 28683d739b2e76d71fb74b9d22a4f9faa3fb8cf8

add seeds (#676)

view details

Martin Ganahl

commit sha e8afa6542724113024e2412ce42a23fb720e1d6a

Ncon bugfix (#675) * fix bug * add test * remove skip * linting

view details

Martin Ganahl

commit sha 79d9d886e22d08ac84babe2c8ff14460eb56775d

fix jitting (#678) * fix jitting issue * list->np.ndarray * typo * remove unneded check * code cleaning * formatting * typing * typing * typigng, linting

view details

Martin Ganahl

commit sha 2d38d0a2db92c2105c7269fbd2f0e8e8d14e18fb

fix some typing in backends (#683) * typing * typing * typing * typing. lintingc * linting * typing

view details

Martin Ganahl

commit sha be7713ab53c835b5591c64fbe853f9b66e7d58af

extend __matmul__ support to vectors and matrices (#687) * __matmul__ support for tensors (just like numpy) * extend __matmul__ to tensors * restrict to vectors and matrices * tests adjusted * yapf Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Martin Ganahl

commit sha 6f62db9f624accdce9d7cd7af2dbf70c4dcf80be

adding sum and matmul to backends (#681) * adding sum and matmul to backends * formatting * change to jitted_functions.py

view details

Jackson

commit sha 98b0b13318ef44e0b20b1b6d6f181db9040d63f2

Add support for asymmetric input/output dimensions in Entangler Layer (#677) * * Changed entangler layer to create complete node levels * Added support for asymmetric input/output dimensions * * Clarify test parameters * * Fix typo * * Update Entangler doc string * Add additional entangler tests * Update entangler to use min(output_leg, input_leg) for primary size * * Cleanup test_layer Co-authored-by: Martin Ganahl <martin.ganahl@gmail.com> Co-authored-by: Ben Penchas <bpenchas@google.com>

view details

Martin Ganahl

commit sha 5ae84eb82b31ccbf84a08c327c8ca472280af786

transpose_data -> contiguous (#684) * transpose_data -> contiguous * linting * fix test

view details

Martin Ganahl

commit sha 5322a5b1d2dd97c42de200233684fbadefdc95a2

ncon with batching support (#682) * adding sum and matmul to backends * formatting * change to jitted_functions.py * added batched ncon * typo * typo * docstring * docstring * rename routine, change docstring * renaming * fix test * more readable code * typo * linting * allow alphanumeric strings only * test updated * yapf * typo * typo Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Chase Roberts

commit sha a4d0afe55d536c16292414fa4c19a0e38552fee3

Update mpo.py (#689)

view details

Martin Ganahl

commit sha c9b418349f63e91d0ae2f8d162e6cceb57ece0a5

reduce python overhead in ncon a tiny bit (#693) * shaved off a few milliseconds of constant overhead * linting * newlinw * shorten code * typo * shorted check * remove newline

view details

Martin Ganahl

commit sha 00065ffdbbe0d6e4075045888e469e78e06dcf55

fix abstract_backend (#694)

view details

jeyassri balachandran

commit sha 7967a91c7571f1fa6c97954f6c0c16af0ca4c173

issues-686 repr in index.py to return more info (#696) * modified repr in index.py to return dense shape, charges and flow information * added formatting descriptors * fixed linting whitespace * modiefied Index.__repr__ * remove commented code Co-authored-by: mganahl <martin.ganahl@gmail.com>

view details

Martin Ganahl

commit sha 892468e40ce7e7191720d1f2e6b5e973d1c117f5

add eigsh_lanczos to blocksparse backend (#688) * typing * typing * typing * typing. lintingc * linting * typing * wip * transpose_data -> contiguous * wip adding eigsh_lanczos * __matmul__ support for tensors (just like numpy) * extend __matmul__ to tensors * restrict to vectors and matrices * tests adjusted * yapf * wip * wip * wip * testing added * comment * comment * typing * yapf * yapf * linting * typinig changes * typing * remove newline * comment * comment * comment * add type * linting * typing * typing * typing * linting * remove parens * space * remove complex test * remove complex casting * remove parense * add some parense for readability

view details

Martin Ganahl

commit sha 51f9e8bc6bad9143a99ddb652da8eb2c12373a00

shave off a few microns of constant overhead (#697) * shave off a few microns of constant overhead * remove time import * shorted code * typo * remove disable Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

mganahl

commit sha cebf5a1e4b8afb8bf95ad981a652d953827c6602

Merge branch 'master' into experimental_ncon

view details

push time in 7 days

push eventmganahl/TensorNetwork

Adam GM Lewis

commit sha 81b9db39c0d7f6d984af2b6a592311fa5ab1859b

Add linalg directory and initialization code

view details

Adam GM Lewis

commit sha c52b16bfbb87b3624ae4d6e46c22d36ca60e3b4d

Delete linalg for merge

view details

Adam GM Lewis

commit sha 15f9447c8b98a8299d01766e837a260c5c4b1690

Merge https://github.com/google/TensorNetwork

view details

Adam GM Lewis

commit sha 7393da01cf2c2f901c4724e453bda7f271303a7e

Merge remote-tracking branch 'upstream/experimental_ncon' into experimental_ncon

view details

Jackson

commit sha c8cc12c246f72ce645cabffa0a6b5ccf3579b65b

* Update README.md - add Conv2DMPO Layer (#648)

view details

Ben Penchas

commit sha ec2008035f1879f4e34bf048dc37439baa9f51d2

Adding jupyter notebook with Keras layer examples (#645) * Adding jupyter notebook with Keras layer examples * Moving to colabs dir Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Martin Ganahl

commit sha b8965af51538b59ea69131e9a0860eb5ae462b36

k-step arnoldi factorization (#641) * added eigs * added arnoldi stuff * working on arnoldi * fix call signature * revert * ??? * ???? * fix test * fix argument order in matvec * fix linter * fix doc * fix API * linting * linting * tests fixed * fix test * fixing a bunch of things * fix code * working on arnoldi * check * check added * remove [] * remove [] * more arnoldi * docstring * added eigs to jax * fix bug in arnoldi * trying to make arnoldi faster in jax * fix doc * fix issue * added jitting matvec * some docstrings * remove iram * tests for arnoldi * docstring * linting * remove time import * remove typing import * add seed * remove [] * formatting * formatting * docstring * dicstring * docstring * remove typing * fix comments, typos, tests * comments * revert * remove line Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Chase Roberts

commit sha ecd48474a70bc565e7eb8163ddd78461d68a1de6

Update README.md

view details

Jackson

commit sha fd1e6bf3ba93c388b8bcc06956ab2034340b7053

Modify DenseDecomp Layer to handle ndimensional input (#649) * * Modify DenseDecomp Layer to handle ndimensional input * Add tests * * fix docstring Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Ben Penchas

commit sha efcc2d50190c26b253239a355632135c0c8f7597

Remove commented line (#654) Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Martin Ganahl

commit sha bb92c3a3f2271fc9b34c43541935ed596fb834f8

Refactor operations (#644) * move stuff from network_operations.py to linalg.py * move stuff * move tests to linalg.py * add tests * add imports * formatting * formatting * formatting * remove tensornetwork import * fix tests * fix tests Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Jackson

commit sha 3c72f403295fd1eb06c3da3ca6e7cc43d54e6fcf

* Add N-dim input handling for mpo, entangler, expander, and condensor (#655) layer * Update tests

view details

Lazersmoke

commit sha 72360d6b0ec91b2b5b7e0c04fcc1641908def50b

Add documentation stubs (#608) These are linked to by the "Edit on GitHub" links in the documentation. Adding these means contributors can immediately edit them on GitHub without download, building docs, etc. Disable overwriting documentation stubs Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

Adam GM Lewis

commit sha 99e33607a3e7ec21570b03b0e71021fd8e8c9067

Comment out __init__.py code for linalg

view details

Adam GM Lewis

commit sha 1f486434a6c664f00a7bdc2bc1d74405d09eb845

Restore header to linalg.py for some reason

view details

Adam GM Lewis

commit sha beee0378a8ac281a2266a88a005c9f02a664518d

tensor.py and tests pass linter and typing

view details

Adam GM Lewis

commit sha 28698dbc63c390fb2a238c37e2f6859afa344212

Add tn.Tensor with some basic methods

view details

Adam GM Lewis

commit sha 43a1c9e919c35fa3fd1d8c57f220534f3f34e6d1

Merge remote-tracking branch 'upstream/experimental_ncon' into experimental_ncon

view details

Adam GM Lewis

commit sha dcdc101314f15d2909cfb05b6892aa62b6e166b5

Type hinting in tensor.py supports Python 3.6

view details

Adam GM Lewis

commit sha 9329aa3505953bd2f4c7ce3a3f67cdedaee2d183

Forgot to delete the import

view details

push time in 7 days

delete branch mganahl/TensorNetwork

delete branch : index_repr

delete time in 7 days

push eventmganahl/TensorNetwork

Martin Ganahl

commit sha 51f9e8bc6bad9143a99ddb652da8eb2c12373a00

shave off a few microns of constant overhead (#697) * shave off a few microns of constant overhead * remove time import * shorted code * typo * remove disable Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

push time in 7 days

delete branch mganahl/TensorNetwork

delete branch : faster_ncon

delete time in 7 days

push eventgoogle/TensorNetwork

Martin Ganahl

commit sha 51f9e8bc6bad9143a99ddb652da8eb2c12373a00

shave off a few microns of constant overhead (#697) * shave off a few microns of constant overhead * remove time import * shorted code * typo * remove disable Co-authored-by: Chase Roberts <chaseriley@google.com>

view details

push time in 7 days

PR merged google/TensorNetwork

shave off a few microns of constant overhead cla: yes

change network_structure from List[np.ndarray] to List[List] and replacing numpy calls with list operations (this is faster for small arrays)

+158 -206

0 comment

2 changed files

mganahl

pr closed time in 7 days

push eventgoogle/TensorNetwork

Adam Lewis

commit sha bd00d6de2921f6932142690d11eb816497c77bd4

Linalg (#701) * Initialization tests pass * Fixes for PyType * Add decomposition functions, change name of decomps * Fix typos, remove incorrect concat and prod * Amend __init__.py per last commit * Suppress warning * Fix build

view details

push time in 7 days

PR merged google/TensorNetwork

Linalg cla: yes
  • Decomposition functions and tests.
  • Renamed svd_decomposition, qr_decomposition, rq_decomposition to svd, qr, rq.
  • Some backend changes to make functions work.
  • Changes throughout to respect new names.
  • High level names (e.g. MPS) untouched.
  • Minor changes to silence warnings.
+491 -155

1 comment

23 changed files

alewis

pr closed time in 7 days

pull request commentgoogle/TensorNetwork

Linalg

Hey @alewis, would you like to make a PR to master with renaming svd_decomposition, qr_decomposition, ...? There's no reason to wait with this.

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Numpy gmres

 def test_eigsh_lanczos_raises():     backend.eigsh_lanczos(lambda x: x, initial_state=[1, 2, 3])  +def test_gmres_raises():+  backend = numpy_backend.NumPyBackend()+  dummy_mv = lambda x: x+  N = 10+  with pytest.raises(ValueError): # x0, b have different sizes+    backend.gmres(dummy_mv, np.zeros((N,)), x0=np.zeros((N+1,)))+  with pytest.raises(ValueError): # x0, b have different dtypes+    backend.gmres(dummy_mv, np.zeros((N,), dtype=np.float32),+                  x0=np.zeros(N, dtype=np.float64))+  with pytest.raises(ValueError): # x0, b have different shapes+    x0 = np.zeros(N)+    b = np.zeros(N).reshape(2, N//2)+    backend.gmres(dummy_mv, b, x0=x0)+  with pytest.raises(ValueError): # num_krylov_vectors < 0+    backend.gmres(dummy_mv, np.zeros((N,)), num_krylov_vectors=-1)+  with pytest.raises(ValueError): # num_krylov_vectors <= 0+    backend.gmres(dummy_mv, np.zeros((N,)), num_krylov_vectors=0)+  with pytest.raises(ValueError): # num_krylov_vectors > b.size+    backend.gmres(dummy_mv, np.zeros((N,)), num_krylov_vectors=N+1)+  with pytest.raises(ValueError): # tol < 0+    backend.gmres(dummy_mv, np.zeros((N,)), tol=-0.3)+  with pytest.raises(ValueError): # atol < 0+    backend.gmres(dummy_mv, np.zeros((N,)), atol=-0.3)+++@pytest.mark.parametrize("dtype", np_dtypes)

add a test that checks the accuracy of ||Ax-b|| for the found solution x for some non-trivial matrix dimension (maybe 100 or so)

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Numpy gmres

 def test_eigsh_lanczos_raises():     backend.eigsh_lanczos(lambda x: x, initial_state=[1, 2, 3])  +def test_gmres_raises():+  backend = numpy_backend.NumPyBackend()+  dummy_mv = lambda x: x+  N = 10+  with pytest.raises(ValueError): # x0, b have different sizes

add pytest.raise(ValueError, match="the error string") to all checks so we know the right error is raised

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Numpy gmres

 def matvec(vector):     if dtype:       eta = eta.astype(dtype)       U = U.astype(dtype)-    return eta, [np.reshape(U[:, n], shape) for n in range(numeig)]+    evs = list(eta)+    eVs = [np.reshape(U[:, n], shape) for n in range(numeig)]+    return evs, eVs++  def gmres(self,

let's submit two issues then

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Abstract gmres

 def eigsh_lanczos(self,     raise NotImplementedError(         "Backend '{}' has not implemented eighs_lanczos.".format(self.name)) +  def gmres(self,+            A_mv: Callable,+            b: Tensor,+            A_args: Optional[List] = None,+            A_kwargs: Optional[dict] = None,+            x0: Optional[Tensor] = None,+            tol: float = 1E-05,+            atol: Optional[float] = None,+            num_krylov_vectors: Optional[int] = None,+            maxiter: Optional[int] = 1,+            M: Optional[Callable] = None+            ) -> Tuple[Tensor, int]:+    """ GMRES solves the linear system A @ x = b for x given a vector `b` and+    a general (not necessarily symmetric/Hermitian) linear operator `A`.++    As a Krylov method, GMRES does not require a concrete matrix representation+    of the n by n `A`, but only a function+    `vector1 = A_mv(vector0, *A_args, **A_kwargs)`+    prescribing a one-to-one linear map from vector0 to vector1 (that is,+    A must be square, and thus vector0 and vector1 the same size). If `A` is a+    dense matrix, or if it is a symmetric/Hermitian operator, a different+    linear solver will usually be preferable.++    GMRES works by first constructing the Krylov basis+    K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then+    solving a certain dense linear system K @ q0 = q1 from whose solution x can+    be approximated. For `num_krylov_vectors = n` the solution is provably exact+    in infinite precision, but the expense is cubic in `num_krylov_vectors` so+    one is typically interested in the `num_krylov_vectors << n` case.+    The solution can in this case be repeatedly+    improved, to a point, by restarting the Arnoldi iterations each time+    `num_krylov_vectors` is reached. Unfortunately the optimal parameter choices+    balancing expense and accuracy are difficult to predict in advance, so+    applying this function requires a degree of experimentation.++    In a tensor network code one is typically interested in A_mv implementing+    some tensor contraction. This implementation thus allows `b` and `x0` to be+    of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects.+    Reshaping to and from a matrix problem is handled internally.++    This function is supported only with the NumPy and Jax backends.+    The numpy backend version is simply an interface to

The last part can go, it's enough if it is described in the appropriate backend I think

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Abstract gmres

 def test_eigs_not_implemented():     backend.eigs(np.ones((2, 2)))  +def test_gmres_not_implemented():

I think they both inherit from AbstractBackend, if not then that's a bug

alewis

comment created time in 7 days

pull request commentgoogle/TensorNetwork

Abstract gmres

if #704 builds then its fine

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Abstract gmres

 def test_eigs_not_implemented():     backend.eigs(np.ones((2, 2)))  +def test_gmres_not_implemented():

thanks for writing those tests as well!

alewis

comment created time in 7 days

PullRequestEvent

PR closed google/TensorNetwork

Abstract gmres cla: yes

Adds the GMRES interface to abstract_backend. Tests that it is not implemented there, in PyTorch, or in TensorFlow.

+116 -0

1 comment

4 changed files

alewis

pr closed time in 7 days

pull request commentgoogle/TensorNetwork

Abstract gmres

I think this should probably go before #704, since the former shouldn't build if gmres is not defined in the AbstractBackend (eihter pytype or pylint should complain).

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Numpy gmres

 def matvec(vector):     if dtype:       eta = eta.astype(dtype)       U = U.astype(dtype)-    return eta, [np.reshape(U[:, n], shape) for n in range(numeig)]+    evs = list(eta)+    eVs = [np.reshape(U[:, n], shape) for n in range(numeig)]+    return evs, eVs++  def gmres(self,+            A_mv: Callable,+            b: np.ndarray,+            A_args: Optional[List] = None,+            A_kwargs: Optional[dict] = None,+            x0: Optional[np.ndarray] = None,+            tol: float = 1E-05,+            atol: Optional[float] = None,+            num_krylov_vectors: Optional[int] = None,+            maxiter: Optional[int] = 1,+            M: Optional[Callable] = None+            ) -> Tuple[np.ndarray, int]:+    """ GMRES solves the linear system A @ x = b for x given a vector `b` and+    a general (not necessarily symmetric/Hermitian) linear operator `A`.++    As a Krylov method, GMRES does not require a concrete matrix representation+    of the n by n `A`, but only a function+    `vector1 = A_mv(vector0, *A_args, **A_kwargs)`+    prescribing a one-to-one linear map from vector0 to vector1 (that is,+    A must be square, and thus vector0 and vector1 the same size). If `A` is a+    dense matrix, or if it is a symmetric/Hermitian operator, a different+    linear solver will usually be preferable.++    GMRES works by first constructing the Krylov basis+    K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then+    solving a certain dense linear system K @ q0 = q1 from whose solution x can+    be approximated. For `num_krylov_vectors = n` the solution is provably exact+    in infinite precision, but the expense is cubic in `num_krylov_vectors` so+    one is typically interested in the `num_krylov_vectors << n` case.+    The solution can in this case be repeatedly+    improved, to a point, by restarting the Arnoldi iterations each time+    `num_krylov_vectors` is reached. Unfortunately the optimal parameter choices+    balancing expense and accuracy are difficult to predict in advance, so+    applying this function requires a degree of experimentation.++    In a tensor network code one is typically interested in A_mv implementing+    some tensor contraction. This implementation thus allows `b` and `x0` to be+    of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects.+    Reshaping to and from a matrix problem is handled internally.++    The numpy backend version of GMRES is simply an interface to+    `scipy.sparse.linalg.gmres`, itself an interace to ARPACK.+    SciPy 1.1.0 or newer (May 05 2018) is required.++    Args:+      A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+                 `v` have the same shape.+      b        : The `b` in `A @ x = b`; it should be of the shape `A_mv`+                 operates on.+      A_args   : Positional arguments to `A_mv`, supplied to this interface+                 as a list.+                 Default: None.+      A_kwargs : Keyword arguments to `A_mv`, supplied to this interface+                 as a dictionary.+                 Default: None.+      x0       : An optional guess solution. Zeros are used by default.+                 If `x0` is supplied, its shape and dtype must match those of+                 `b`, or an+                 error will be thrown.+                 Default: zeros.+      tol, atol: Solution tolerance to achieve,+                 norm(residual) <= max(tol*norm(b), atol).+                 Default: tol=1E-05+                          atol=tol+      num_krylov_vectors+               : Size of the Krylov space to build at each restart.+                 Expense is cubic in this parameter. If supplied, it must be+                 an integer in 0 < num_krylov_vectors <= b.size.+                 Default: b.size.+      maxiter  : The Krylov space will be repeatedly rebuilt up to this many+                 times. Large values of this argument+                 should be used only with caution, since especially for nearly+                 symmetric matrices and small `num_krylov_vectors` convergence+                 might well freeze at a value significantly larger than `tol`.+                 Default: 1.+      M        : Inverse of the preconditioner of A; see the docstring for+                 `scipy.sparse.linalg.gmres`. This is only supported in the+                 numpy backend. Supplying this argument to other backends will+                 trigger NotImplementedError.+                 Default: None.++    Raises:+      ValueError: -if `x0` is supplied but its shape differs from that of `b`.+                  -if the ARPACK solver reports a breakdown (which usually+                   indicates some kind of floating point issue).+                  -if num_krylov_vectors is 0 or exceeds b.size.+                  -if tol was negative.++    Returns:+      x       : The converged solution. It has the same shape as `b`.+      info    : 0 if convergence was achieved, the number of restarts otherwise.+    """++    if x0 is not None and x0.shape != b.shape:+      errstring = "If x0 is supplied, its shape, " + str(x0.shape) + ", must \n"+      errstring += "match that of b, " + str(b.shape) + "."+      raise ValueError(errstring)+    if x0 is not None and x0.dtype != b.dtype:+      errstring = "If x0 is supplied, its dtype, " + str(x0.dtype) + ", must \n"+      errstring += "match that of b, " + str(b.dtype) + "."+      raise ValueError(errstring)+    if num_krylov_vectors is None:+      num_krylov_vectors = b.size+    if num_krylov_vectors <= 0 or num_krylov_vectors > b.size:+      errstring = "num_krylov_vectors must be in "+      errstring += "0 < num_krylov_vectors <= b.size.\n"+      errstring += "num_krylov_vectors = " + str(num_krylov_vectors) + ".\n"+      errstring += "b.size = " + str(b.size) + "."+      raise ValueError(errstring)+    if tol < 0:+      raise ValueError("tol = ", tol, " must be positive.")+    if atol is None:+      atol = tol+    if atol < 0:+      raise ValueError("atol = ", atol, " must be positive.")++    if A_args is None:+      A_args = []+    if A_kwargs is None:+      A_kwargs = {}++    def matvec(v):+      v_tensor = v.reshape(b.shape)+      Av = A_mv(v_tensor, *A_args, **A_kwargs)+      Avec = Av.ravel()+      return Avec++    A_shape = (b.size, b.size)+    A_op = sp.sparse.linalg.LinearOperator(matvec=matvec, shape=A_shape)+    x, info = sp.sparse.linalg.gmres(A_op, b, x0=x0, tol=tol, atol=atol,+                                     restart=num_krylov_vectors,+                                     maxiter=maxiter, M=M)+    if info < 0:+      raise ValueError("ARPACK gmres received illegal input or broke down.")

I think. If not, then its fine

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Numpy gmres

 def matvec(vector):     if dtype:       eta = eta.astype(dtype)       U = U.astype(dtype)-    return eta, [np.reshape(U[:, n], shape) for n in range(numeig)]+    evs = list(eta)+    eVs = [np.reshape(U[:, n], shape) for n in range(numeig)]+    return evs, eVs++  def gmres(self,+            A_mv: Callable,+            b: np.ndarray,+            A_args: Optional[List] = None,+            A_kwargs: Optional[dict] = None,+            x0: Optional[np.ndarray] = None,+            tol: float = 1E-05,+            atol: Optional[float] = None,+            num_krylov_vectors: Optional[int] = None,+            maxiter: Optional[int] = 1,+            M: Optional[Callable] = None+            ) -> Tuple[np.ndarray, int]:+    """ GMRES solves the linear system A @ x = b for x given a vector `b` and+    a general (not necessarily symmetric/Hermitian) linear operator `A`.++    As a Krylov method, GMRES does not require a concrete matrix representation+    of the n by n `A`, but only a function+    `vector1 = A_mv(vector0, *A_args, **A_kwargs)`+    prescribing a one-to-one linear map from vector0 to vector1 (that is,+    A must be square, and thus vector0 and vector1 the same size). If `A` is a+    dense matrix, or if it is a symmetric/Hermitian operator, a different+    linear solver will usually be preferable.++    GMRES works by first constructing the Krylov basis+    K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then+    solving a certain dense linear system K @ q0 = q1 from whose solution x can+    be approximated. For `num_krylov_vectors = n` the solution is provably exact+    in infinite precision, but the expense is cubic in `num_krylov_vectors` so+    one is typically interested in the `num_krylov_vectors << n` case.+    The solution can in this case be repeatedly+    improved, to a point, by restarting the Arnoldi iterations each time+    `num_krylov_vectors` is reached. Unfortunately the optimal parameter choices+    balancing expense and accuracy are difficult to predict in advance, so+    applying this function requires a degree of experimentation.++    In a tensor network code one is typically interested in A_mv implementing+    some tensor contraction. This implementation thus allows `b` and `x0` to be+    of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects.+    Reshaping to and from a matrix problem is handled internally.++    The numpy backend version of GMRES is simply an interface to+    `scipy.sparse.linalg.gmres`, itself an interace to ARPACK.+    SciPy 1.1.0 or newer (May 05 2018) is required.++    Args:+      A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+                 `v` have the same shape.+      b        : The `b` in `A @ x = b`; it should be of the shape `A_mv`+                 operates on.+      A_args   : Positional arguments to `A_mv`, supplied to this interface+                 as a list.+                 Default: None.+      A_kwargs : Keyword arguments to `A_mv`, supplied to this interface+                 as a dictionary.+                 Default: None.+      x0       : An optional guess solution. Zeros are used by default.+                 If `x0` is supplied, its shape and dtype must match those of+                 `b`, or an+                 error will be thrown.+                 Default: zeros.+      tol, atol: Solution tolerance to achieve,+                 norm(residual) <= max(tol*norm(b), atol).+                 Default: tol=1E-05+                          atol=tol+      num_krylov_vectors+               : Size of the Krylov space to build at each restart.+                 Expense is cubic in this parameter. If supplied, it must be+                 an integer in 0 < num_krylov_vectors <= b.size.+                 Default: b.size.+      maxiter  : The Krylov space will be repeatedly rebuilt up to this many+                 times. Large values of this argument+                 should be used only with caution, since especially for nearly+                 symmetric matrices and small `num_krylov_vectors` convergence+                 might well freeze at a value significantly larger than `tol`.+                 Default: 1.+      M        : Inverse of the preconditioner of A; see the docstring for+                 `scipy.sparse.linalg.gmres`. This is only supported in the+                 numpy backend. Supplying this argument to other backends will+                 trigger NotImplementedError.+                 Default: None.++    Raises:+      ValueError: -if `x0` is supplied but its shape differs from that of `b`.+                  -if the ARPACK solver reports a breakdown (which usually+                   indicates some kind of floating point issue).+                  -if num_krylov_vectors is 0 or exceeds b.size.+                  -if tol was negative.++    Returns:+      x       : The converged solution. It has the same shape as `b`.+      info    : 0 if convergence was achieved, the number of restarts otherwise.+    """++    if x0 is not None and x0.shape != b.shape:+      errstring = "If x0 is supplied, its shape, " + str(x0.shape) + ", must \n"+      errstring += "match that of b, " + str(b.shape) + "."+      raise ValueError(errstring)+    if x0 is not None and x0.dtype != b.dtype:+      errstring = "If x0 is supplied, its dtype, " + str(x0.dtype) + ", must \n"+      errstring += "match that of b, " + str(b.dtype) + "."+      raise ValueError(errstring)+    if num_krylov_vectors is None:+      num_krylov_vectors = b.size+    if num_krylov_vectors <= 0 or num_krylov_vectors > b.size:+      errstring = "num_krylov_vectors must be in "+      errstring += "0 < num_krylov_vectors <= b.size.\n"+      errstring += "num_krylov_vectors = " + str(num_krylov_vectors) + ".\n"+      errstring += "b.size = " + str(b.size) + "."+      raise ValueError(errstring)+    if tol < 0:+      raise ValueError("tol = ", tol, " must be positive.")+    if atol is None:+      atol = tol+    if atol < 0:+      raise ValueError("atol = ", atol, " must be positive.")++    if A_args is None:+      A_args = []+    if A_kwargs is None:+      A_kwargs = {}++    def matvec(v):+      v_tensor = v.reshape(b.shape)+      Av = A_mv(v_tensor, *A_args, **A_kwargs)+      Avec = Av.ravel()+      return Avec++    A_shape = (b.size, b.size)+    A_op = sp.sparse.linalg.LinearOperator(matvec=matvec, shape=A_shape)+    x, info = sp.sparse.linalg.gmres(A_op, b, x0=x0, tol=tol, atol=atol,+                                     restart=num_krylov_vectors,+                                     maxiter=maxiter, M=M)+    if info < 0:+      raise ValueError("ARPACK gmres received illegal input or broke down.")

it would make sense to report info since its value usually conveys some information of what went wrong

alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Numpy gmres

 def matvec(vector):     if dtype:       eta = eta.astype(dtype)       U = U.astype(dtype)-    return eta, [np.reshape(U[:, n], shape) for n in range(numeig)]+    evs = list(eta)+    eVs = [np.reshape(U[:, n], shape) for n in range(numeig)]+    return evs, eVs++  def gmres(self,+            A_mv: Callable,+            b: np.ndarray,+            A_args: Optional[List] = None,+            A_kwargs: Optional[dict] = None,+            x0: Optional[np.ndarray] = None,+            tol: float = 1E-05,+            atol: Optional[float] = None,+            num_krylov_vectors: Optional[int] = None,+            maxiter: Optional[int] = 1,+            M: Optional[Callable] = None+            ) -> Tuple[np.ndarray, int]:+    """ GMRES solves the linear system A @ x = b for x given a vector `b` and+    a general (not necessarily symmetric/Hermitian) linear operator `A`.++    As a Krylov method, GMRES does not require a concrete matrix representation+    of the n by n `A`, but only a function+    `vector1 = A_mv(vector0, *A_args, **A_kwargs)`+    prescribing a one-to-one linear map from vector0 to vector1 (that is,+    A must be square, and thus vector0 and vector1 the same size). If `A` is a+    dense matrix, or if it is a symmetric/Hermitian operator, a different+    linear solver will usually be preferable.++    GMRES works by first constructing the Krylov basis+    K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then+    solving a certain dense linear system K @ q0 = q1 from whose solution x can+    be approximated. For `num_krylov_vectors = n` the solution is provably exact+    in infinite precision, but the expense is cubic in `num_krylov_vectors` so+    one is typically interested in the `num_krylov_vectors << n` case.+    The solution can in this case be repeatedly+    improved, to a point, by restarting the Arnoldi iterations each time+    `num_krylov_vectors` is reached. Unfortunately the optimal parameter choices+    balancing expense and accuracy are difficult to predict in advance, so+    applying this function requires a degree of experimentation.++    In a tensor network code one is typically interested in A_mv implementing+    some tensor contraction. This implementation thus allows `b` and `x0` to be+    of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects.+    Reshaping to and from a matrix problem is handled internally.++    The numpy backend version of GMRES is simply an interface to+    `scipy.sparse.linalg.gmres`, itself an interace to ARPACK.+    SciPy 1.1.0 or newer (May 05 2018) is required.++    Args:+      A_mv     : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and+                 `v` have the same shape.+      b        : The `b` in `A @ x = b`; it should be of the shape `A_mv`+                 operates on.+      A_args   : Positional arguments to `A_mv`, supplied to this interface+                 as a list.+                 Default: None.+      A_kwargs : Keyword arguments to `A_mv`, supplied to this interface+                 as a dictionary.+                 Default: None.+      x0       : An optional guess solution. Zeros are used by default.+                 If `x0` is supplied, its shape and dtype must match those of+                 `b`, or an+                 error will be thrown.+                 Default: zeros.+      tol, atol: Solution tolerance to achieve,+                 norm(residual) <= max(tol*norm(b), atol).+                 Default: tol=1E-05+                          atol=tol+      num_krylov_vectors+               : Size of the Krylov space to build at each restart.+                 Expense is cubic in this parameter. If supplied, it must be+                 an integer in 0 < num_krylov_vectors <= b.size.+                 Default: b.size.+      maxiter  : The Krylov space will be repeatedly rebuilt up to this many+                 times. Large values of this argument+                 should be used only with caution, since especially for nearly+                 symmetric matrices and small `num_krylov_vectors` convergence+                 might well freeze at a value significantly larger than `tol`.+                 Default: 1.+      M        : Inverse of the preconditioner of A; see the docstring for+                 `scipy.sparse.linalg.gmres`. This is only supported in the+                 numpy backend. Supplying this argument to other backends will+                 trigger NotImplementedError.+                 Default: None.++    Raises:+      ValueError: -if `x0` is supplied but its shape differs from that of `b`.+                  -if the ARPACK solver reports a breakdown (which usually+                   indicates some kind of floating point issue).+                  -if num_krylov_vectors is 0 or exceeds b.size.+                  -if tol was negative.++    Returns:+      x       : The converged solution. It has the same shape as `b`.+      info    : 0 if convergence was achieved, the number of restarts otherwise.+    """++    if x0 is not None and x0.shape != b.shape:+      errstring = "If x0 is supplied, its shape, " + str(x0.shape) + ", must \n"+      errstring += "match that of b, " + str(b.shape) + "."+      raise ValueError(errstring)+    if x0 is not None and x0.dtype != b.dtype:+      errstring = "If x0 is supplied, its dtype, " + str(x0.dtype) + ", must \n"+      errstring += "match that of b, " + str(b.dtype) + "."+      raise ValueError(errstring)+    if num_krylov_vectors is None:+      num_krylov_vectors = b.size+    if num_krylov_vectors <= 0 or num_krylov_vectors > b.size:+      errstring = "num_krylov_vectors must be in "+      errstring += "0 < num_krylov_vectors <= b.size.\n"+      errstring += "num_krylov_vectors = " + str(num_krylov_vectors) + ".\n"+      errstring += "b.size = " + str(b.size) + "."+      raise ValueError(errstring)

small hint: you can break lines with strings wihthout having to do +=, and with f-strings (supported in python 3.4 and later I think) you can use {} to insert expressions into the string

errstring = f"num_krylov_vectors must be in "
            f"0 < num_krylov_vectors <= b.size.\n"
            f"num_krylov_vectors = {num_krylov_vectors}.\n"
            f"b.size = {b.size}."
alewis

comment created time in 7 days

Pull request review commentgoogle/TensorNetwork

Numpy gmres

 def matvec(vector):     if dtype:       eta = eta.astype(dtype)       U = U.astype(dtype)-    return eta, [np.reshape(U[:, n], shape) for n in range(numeig)]+    evs = list(eta)+    eVs = [np.reshape(U[:, n], shape) for n in range(numeig)]+    return evs, eVs++  def gmres(self,

should we also add lgmres? I found lgmres to be more stable and also faster than gmres in my applications, though it's been a few years since I have last tested this.

alewis

comment created time in 7 days

push eventmganahl/TensorNetwork

mganahl

commit sha 955cef9bec0a6eecf8c146a542c11a59a42ffc8d

file added

view details

push time in 8 days

create barnchmganahl/TensorNetwork

branch : new_charge_encoding

created branch time in 8 days

issue commentgoogle/TensorNetwork

PEPS

Sounds goo. How far are we with VUMPS? Should I start looking into the GMRES implementation for jax, or are you doing it?

alewis

comment created time in 8 days

push eventmganahl/TensorNetwork

jeyassri balachandran

commit sha 7967a91c7571f1fa6c97954f6c0c16af0ca4c173

issues-686 repr in index.py to return more info (#696) * modified repr in index.py to return dense shape, charges and flow information * added formatting descriptors * fixed linting whitespace * modiefied Index.__repr__ * remove commented code Co-authored-by: mganahl <martin.ganahl@gmail.com>

view details

Martin Ganahl

commit sha 892468e40ce7e7191720d1f2e6b5e973d1c117f5

add eigsh_lanczos to blocksparse backend (#688) * typing * typing * typing * typing. lintingc * linting * typing * wip * transpose_data -> contiguous * wip adding eigsh_lanczos * __matmul__ support for tensors (just like numpy) * extend __matmul__ to tensors * restrict to vectors and matrices * tests adjusted * yapf * wip * wip * wip * testing added * comment * comment * typing * yapf * yapf * linting * typinig changes * typing * remove newline * comment * comment * comment * add type * linting * typing * typing * typing * linting * remove parens * space * remove complex test * remove complex casting * remove parense * add some parense for readability

view details

push time in 9 days

delete branch mganahl/TensorNetwork

delete branch : eigsh_blocksparse

delete time in 9 days

push eventgoogle/TensorNetwork

Martin Ganahl

commit sha 892468e40ce7e7191720d1f2e6b5e973d1c117f5

add eigsh_lanczos to blocksparse backend (#688) * typing * typing * typing * typing. lintingc * linting * typing * wip * transpose_data -> contiguous * wip adding eigsh_lanczos * __matmul__ support for tensors (just like numpy) * extend __matmul__ to tensors * restrict to vectors and matrices * tests adjusted * yapf * wip * wip * wip * testing added * comment * comment * typing * yapf * yapf * linting * typinig changes * typing * remove newline * comment * comment * comment * add type * linting * typing * typing * typing * linting * remove parens * space * remove complex test * remove complex casting * remove parense * add some parense for readability

view details

push time in 9 days

push eventmganahl/TensorNetwork

jeyassri balachandran

commit sha 7967a91c7571f1fa6c97954f6c0c16af0ca4c173

issues-686 repr in index.py to return more info (#696) * modified repr in index.py to return dense shape, charges and flow information * added formatting descriptors * fixed linting whitespace * modiefied Index.__repr__ * remove commented code Co-authored-by: mganahl <martin.ganahl@gmail.com>

view details

Martin Ganahl

commit sha f0238d04135ee0289602bfc9652919865e4b87d0

Merge branch 'master' into eigsh_blocksparse

view details

push time in 9 days

push eventmganahl/TensorNetwork

jeyassri balachandran

commit sha 7967a91c7571f1fa6c97954f6c0c16af0ca4c173

issues-686 repr in index.py to return more info (#696) * modified repr in index.py to return dense shape, charges and flow information * added formatting descriptors * fixed linting whitespace * modiefied Index.__repr__ * remove commented code Co-authored-by: mganahl <martin.ganahl@gmail.com>

view details

Martin Ganahl

commit sha 13093a1c2152bdc7dcd82e97229ad8908deac302

Merge branch 'master' into fix_printing

view details

push time in 9 days

push eventgoogle/TensorNetwork

jeyassri balachandran

commit sha 7967a91c7571f1fa6c97954f6c0c16af0ca4c173

issues-686 repr in index.py to return more info (#696) * modified repr in index.py to return dense shape, charges and flow information * added formatting descriptors * fixed linting whitespace * modiefied Index.__repr__ * remove commented code Co-authored-by: mganahl <martin.ganahl@gmail.com>

view details

push time in 9 days

PR merged google/TensorNetwork

issues-686 repr in index.py to return more info cla: yes

modified repr in index.py to return dense shape, charges and flow inf…ormation

+6 -1

9 comments

1 changed file

shreeju

pr closed time in 9 days

pull request commentgoogle/TensorNetwork

issues-686 repr in index.py to return more info

@googlebot I consent

shreeju

comment created time in 9 days

delete branch mganahl/TensorNetwork

delete branch : fix_abstract_backend

delete time in 9 days

delete branch mganahl/TensorNetwork

delete branch : batched_ncon

delete time in 9 days

delete branch mganahl/TensorNetwork

delete branch : batch_ncon

delete time in 9 days

delete branch mganahl/TensorNetwork

delete branch : backends_sum_matmul

delete time in 9 days

delete branch mganahl/TensorNetwork

delete branch : backend_typing

delete time in 9 days

Pull request review commentgoogle/TensorNetwork

issues-686 repr in index.py to return more info

 def __len__(self) -> int:     return self.dim    def __repr__(self) -> str:-    return str(self.dim)+    dense_shape = str(self.dim)

Thanks @shreeju! I submitted an PR to your PR with some minor modifications. Should be good once you pull it in

shreeju

comment created time in 9 days

more