- PVSM.RU - https://www.pvsm.ru -

Опенсорсные модели становятся всё объёмнее, поэтому потребность в надёжной инфраструктуре для выполнения крупномасштабного обучения ИИ сегодня как никогда высока. Недавно наша компания выполнила fine-tuning модели LLaMA 3.1 405B на GPU AMD, доказав их способность эффективно справляться с крупномасштабными задачами ИИ. Наш опыт был крайне положительным, и мы с радостью выложили всю свою работу на GitHub [1] в опенсорс.
GPU AMD, и в особенности серия MI300X — это серьёзная альтернатива ИИ-оборудованию NVIDIA, обеспечивающая больше производительности на вложенный доллар. Наша система состояла из одного узла с 8 GPU AMD MI300x, а для fine-tuning мы использовали JAX. В этой статье мы расскажем всю историю fine-tuning LLaMA 405B, в том числе и подробности шардинга параметров и реализации LoRA.
JAX — это мощная библиотека для машинного обучения, объединяющая в себе NumPy-подобные API, автоматическое дифференцирование и компилятор Google XLA. Она имеет великолепные API для параллелизма моделей, идеально подходящие для обучения огромных моделей наподобие LLaMA 3.1 405B.
Чистые функции: JAX мотивирует к написанию чистых функций (если вы хотите компилировать код при помощи JIT), что упрощает компоновку, отладку и чтение кода.
Продвинутый параллелизм: гибкие JIT API библиотеки JAX изначально поддерживают продвинутый параллелизм данных и моделей, что крайне важно для крупномасштабного обучения.
Повышение чистоты кодовых баз: философия дизайна JAX стимулирует к написанию кода, изначально портируемого между аппаратными платформами (CPU, GPU, TPU), что приводит к повышению чистоты и удобства поддержки кодовых баз.
Если вы хотите глубже изучить преимущества JAX перед PyTorch, то рекомендую прочитать пост PyTorch is dead. Long live JAX [2].
При работе с AMD JAX обеспечивает множество преимуществ:
Независимый от оборудования подход: JAX использует компилятор XLA (Accelerated Linear Algebra), компилирующий вычисления в независимое от оборудования промежуточное представление (граф HLO). Это позволяет оптимизировать и эффективно исполнять без модификаций один и тот же код JAX на разных аппаратных бэкендах, включая GPU AMD.
Платформонезависимые оптимизации: компилятор XLA выполняет оптимизации вне зависимости от оборудования, от чего выигрывают все поддерживаемые платформы.
Упрощённая портируемость: при работе с JAX переход с NVIDIA на AMD (или на другое поддерживаемое оборудование) требует лишь минимальных изменений в коде. Это сильно отличает её от PyTorch, который более тесно связан с экосистемой CUDA NVIDIA.
PyTorch часто использует специфичные для CUDA реализации (например, вызовы torch.cuda, scaled_dot_product_attention).
Хотя PyTorch поддерживает другие бэкенды наподобие ROCm для AMD GPU, портирование кода может быть трудной задачей из-за специфичных для NVIDIA путей исполнения кода.
Процесс «избавления от NVIDIA» кода PyTorch может повысить сложность и помешать портируемости.
Настройка JAX на GPU AMD — это очень простой процесс:
# Подтягиваем образ Docker:
docker pull rocm/jax:latest
# Запускаем контейнер Docker:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest
# Верифицируем установку:
python3 -c 'import jax; print(jax.devices())'
Я работал с узлом AMD, состоящим из 8 GPU AMD MI300x. У каждого из MI300x имелось 192 ГБ памяти HBM3. Они крайне хорошо проявляют себя по сравнению с новыми GPU NVIDIA H100. (См. сравнение ниже, источник: TensorWave [3])

