[{"data":1,"prerenderedAt":-1},["ShallowReactive",2],{"project-1589":3},{"id":4,"name":5,"fullName":6,"owner":7,"repo":5,"description":8,"homepage":9,"htmlUrl":10,"language":11,"languages":10,"totalLinesOfCode":10,"stars":12,"forks":13,"watchers":14,"openIssues":15,"contributorsCount":16,"subscribersCount":16,"size":16,"stars1d":17,"stars7d":18,"stars30d":19,"stars90d":16,"forks30d":16,"starsTrendScore":20,"compositeScore":21,"rankGlobal":10,"rankLanguage":10,"license":22,"archived":23,"fork":23,"defaultBranch":24,"hasWiki":23,"hasPages":23,"topics":25,"createdAt":10,"pushedAt":10,"updatedAt":26,"readmeContent":27,"aiSummary":28,"trendingCount":16,"starSnapshotCount":16,"syncStatus":29,"lastSyncTime":30,"discoverSource":31},1589,"jax","jax-ml\u002Fjax","jax-ml","Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU\u002FTPU, and more","https:\u002F\u002Fdocs.jax.dev",null,"Python",35800,3627,330,1666,0,9,46,201,43,118,"Apache License 2.0",false,"main",[5],"2026-06-12 04:00:10","\u003Cdiv align=\"center\">\n\u003Cimg src=\"https:\u002F\u002Fraw.githubusercontent.com\u002Fjax-ml\u002Fjax\u002Fmain\u002Fimages\u002Fjax_logo_250px.png\" alt=\"logo\">\u003C\u002Fimg>\n\u003C\u002Fdiv>\n\n# Transformable numerical computing at scale\n\n[![Continuous integration](https:\u002F\u002Fgithub.com\u002Fjax-ml\u002Fjax\u002Factions\u002Fworkflows\u002Fci-build.yaml\u002Fbadge.svg)](https:\u002F\u002Fgithub.com\u002Fjax-ml\u002Fjax\u002Factions\u002Fworkflows\u002Fci-build.yaml)\n[![PyPI version](https:\u002F\u002Fimg.shields.io\u002Fpypi\u002Fv\u002Fjax)](https:\u002F\u002Fpypi.org\u002Fproject\u002Fjax\u002F)\n\n[**Transformations**](#transformations)\n| [**Scaling**](#scaling)\n| [**Install guide**](#installation)\n| [**Change logs**](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fchangelog.html)\n| [**Reference docs**](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002F)\n\n\n## What is JAX?\n\nJAX is a Python library for accelerator-oriented array computation and program transformation,\ndesigned for high-performance numerical computing and large-scale machine learning.\n\nJAX can automatically differentiate native\nPython and NumPy functions. It can differentiate through loops, branches,\nrecursion, and closures, and it can take derivatives of derivatives of\nderivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)\nvia [`jax.grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,\nand the two can be composed arbitrarily to any order.\n\nJAX uses [XLA](https:\u002F\u002Fwww.openxla.org\u002Fxla)\nto compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators.\nYou can compile your own pure functions with [`jax.jit`](#compilation-with-jit).\nCompilation and automatic differentiation can be composed arbitrarily.\n\nDig a little deeper, and you'll see that JAX is really an extensible system for\n[composable function transformations](#transformations) at [scale](#scaling).\n\nThis is a research project, not an official Google product. Expect\n[sharp edges](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fnotebooks\u002FCommon_Gotchas_in_JAX.html).\nPlease help by trying it out, [reporting bugs](https:\u002F\u002Fgithub.com\u002Fjax-ml\u002Fjax\u002Fissues),\nand letting us know what you think!\n\n```python\nimport jax\nimport jax.numpy as jnp\n\ndef predict(params, inputs):\n  for W, b in params:\n    outputs = jnp.dot(inputs, W) + b\n    inputs = jnp.tanh(outputs)  # inputs to the next layer\n  return outputs                # no activation on last layer\n\ndef loss(params, inputs, targets):\n  preds = predict(params, inputs)\n  return jnp.sum((preds - targets)**2)\n\ngrad_loss = jax.jit(jax.grad(loss))  # compiled gradient evaluation function\nperex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads\n```\n\n### Contents\n* [Transformations](#transformations)\n* [Scaling](#scaling)\n* [Current gotchas](#gotchas-and-sharp-bits)\n* [Installation](#installation)\n* [Citing JAX](#citing-jax)\n* [Reference documentation](#reference-documentation)\n\n## Transformations\n\nAt its core, JAX is an extensible system for transforming numerical functions.\nHere are three: `jax.grad`, `jax.jit`, and `jax.vmap`.\n\n### Automatic differentiation with `grad`\n\nUse [`jax.grad`](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fjax.html#jax.grad)\nto efficiently compute reverse-mode gradients:\n\n```python\nimport jax\nimport jax.numpy as jnp\n\ndef tanh(x):\n  y = jnp.exp(-2.0 * x)\n  return (1.0 - y) \u002F (1.0 + y)\n\ngrad_tanh = jax.grad(tanh)\nprint(grad_tanh(1.0))\n# prints 0.4199743\n```\n\nYou can differentiate to any order with `grad`:\n\n```python\nprint(jax.grad(jax.grad(jax.grad(tanh)))(1.0))\n# prints 0.62162673\n```\n\nYou're free to use differentiation with Python control flow:\n\n```python\ndef abs_val(x):\n  if x > 0:\n    return x\n  else:\n    return -x\n\nabs_val_grad = jax.grad(abs_val)\nprint(abs_val_grad(1.0))   # prints 1.0\nprint(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)\n```\n\nSee the [JAX Autodiff\nCookbook](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fnotebooks\u002Fautodiff_cookbook.html)\nand the [reference docs on automatic\ndifferentiation](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fjax.html#automatic-differentiation)\nfor more.\n\n### Compilation with `jit`\n\nUse XLA to compile your functions end-to-end with\n[`jit`](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fjax.html#just-in-time-compilation-jit),\nused either as an `@jit` decorator or as a higher-order function.\n\n```python\nimport jax\nimport jax.numpy as jnp\n\ndef slow_f(x):\n  # Element-wise ops see a large benefit from fusion\n  return x * x + x * 2.0\n\nx = jnp.ones((5000, 5000))\nfast_f = jax.jit(slow_f)\n%timeit -n10 -r3 fast_f(x)\n%timeit -n10 -r3 slow_f(x)\n```\n\nUsing `jax.jit` constrains the kind of Python control flow\nthe function can use; see\nthe tutorial on [Control Flow and Logical Operators with JIT](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fcontrol-flow.html)\nfor more.\n\n### Auto-vectorization with `vmap`\n\n[`vmap`](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fjax.html#vectorization-vmap) maps\na function along array axes.\nBut instead of just looping over function applications, it pushes the loop down\nonto the function’s primitive operations, e.g. turning matrix-vector multiplies into\nmatrix-matrix multiplies for better performance.\n\nUsing `vmap` can save you from having to carry around batch dimensions in your\ncode:\n\n```python\nimport jax\nimport jax.numpy as jnp\n\ndef l1_distance(x, y):\n  assert x.ndim == y.ndim == 1  # only works on 1D inputs\n  return jnp.sum(jnp.abs(x - y))\n\ndef pairwise_distances(dist1D, xs):\n  return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)\n\nxs = jax.random.normal(jax.random.key(0), (100, 3))\ndists = pairwise_distances(l1_distance, xs)\ndists.shape  # (100, 100)\n```\n\nBy composing `jax.vmap` with `jax.grad` and `jax.jit`, we can get efficient\nJacobian matrices, or per-example gradients:\n\n```python\nper_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0)))\n```\n\n## Scaling\n\nTo scale your computations across thousands of devices, you can use any\ncomposition of these:\n* [**Compiler-based automatic parallelization**](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fparallel.html)\nwhere you program as if using a single global machine, and the compiler chooses\nhow to shard data and partition computation (with some user-provided constraints);\n* [**Explicit sharding and automatic partitioning**](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fparallel.html)\nwhere you still have a global view but data shardings are\nexplicit in JAX types, inspectable using `jax.typeof`;\n* [**Manual per-device programming**](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fnotebooks\u002Fshard_map.html)\nwhere you have a per-device view of data\nand computation, and can communicate with explicit collectives.\n\n| Mode | View? | Explicit sharding? | Explicit Collectives? |\n|---|---|---|---|\n| Auto | Global | ❌ | ❌ |\n| Explicit | Global | ✅ | ❌ |\n| Manual | Per-device | ✅ | ✅ |\n\n```python\nfrom jax.sharding import set_mesh, AxisType, PartitionSpec as P\nmesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))\nset_mesh(mesh)\n\n# parameters are sharded for FSDP:\nfor W, b in params:\n  print(f'{jax.typeof(W)}')  # f32[512@data,512]\n  print(f'{jax.typeof(b)}')  # f32[512]\n\n# shard data for batch parallelism:\ninputs, targets = jax.device_put((inputs, targets), P('data'))\n\n# evaluate gradients, automatically parallelized!\ngradfun = jax.jit(jax.grad(loss))\nparam_grads = gradfun(params, (inputs, targets))\n```\n\nSee the [tutorial](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fparallel.html) and\n[advanced guides](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fadvanced_guide.html) for more.\n\n## Gotchas and sharp bits\n\nSee the [Gotchas\nNotebook](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fnotebooks\u002FCommon_Gotchas_in_JAX.html).\n\n## Installation\n\n### Supported platforms\n\n|            | Linux x86_64 | Linux aarch64 | Mac aarch64  | Windows x86_64 | Windows WSL2 x86_64 |\n|------------|--------------|---------------|--------------|----------------|---------------------|\n| CPU        | yes          | yes           | yes          | yes            | yes                 |\n| NVIDIA GPU | yes          | yes           | n\u002Fa          | no             | experimental        |\n| Google TPU | yes          | n\u002Fa           | n\u002Fa          | n\u002Fa            | n\u002Fa                 |\n| AMD GPU    | yes          | no            | n\u002Fa          | no             | experimental        |\n| Apple GPU  | n\u002Fa          | no            | experimental | n\u002Fa            | n\u002Fa                 |\n| Intel GPU  | experimental | n\u002Fa           | n\u002Fa          | no             | no                  |\n\n\n### Instructions\n\n| Platform        | Instructions                                                                                                    |\n|-----------------|-----------------------------------------------------------------------------------------------------------------|\n| CPU             | `pip install -U jax`                                                                                            |\n| NVIDIA GPU      | `pip install -U \"jax[cuda13]\"`                                                                                  |\n| Google TPU      | `pip install -U \"jax[tpu]\"`                                                                                     |\n| AMD GPU (Linux) | Follow [AMD's instructions](https:\u002F\u002Fgithub.com\u002Fjax-ml\u002Fjax\u002Fblob\u002Fmain\u002Fbuild\u002Frocm\u002FREADME.md).                      |\n| Intel GPU       | Follow [Intel's instructions](https:\u002F\u002Fgithub.com\u002Fintel\u002Fintel-extension-for-openxla\u002Fblob\u002Fmain\u002Fdocs\u002Facc_jax.md).  |\n\nSee [the documentation](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Finstallation.html)\nfor information on alternative installation strategies. These include compiling\nfrom source, installing with Docker, using other versions of CUDA, a\ncommunity-supported conda build, and answers to some frequently-asked questions.\n\n## Citing JAX\n\nTo cite this repository:\n\n```\n@software{jax2018github,\n  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Yash Katariya and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},\n  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},\n  url = {http:\u002F\u002Fgithub.com\u002Fjax-ml\u002Fjax},\n  version = {0.3.13},\n  year = {2018},\n}\n```\n\nIn the above bibtex entry, names are in alphabetical order, the version number\nis intended to be that from [jax\u002Fversion.py](..\u002Fmain\u002Fjax\u002Fversion.py), and\nthe year corresponds to the project's open-source release.\n\nA nascent version of JAX, supporting only automatic differentiation and\ncompilation to XLA, was described in a [paper that appeared at SysML\n2018](https:\u002F\u002Fmlsys.org\u002FConferences\u002F2019\u002Fdoc\u002F2018\u002F146.pdf). We're currently working on\ncovering JAX's ideas and capabilities in a more comprehensive and up-to-date\npaper.\n\n## Reference documentation\n\nFor details about the JAX API, see the\n[reference documentation](https:\u002F\u002Fdocs.jax.dev\u002F).\n\nFor getting started as a JAX developer, see the\n[developer documentation](https:\u002F\u002Fdocs.jax.dev\u002Fen\u002Flatest\u002Fdeveloper.html).\n","JAX 是一个用于加速器导向的数组计算和程序转换的Python库，专为高性能数值计算和大规模机器学习设计。其核心功能包括自动微分、即时编译（JIT）到GPU\u002FTPU以及向量化处理等。JAX能够对原生Python和NumPy函数进行自动微分，并支持反向传播与前向模式微分的任意组合。此外，它利用XLA编译技术来优化在不同硬件加速器上的执行效率。适用于需要高效数值运算及深度学习模型训练的研究与开发场景。",2,"2026-06-11 02:44:51","top_all"]