FLAX框架教程-1
介绍JAX

JAX 是一个面向数组的数值计算库(类似于 NumPy),支持自动微分和即时编译(JIT),以实现高性能的机器学习研究。
- JAX 提供了一个统一的类似 NumPy 的接口,用于在 CPU、GPU 或 TPU 上运行计算,并支持本地或分布式环境。
- JAX 通过 Open XLA(一个开源的机器学习编译器生态系统)内置即时编译(JIT)。
- JAX 函数支持通过自动微分转换高效计算梯度。
- JAX 函数可以自动向量化,高效地将其映射到表示输入批次的数组上。
介绍FLAX

FLAX即Neural Networks For JAX(基于JAX的神经网络框架)
Flax 为使用 JAX 构建神经网络的研究人员和开发者提供了灵活且完整的用户体验,充分发挥 JAX 的强大功能。
Flax 的核心是 NNX ——一个简化的 API,让用户能够更轻松地创建、检查、调试和分析 JAX 中的神经网络。Flax NNX 对 Python 的引用语义提供了一流支持,使用户能够使用常规的 Python 对象来表达模型。Flax NNX 是此前 Flax Linen API 的演进版本,结合多年的实践经验,带来了更加简单和用户友好的 API。
因此本教程会使用最新的NNX,而不是原有的Linen API
配置JAX环境
由于使用JAX的一大优势是可以移植到TPU上面进行使用,我们选择WSL作为学习和实验环境。
安装VSCode
点击Download for Windows下载
配置WSL
首先确认你的电脑系统版本是Windows 11,然后执行下面指令,按照引导步骤即可
wsl --install
新建文件夹并开启VSCode
mkdir jax-test && cd jax-test
安装JAX
如果只使用CPU进行调试执行
pip install jax
如果使用GPU进行调试执行
pip install -U "jax[cuda12]"