При помощи JAX мне удалось обучить модель LLaMA 405B на GPU AMD, добившись впечатляющих результатов.
Мы выполнили fine-tuning LoRA со всеми весами модели и параметрами lora с точностью bfloat16, с LoRA rank = 8 и LoRA alpha = 16:
Размер модели: веса модели LLaMA занимают примерно 800 ГБ VRAM.
Веса LoRA + состояние оптимизатора: приблизительно 400 ГБ VRAM.
Общее использование VRAM: 77% от общего объёма VRAM, примерно 1200 ГБ.
Ограничения: из-за большого размера модели 405B пространство для размеров батчей и длины последовательностей было ограничено. Я использовал размер батчей 16 и длину последовательностей 64.
JIT-компиляция: кроме того, из-за ограничений пространства я не смог запустить JIT-компилируемую версию; вероятно, для этого требуется чуть больше пространства, чем для графа eager mode.
Скорость обучения: примерно 35 токенов в секунду в eager mode JAX (1 этап обучения занимал 30 с)
Эффективность использования памяти: стабильно примерно 70%
Масштабирование: при работе с JAX масштабирование было примерно линейным среди всех 8 GPU.
Ниже представлены показатели GPU, эффективности использования памяти и результаты rocm-smi для 8 GPU на одном этапе обучения прогона fine-tuning:
Использование GPU:

Использование VRAM:

результаты rocm-smi:
|
Устройство |
Температура |
Мощность |
Разделы |
Кулер |
Производительность |
PwrCap |
VRAM% |
GPU% |
|---|---|---|---|---|---|---|---|---|
|
0 |
58,0°C |
232,0 Вт |
NPS1, SPX, 0 |
0% |
auto |
750,0 Вт |
77% |
27% |
|
1 |
58,0°C |
233,0 Вт |
NPS1, SPX, 0 |
0% |
auto |
750,0 Вт |
77% |
25% |
|
2 |
56,0°C |
236,0 Вт |
NPS1, SPX, 0 |
0% |
auto |
750,0 Вт |
77% |
24% |
|
3 |
52,0°C |
228,0 Вт |
NPS1, SPX, 0 |
0% |
auto |
750,0 Вт |
77% |
23% |
|
4 |
59,0°C |
232,0 Вт |
NPS1, SPX, 0 |
0% |
auto |
750,0 Вт |
77% |
22% |
|
5 |
51,0°C |
230,0 Вт |
NPS1, SPX, 0 |
0% |
auto |
750,0 Вт |
77% |
21% |
|
6 |
61,0°C |
235.0W |
NPS1, SPX, 0 |
0% |
auto |
750,0 Вт |
77% |
18% |
|
7 |
56,0°C |
227,0 Вт |
NPS1, SPX, 0 |
0% |
auto |
750,0 Вт |
77% |
18% |
Полную информацию об использовании GPU, VRAM и данные rocm-smi можно найти в нашем репозитории Github [4].

Мы перенесли архитектуру LLaMA 3.1 с PyTorch на JAX. Нашу реализацию можно изучить в репозитории GitHub [1].
Этот перенос открыл для нас новые возможности с точки зрения производительности и масштабируемости.
Для работы с такой огромной моделью, как LLaMA 405B, требуется эффективный шардинг параметров между несколькими устройствами. Ниже мы расскажем, как добились его при помощи JAX.
Чтобы эффективно распределить огромную модель LLaMA 405B на 8 GPU AMD, мы применили функцию меша устройств (device mesh) JAX (codepointer [5]). Меш устройств упорядочивает имеющиеся устройства в многомерную сетку, позволяя нам указывать, как будут разбиты вычисления и данные. В своей системе мы создали меш с формой (1, 8, 1), а именно с такими осями, как параллелизм данных (data parallelism, dp), параллелизм данных с полным шардингом (fully sharded data parallelism, fsdp) и параллелизм модели (model parallelism, mp). Затем мы применили к параметрам модели конкретные правила шардинга, указав для каждого тензора модели способ разбиения его размерностей между осями меша.
DEVICES = jax.devices()
DEVICE_COUNT = len(DEVICES)
DEVICE_MESH = mesh_utils.create_device_mesh((1, 8, 1))
MESH = Mesh(devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))
Шардинг массивов можно визуализировать при помощи jax.debug.visualize_array_sharding. Это невероятно полезно для проверки правильности применения спецификаций шардинга.
Мы определили [6] правила разбиения для различных компонентов модели:
Обычные параметры: разбиты шардингом на 8 GPU.
Например, тензор LM head (lm_head/kernel) имеет две оси, разбитые с PS("fsdp", "mp"); в наше случае это 8, 1, так что мы видим, что по первой оси тензор разбит на 8 GPU.

