多节点TPU运行sglang-jax教程

第一步 安装sglang

# 安装uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# 克隆sglang-jax
git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax

# 安装sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e "python[all]"

第二步 下载权重

下载权重到 /dev/shm 内存盘

uv run python -u -m sgl_jax.launch_server --model-path Qwen/Qwen3-235B-A22B-Instruct-2507  --trust-remote-code  --dist-init-addr=0.0.0.0:10011 --nnodes=1  --tp-size=4 --device=tpu --random-seed=3 --node-rank=0 --mem-fraction-static=0.8 --max-prefill-tokens=262144 --download-dir=/dev/shm --dtype=bfloat16  --skip-server-warmup --host 0.0.0.0 --port 30000

拷贝权重到gcs

gcloud storage cp -r /dev/shm/model--【模型名称】 gs://【gcs目录】

第三步 挂载gcs bucket

gcsfuse 【桶名称】 /【挂载路径】

第四步 在多节点上启动sglang-jax

本教程使用 Qwen/Qwen3-235B-A22B-Instruct-2507

节点0

JAX_COMPILATION_CACHE_DIR=/dev/shm/jit_cache uv run python -u -m sgl_jax.launch_server --model-path Qwen/Qwen3-235B-A22B-Instruct-2507 --trust-remote-code  --dist-init-addr=0.0.0.0:10011 --nnodes=2  --tp-size=8 --device=tpu --random-seed=3 --node-rank=0 --mem-fraction-static=0.8 --max-prefill-tokens=262144 --download-dir=/【挂载目录】 --dtype=bfloat16  --skip-server-warmup --host 0.0.0.0 --port 30000 --page-size 16

节点X(X>0)

JAX_COMPILATION_CACHE_DIR=/dev/shm/jit_cache uv run python -u -m sgl_jax.launch_server --model-path Qwen/Qwen3-235B-A22B-Instruct-2507 --trust-remote-code  --dist-init-addr=【节点0 IP地址】:10011 --nnodes=2  --tp-size=8 --device=tpu --random-seed=3 --node-rank=1 --mem-fraction-static=0.8 --max-prefill-tokens=262144 --download-dir=/【挂载目录】 --dtype=bfloat16  --skip-server-warmup --host 0.0.0.0 --port 30000 --page-size 16

第五步 连接!

OpenAI Compatible 格式连接到 http://【节点0】:30000/v1即可