Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Introduction

La promesse de Jax

Jax promet de pouvoir écrire du code Python et de le déployer sur des plateformes CPU, GPU et TPU-Google sans efforts de traduction particuliers. Il permet aussi de faire des opérations (transformations) assez inédites comme la dérivation automatique des fonctions par rapport à leurs arguments ce qui constitue en soit un game changer pour l’IA et bien d’autres domaines.

Le monde de Jax

Avec des pincettes énormes, on pourrait résumer le monde de Jax à des données sous forme de tenseurs qui sont manipulées par des fonctions pures auxquelles on applique des transformations. Dans les nuances à apporter, il faut noter les données tensorielle sont agencées sous forme de pytrees ce qui une idées extrêmement puissante à elle seule, même si ce n’est pas ce qui saute aux yeux quand on début Jax.

Des fonctions pures ?

Dans Jax, la pureté des fonctions est un sujet qui revient souvent. Une fonction pure est une fonction qui n’a pas d’effets de bords. Elle n’utilise donc pas l’infinité de bidouilles que Python autorise. Dans les grandes lignes, les sorties d’une fonction doivent dépendre de manière déterministe de ses arguments et uniquement d’eux. Cela interdit notamment l’usage de variables globales (enfin, on verra que c’est plus subtile) et aussi de modifier dynamiquement ses propres arguments comme on le ferait souvent en C. Si on pense C justement, on peut se dire cette approche est antinomique avec l’économie de mémoire et la performance en général, en fait oui et non. Jax impose cette contrainte car il va voir les fonctions comme des scripts à interpréter dans son langage (voir jaxpr) et à traduire dans un langage dédié à la plateforme cible. L’optimisation au sens ou la verrait en C n’a donc pas lieu d’être. Le but de la pureté est avant tout de lever toute ambiguïté sur le fonctionnement interne de la fonction et de pouvoir y tracer le chemin de l’information. Il faut donc voir les fonctions comme des tuyaux dans lesquels le flux de données passe et la pureté indique juste que ces derniers ne doivent pas fuir ni laisser entrer des choses du milieu extérieur.

Voici un petit exemple de fonction pure et de la manière sont Jax la comprend:

import jax
from jax import numpy as jnp
import time


def dumb_pure_func(x):
    b = x + 3
    c = b**2
    return c


dumb_pure_func(3)
36
jax.make_jaxpr(dumb_pure_func)(2)
{ lambda ; a:i32[]. let b:i32[] = add a 3:i32[] c:i32[] = integer_pow[y=2] b in (c,) }

Jax est donc capable de comprendre le fonctionnement de la fonction est de la traduire par un bout de code dans son langage.

Les transformations

Imaginons qu’on travaille sur la fonction suivante:

def myfunc(x, a=1, b=1, c=1):
    return a * x**2 + b * x + c


jax.make_jaxpr(myfunc)(3, 1, 1, 1)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[]. let e:i32[] = integer_pow[y=2] a f:i32[] = mul b e g:i32[] = mul c a h:i32[] = add f g i:i32[] = add h d in (i,) }
myfunc(5)
31

Vectoriser avec vmap

On peut vectoriser par rapport à un axe par exemple avec vmap.

vmyfunc = jax.vmap(myfunc, in_axes=(0, None, None, None))
xa = jnp.linspace(0.0, 5.0, 6)
vmyfunc(xa, 1, 1, 1)
Array([ 1., 3., 7., 13., 21., 31.], dtype=float32)

Mais on peut faire des structure bien plus complexes en combinant plusieurs transformations:

vmyfunc2 = jax.vmap(vmyfunc, in_axes=(None, 0, None, None))
aa = jnp.linspace(0.0, 1.0, 3)
vmyfunc2(xa, aa, 1, 1)
Array([[ 1. , 2. , 3. , 4. , 5. , 6. ], [ 1. , 2.5, 5. , 8.5, 13. , 18.5], [ 1. , 3. , 7. , 13. , 21. , 31. ]], dtype=float32)

Le potentiel est énorme car on peut soit vectoriser en plusieurs strates ou aussi le faire d’un coup en jouant sur les axes selon les besoins.

Compiler avec jit

Il est possible de compiler tout ou partie du code avec jit. La compilation va coûter quelques milisecondes et permettre une execution optimisée par la suite.

Ne = 100
xa = jnp.linspace(0.0, 5.0, 6000)
aa = jnp.linspace(0.0, 1.0, 3000)
t0 = time.time()
for e in range(Ne):
    val = vmyfunc2(xa, aa, 1, 1)
    val.block_until_ready()