Нереплицированные параметры:
Параметры без спецификаций шардинга реплицируются между всеми устройствами.
Например, нормы слоёв (attention_norm/kernel и ffn_norm/kernel) используют PS(None).

В процессе загрузки модели мы инкрементно выполняем шардинг весов модели при помощи специальных функций шардинга:
def make_shard_and_gather_fns(partition_specs):
def make_shard_fn(partition_spec):
out_sharding = NamedSharding(mesh, partition_spec)
def shard_fn(tensor):
return jax.device_put(tensor, out_sharding).block_until_ready()
return shard_fn
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
return shard_fns
# Создаём функции шардинга на основании правил разбиения
shard_fns = make_shard_and_gather_fns(partitioning_rules)
Это позволяет нам помещать каждый параметр на соответствующие устройства с указанным шардингом.
Изначально батч обучения создаётся обычным образом. Перед передачей его модели мы выполняем его шардинг между GPU в соответствии со следующим кодом:
train_batch = jax.device_put(
train_batch, NamedSharding(self.mesh, PS("dp", "fsdp"))
)
Здесь мы указываем, что батч обучения должен быть разделён шардингом между осями data parallel ("dp") и fully sharded data parallel ("fsdp"), которые в нашем случае соответствуют 1, 8; это приводит к следующей визуализации:
до шардинга

после вызова jax.device_put

LoRA (Low-Rank Adaptation) снижает количество параметров для обучения, разбивая обновления весов на низкоранговые матрицы. Это особенно полезно для fine-tuning больших моделей.
Ключевые аспекты нашей реализации LoRA:
Раздельная параметризация: мы храним параметры LoRA (lora_a и lora_b) отдельно от параметров основной модели.
Прекращение градиента: мы используем jax.lax.stop_gradient(kernel), чтобы предотвратить обновления весов основной модели.
Эффективное умножение матриц: мы используем lax.dot_general для быстрых матричных операций с контролем точности.
Коэффициент масштабирования: перед добавлением к основным выходным данных выходные данные LoRA масштабируются на (self.lora_alpha / self.lora_rank).
Мы реализовали специальный слой LoRADense, включающий в себя параметры LoRA:
class LoRADense(nn.Module):
features: int
lora_rank: int = 8
lora_alpha: float = 16.0
@nn.compact
def __call__(self, inputs: Any) -> Any:
# Параметр исходного ядра (заморожен)
kernel = self.param('kernel', ...)
y = lax.dot_general(inputs, jax.lax.stop_gradient(kernel), ...)
# Параметры LoRA (обучаемые)
lora_a = self.variable('lora_params', 'lora_a', ..., ...)
lora_b = self.variable('lora_params', 'lora_b', ..., ...)
# Вычисление выходных данных LoRA
lora_output = lax.dot_general(inputs, lora_a.value, ...)
lora_output = lax.dot_general(lora_output, lora_b.value, ...)
# Комбинирование исходных выходных данных с модификациями LoRA
y += (self.lora_alpha / self.lora_rank) * lora_output
return y.astype(self.dtype)
Для эффективного распределения параметров LoRA между устройствами мы с помощью JAX применили особые правила шардинга. Это гарантирует, что параметры LoRA выравниваются с шардингом параметров основной модели, оптимизируя при этом и использование памяти, и эффективность вычислений.
Использованная нами спецификация разбиения: PS("fsdp", "mp").
Визуализация:
Шардинг осей: шардинг параметров lora_a между осями будет выполняться как (8, 1), то есть первая ось разбивается шардингом на 8 устройств (ось fsdp), а вторая ось не разбивается.

