FLAX框架教程-基础知识篇-1

FLAX框架教程-1

介绍JAX

JAX 是一个面向数组的数值计算库(类似于 NumPy),支持自动微分和即时编译(JIT),以实现高性能的机器学习研究。

  • JAX 提供了一个统一的类似 NumPy 的接口,用于在 CPU、GPU 或 TPU 上运行计算,并支持本地或分布式环境。
  • JAX 通过 Open XLA(一个开源的机器学习编译器生态系统)内置即时编译(JIT)。
  • JAX 函数支持通过自动微分转换高效计算梯度。
  • JAX 函数可以自动向量化,高效地将其映射到表示输入批次的数组上。

介绍FLAX

FLAX即Neural Networks For JAX(基于JAX的神经网络框架)

Flax 为使用 JAX 构建神经网络的研究人员和开发者提供了灵活且完整的用户体验,充分发挥 JAX 的强大功能。

Flax 的核心是 NNX ——一个简化的 API,让用户能够更轻松地创建、检查、调试和分析 JAX 中的神经网络。Flax NNX 对 Python 的引用语义提供了一流支持,使用户能够使用常规的 Python 对象来表达模型。Flax NNX 是此前 Flax Linen API 的演进版本,结合多年的实践经验,带来了更加简单和用户友好的 API。

因此本教程会使用最新的NNX,而不是原有的Linen API

配置JAX环境

由于使用JAX的一大优势是可以移植到TPU上面进行使用,我们选择WSL作为学习和实验环境。

安装VSCode

点击Download for Windows下载

https://code.visualstudio.com

配置WSL

首先确认你的电脑系统版本是Windows 11,然后执行下面指令,按照引导步骤即可

wsl --install

新建文件夹并开启VSCode

mkdir jax-test && cd jax-test

安装JAX

如果只使用CPU进行调试执行

pip install jax

如果使用GPU进行调试执行

pip install -U "jax[cuda12]"

安装FLAX