t1 = time.time()
dt0 = (t1 - t0) / Ne
print(f"Exectution took {dt0*1.e3:.2f} ms")
Exectution took 6.99 ms
jvmyfunc2 = jax.jit(vmyfunc2)
t0 = time.time()
val = jvmyfunc2(xa, aa, 1, 1)
val.block_until_ready()
t1 = time.time()
dt1 = t1 - t0
print(f"Compilation + first execution took {dt1*1.e3:.2f} ms")
Compilation + first execution took 19.40 ms
t0 = time.time()
for e in range(Ne):
    val = jvmyfunc2(xa, aa, 1, 1)
    val.block_until_ready()
t1 = time.time()
dt2 = (t1 - t0) / Ne
print(f"Second execution took {dt2*1.e3:.2f} ms")
Second execution took 1.83 ms

On a donc gagné du temps avec le jit et ce malgré le fait que notre fonction est très simple et donc très optimisée à la base. Cette tendance sera amplifiée sur des calculs lourds sur GPU/TPU.

Autres transformations

Les autres transformations ne sont pas cruciales maintenant alors je les passe sous couvert. Mais elles sont ultra intéressantes dans d’autres cas, surtout grad.

# WIP

Liste non exhaustive des limitations de Jax

Forcément, cette belle promesse vient avec pas mal de limitations.

Les structures de contrôle

On commence par une des plus agacentes au début, les structures de contrôle. Fini les if, foret while.

En fait, ces dernières ne sont pas claires dans leurs buts et peuvent correspondre à plusieurs objectifs. Jax fournit donc des outils de remplacements qui ne manqueront pas de vous énerver (parfois). A titre d’exemple, forsera remplacée alternativement selon les buts par vmap, scan, where, lax.fori_loop ou pourra rester for dans ces bien choisis.

L’allocation dynamique de mémoire

Dans le monde de jax, il est interdit d’allouer dynamiquement de la mémoire, par exemple en créant des array de taille inconnue à la compilation. Cela ne manquera pas de vous créer des frustrations. On verra aussi qu’il est possible de trouver des compromis sur ce point. Le chapitre des sharp bits et globalement toutes les prises de parole de JakeVPD et Patrick Kidger mérite d’être lues pour comprendre la parole sainte à ce sujet.

Exemple:

def dumb_func_allocating_memory(n):
    a = jnp.arange(n)
    return a
# jax.make_jaxpr(dumb_func_allocating_memory)(2) # Uncommment to see the error.

ConcretizationTypeError Traceback (most recent call last) Cell In[11], line 1 ----> 1 jax.make_jaxpr(dumb_func_allocating_memory)(2)

[... skipping hidden 14 frame]

Cell In[10], line 2, in dumb_func_allocating_memory(n) 1 def dumb_func_allocating_memory(n): ----> 2 a = jnp.arange(n) 3 return a

File ~/miniforge3/envs/science/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5947, in arange(start, stop, step, dtype, device, out_sharding) 5945 if sharding is None or not sharding._is_concrete: 5946 assert sharding is None or isinstance(sharding, NamedSharding) -> 5947 return _arange(start, stop=stop, step=step, dtype=dtype, 5948 out_sharding=sharding) 5949 else: 5950 output = _arange(start, stop=stop, step=step, dtype=dtype)

File ~/miniforge3/envs/science/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5962, in _arange(start, stop, step, dtype, out_sharding) 5960 util.check_arraylike(“arange”, start) 5961 if stop is None and step is None: -> 5962 start = core.concrete_or_error(None, start, “It arose in the jnp.arange argument ‘stop’”) 5963 else: 5964 start = core.concrete_or_error(None, start, “It arose in the jnp.arange argument ‘start’”)

File ~/miniforge3/envs/science/lib/python3.12/site-packages/jax/_src/core.py:1847, in concrete_or_error(force, val, context) 1845 maybe_concrete = val.to_concrete_value() 1846 if maybe_concrete is None: -> 1847 raise ConcretizationTypeError(val, context) 1848 else: 1849 return force(maybe_concrete)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[] It arose in the jnp.arange argument ‘stop’ The error occurred while tracing the function dumb_func_allocating_memory at /var/folders/67/hblp6z8n36ldk_9_bl9g80kh0000gn/T/ipykernel_50677/3347747001.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n.

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

Frustration, colère ...

Dans un tel, cas il faut généralement se demander si on a vraiment besoin que nsoit dynamique. Si c’est vraiment le cas, alors on peut le rendre statique (au sens jax) en spécifiant:

import numpy as np


def make_dumb_function_allocating_memory(n):
    def dumb_func_allocating_memory2(a):
        x = a * jnp.arange(n)
        return x

    return dumb_func_allocating_memory2


dfam = jax.jit(make_dumb_function_allocating_memory(3))
jax.make_jaxpr(dfam)(2)
{ lambda ; a:i32[]. let b:i32[3] = jit[ name=dumb_func_allocating_memory2 jaxpr={ lambda ; a:i32[]. let c:i32[3] = iota[dimension=0 dtype=int32 shape=(3,) sharding=None] d:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a b:i32[3] = mul d c in (b,) } ] a in (b,) }