На иллюстрации показано, что первая ось разбита шардингом на 8 устройств (ось fsdp), а вторая ось не разбита.
Использованная нами спецификация разбиения: PS("mp", "fsdp").
Визуализация:
Шардинг осей: шардинг параметров lora_b по слоям будет выполняться как (1, 8), то есть вторая ось разбивается шардингом на 8 устройств (ось fsdp), а первая ось не разбивается.

На иллюстрации показано, что вторая ось разбита шардингом на 8 устройств (ось fsdp), разбивая столбцы матрицы.
Такая стратегия шардинга оптимизирует распределение параметров, снижает лишнюю трату ресурсов на коммуникации и повышает параллелизм при обучении. Она гарантирует, что на каждом устройстве содержится только часть параметров LoRA, обеспечивая эффективное масштабирование для больших моделей наподобие LLaMA 405B.
Для оптимизации обучения при fine-tuning модели LLaMA 405B мы вычисляем градиенты только для параметров LoRA, оставляя параметры основной модели замороженными. При таком подходе снижается объём используемой памяти и ускоряется обучение, потому что мы обновляем меньшее количество параметров. Подробности реализации можно посмотреть в нашем репозитории GitHub [1].
В нашем цикле обучения на каждом этапе используется передача батча входных данных через модель. Так как обучаются только параметры LoRA, прогнозы модели и вычисляемая функция потерь зависят только от этих параметров. Затем мы выполняем обратное распространение градиентов с параметрами LoRA. Сосредоточив обновления только на этих параметрах, мы упрощаем процесс обучения, что позволяет эффективно выполнять на нескольких GPU fine-tuning чрезвычайно больших моделей наподобие LLaMA 405B.
Fine-tuning огромной модели LLaMA 3.1 405B на GPU AMD при помощи JAX оставил у нас крайне положительное впечатление. Благодаря использованию мощных возможностей параллелизма JAX и её независящих от оборудования методик я смог эффективно распределить модель по 8 GPU AMD MI300x. Использование шардинга параметров позволило эффективно управлять огромным объёмом параметров модели между устройствами, что обеспечило почти линейную масштабируемость и высокую эффективность использования памяти.
Этот опыт подчёркивает способности GPU AMD в качестве мощной альтернативы оборудованию NVIDIA в крупномасштабном обучении ИИ. Беспроблемная интеграция JAX с поддержкой ROCm упрощает переход и открывает новые возможности для сообщества исследователей и разработчиков ИИ. Делясь своим опытом и кодом, я надеюсь, что это мотивирует других исследовать и применять эти инструменты в собственных крупномасштабных проектах машинного обучения.
Автор: PatientZero
Источник [7]
Сайт-источник PVSM.RU: https://www.pvsm.ru
Путь до страницы источника: https://www.pvsm.ru/amd/397522
Ссылки в тексте:
[1] GitHub: https://github.com/felafax/felafax
[2] PyTorch is dead. Long live JAX: https://neel04.github.io/my-website/blog/pytorch_rant/
[3] TensorWave: https://tensorwave.com/
[4] репозитории Github: https://github.com/felafax/felafax?tab=readme-ov-file#amd-405b-fine-tuning-run
[5] codepointer: https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69
[6] определили: https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/llama_config.py#L44
[7] Источник: https://habr.com/ru/articles/845674/?utm_campaign=845674&utm_source=habrahabr&utm_medium=rss
Нажмите здесь для печати.