Сегодня мы выкладываем в опенсорс наш новый инструмент — алгоритм YaFSDP, который помогает существенно ускорить процесс обучения больших языковых моделей.
В этой статье мы расскажем о том, как можно организовать обучение больших языковых моделей на кластере и какие проблемы при этом возникают. Рассмотрим альтернативные методы ZeRo и FSDP, которые помогают организовать этот процесс. И объясним, чем YaFSDP отличается от них.
Проблемы при обучении на нескольких GPU
В чём сложность распределённого обучения языковых моделей на кластере из большого числа GPU? Чтобы ответить на этот вопрос, для начала рассмотрим обучение на одной GPU:
-
Мы делаем проход по сети для нового батча данных, считаем лосс.
-
Делаем обратное распространение ошибки.
-
Оптимизатор обновляет состояния оптимизатора и веса модели.
Что поменяется с введением нескольких GPU? Рассмотрим самую простую реализацию распределённого обучения на четырёх GPU (Distributed Data Parallelism):
Что поменялось:
-
Теперь каждая GPU обрабатывает свой кусок большего батча данных: мы можем увеличить батч в 4 раза при той же нагрузке на память.
-
Нам необходимо синхронизировать работу разных GPU. Для этого мы усредняем градиенты через all_reduce между GPU, чтобы веса на разных картах обновлялись синхронно. all_reduce — одна из наиболее быстрых реализаций, она есть в библиотеке коллективных коммуникаций NCCL и поддержана в torch.distributed.
Вспомним типы коммуникаций (мы к ним будем возвращаться часто):
Какие тут есть проблемы:
-
Коммуникации all_reduce требуют два раза переслать столько градиентов, сколько у нас параметров в сети. Например, для Llama 70B при суммировании градиентов в fp16 нам потребуется переслать между картами 280 GB на каждую итерацию. В современных кластерах это займёт существенное время.
-
Веса, градиенты и состояния оптимизатора будут дублироваться между картами. Для той же Llama 70B и оптимизатора Adam в Mixed precision потребуется более 1 TB памяти при обычной вместимости GPU в 80 GB.
Таким образом, из‑за колоссальной избыточности по памяти мы не сможем вместить даже относительно небольшую модель в память GPU, а наше обучение замедлится за счёт дополнительных коммуникаций.
Можно ли как‑то решить эти проблемы? Да, есть несколько таких решений, из которых можно выделить группу методов: Data Parallelism с полным шардированием весов, градиентов и состояний оптимизатора. Для torch есть три таких метода: ZeRO, FSDP и разработанный нами YaFSDP.
Заход ZeRO
В 2019 году команда разработки DeepSpeed из Microsoft опубликовала статью ZeRO: Memory Optimizations Toward Training Trillion Parameter Models В статье исследователи предложили концепцию ZeRO (Zero Redundancy Optimizer), которая позволила критически снизить нагрузку на память за счёт полного разбиения весов, градиентов и состояний оптимизатора между всеми GPU:
Предложенное разбиение виртуальное. Во время forward и backward модель работает со всеми параметрами так, будто никакого разбиения и нет. Но как такое можно организовать? Ответ: через асинхронную подгрузку параметров.
Реализация ZeRO в библиотеке DeepSpeed при обучении на N GPU такая:
-
Каждый параметр мы разбиваем на N частей, каждую часть храним на своём процессе.
-
Во время первой итерации до шага оптимизатора мы запоминаем, в каком порядке используются параметры.
-
Мы выделяем место под собранные параметры. На каждой следующей итерации на forward и backward мы асинхронно подгружаем параметры через all_gather. Когда какой‑то модуль завершил свою работу, мы освобождаем память под параметры этого модуля и запускаем загрузку следующих параметров. Вычисления идут параллельно.
-
На backward мы делаем reduce_scatter сразу, как посчитаны градиенты.
-
На шаге оптимизатора мы обновляем только те веса и состояния оптимизатора, которые принадлежат нашей GPU. Кстати, это ускоряет сам шаг оптимизатора в N раз!
Как работал бы forward в ZeRO, если бы на каждый слой у нас был ровно один тензор параметров:
Таким образом, для одной GPU схема обучения будет выглядеть так:
Что мы видим:
-
Коммуникации стали асинхронными. Если коммуникации производятся быстрее вычислений, то они не должны мешать.
-
Коммуникаций стало сильно больше.
-
Шаг оптимизатора стал работать очень быстро.
Идея и реализация от DeepSpeed позволила ускорить многие обучения при сильно сниженной нагрузке на память. Но и тут есть свои минусы:
-
Очень большое количество багов и проблемных мест в коде DeepSpeed.
-
Неэффективные коммуникации на больших кластерах.
-
У всех коллективных коммуникаций в NCCL есть особенность: чем меньше данных мы пересылаем за раз, тем менее эффективно работают коммуникации.
-
Пусть у нас N GPU, тогда при all_gather за раз мы будем пересылать не более 1/N от всего числа параметров. При увеличении N эффективность пересылок будет снижаться.
-
В DeepSpeed мы делаем all_gather и reduce_scatter от каждого тензора параметров. В Llama 70B типовой размер такого тензора, 8192 × 8192. При обучении на 1024 картах за раз будет пересылаться не более 128 KB, что не позволит загрузить сеть достаточно.
-
В DeepSpeed попытались разобраться с этой проблемой за счёт одновременной сборки сразу большого количества тензоров. К сожалению, такой подход либо порождает много медленных операций с памятью GPU, либо для этого нужно делать свою реализацию всех коммуникаций.
-
Это приводит к примерно такой картине в профиле (stream 7 — вычисления, stream 24 — коммуникации):
Таким образом, DeepSpeed сильно замедлял обучение при увеличении размера кластера. Но можно ли лучше?
Эпоха FSDP
Да, можно. Это доказали разработчики подхода Fully Sharded Data Parallelism или FSDP. На данный момент он встроен непосредственно в torch, активно поддерживается и им пользуется много разработчиков.
Чем же хорош новый подход? У него несколько преимуществ:
-
Возможность объединять множество параметров слоя в один FlattenParameter, который будет разбиваться при шардировании. Это позволяет производить быстрые коллективные коммуникации с действительно большими пересылками.
-
Более удобный интерфейс:
-
DeepSpeed преобразовывает буквально весь пайплайн обучения: изменяет модель и оптимизатор.
-
FSDP влияет только на модель. Оптимизатору он выдаёт только веса и градиенты своей части параметров, что позволяет использовать произвольный оптимизатор без дополнительной настройки.
-
FSDP не обладает таким количеством багов, как DeepSpeed, по крайней мере в частотных сценариях.
-
-
Динамические графы: ZeRO требует, чтобы модули всегда вызывались в строго определённом порядке, иначе он не будет понимать, когда какой параметр загружать. В FSDP можно использовать динамические графы.
Но при всех этих преимуществах есть и проблемы, с которыми мы столкнулись:
-
FSDP динамически выделяет память под слои и иногда требует её намного больше, чем необходимо.
-
На backward в коммуникациях возникает эффект, который мы назвали «Ёлочкой». Проще пояснить это на профиле:
Что тут происходит? Перед коммуникацией reduce_scatter (синяя) происходит много подготовительных вычислений (маленькие операции под коммуникациями). Эти вычисления идут вперемешку с основным потоком вычислений, поэтому коммуникации сильно тормозятся. Как итог между коммуникациями появляется большой просвет, из‑за которого на такой же промежуток тормозят и вычисления.
Мы попробовали побороть эти проблемы. В результате появился наше решение YaFSDP.
YaFSDP
В этой части мы расскажем о нашей разработке и немного о том, как вообще можно реализовывать такие штуки. Будет много ссылок на код. Если хотите узнать о том, как можно продвинуто использовать torch, эта часть для вас.
Итак: мы хотим гарантировать, что памяти используется ровно столько, сколько реально нужно, а коммуникации ничего не тормозит.
Зачем экономить память
Да и правда: зачем? Давайте рассмотрим, кто потребляет память в обучении:
-
Веса, градиенты, состояния оптимизатора — все они зависят от количества процессов, и потребляемая память стремится к нулю при увеличении количества процессов.
-
Буферы занимают константную память.
-
Активации зависят от размера модели и количества токенов на процесс.
Выходит, единственное, что занимает память, это активации. И это действительно так! Для Llama 2 70B с батчем в 8192 токенов и flash 2 на хранение активаций уйдёт более 110 GB (это число можно существенно уменьшить, но это уже другая история).
Можно существенно уменьшить нагрузку на память, если использовать чекпоинт активаций: на forward мы сохраняем только активации между трансформерными блоками, а на backward заново перевычисляем их. И это сильно экономит память. Для хранения активаций потребуется всего 5 GB. Но избыточные вычисления будут занимать 25% времени для всего обучения.
Поэтому имеет смысл освобождать память для того, чтобы не делать чекпоинт активаций для как можно большего числа слоёв.
Кроме того, есть способы уменьшить и объём коммуникаций в случае, если есть лишняя свободная память.
Буферы
По аналогии с FSDP, мы будем шардировать не параметры по отдельности, а слои целиком, чтобы избежать неэффективных коммуникаций и большого количества операций копирования. Чтобы контролировать использование памяти, мы решили сразу выделить буферы под все необходимые данные, чтобы не отдавать их на откуп аллокатору памяти torch. Схема такая: Для хранения промежуточных весов и градиентов мы заранее выделим два буфера, каждый нечётный слой будет использовать первый буфер, каждый чётный — второй.
Таким образом, у нас разные слои будут физически смотреть в одни и те же фрагменты памяти. Если слои обладают одинаковой структурой, то они всегда будут идентичны! Главное в тот момент, когда требуется слой i, гарантировать, что в буфере находятся веса слоя i. Все параметры в моменте будут смотреть на свой кусок памяти в буфере:
В остальном всё также, как и в FSDP. Нам потребуются:
-
буферы для хранения шардов и градиентов в fp32 для оптимизатора из‑за Mixed Precision;
-
буфер для хранения шарда весов в половинной точности, в нашем случае, в bf16.
Теперь нам нужно выстроить коммуникации так, чтобы гарантировать:
-
Что forward/backward слоя не начнётся до тех пор, пока в его буфере не будут собраны веса этого слоя.
-
Мы не будем собирать в буфер нашего слоя другой слой до того, как закончится операция forward/backward на нашем слое.
-
backward на нашем слое не начнётся, пока не завершится reduce_scatter на предыдущем слое, который использует тот же буфер градиентов.
-
reduce_scatter в буфере не начнётся до тех пор, пока не завершится backward соответствующего слоя.
Как это организовать?
Работа со стримами
Для параллельной работы вычислений и коммуникаций стоит использовать примитив cuda — стримы (можно назвать их потоками, но тогда будет много переплетений с другими терминами).
Как устроено взаимодействие между CPU и GPU в torch и других фреймворках? CPU грузит на GPU кернелы (функции, которые выполняются на GPU) в том порядке, в котором они должны выполняться. Чтобы не было простоев из‑за CPU, кернелы грузятся вперёд вычислений и выполняются асинхронно. В рамках одного stream кернелы всегда выполняются именно в том порядке, в котором они были загружены на CPU. Чтобы кернелы могли работать параллельно, мы можем грузить их в разные стримы. Важный момент: если кернелы в разных стримах используют одни и те же ресурсы, они могут не запуститься параллельно (как в случае с «Ёлочкой»), или оба кернела могут сильно замедлиться.
Для коммуникации между стримами можно использовать примитив event (event = torch.cuda.Event()
в torch). Мы можем записать event в стрим (event.record(stream)
), тогда он встанет в конец стрима чем‑то вроде микрокернела. В другом стриме мы можем подождать этого event (event.wait(another_stream)
), тогда этот стрим остановится до тех пор, пока первый стрим не дойдёт до этого event.
Для организации нашей затеи нам достаточно два стрима: стрим вычислений и стрим коммуникаций. Так можно выстроить выполнение так, чтобы гарантировать выполнение условий 1 и 2, описанных выше:
На данной картинке плотной границей очерчены event.record()
, а пунктиром — event.wait()
. Можно увидеть, что выполнение forward третьего слоя не начнётся, пока мы не завершим all_gather того же слоя (условие 1). А all_gather третьего слоя не начнётся, пока не завершится forward первого слоя, который использует тот же буфер (условие 2). Так как в такой схеме нет циклов, deadlock невозможен.
Как это может быть реализовано в torch? Для таких операций удобно использовать forward_pre_hook
— код на CPU, который выполняется перед forward, а также forward_hook
— код, который выполняется после:
Таким образом, всю подготовительную работу мы вынесли в forward_pre_hook
. Подробнее про хуки можно прочитать в документации.
Что отличается в backward? Там появляется необходимость делать усреднение градиентов между процессами:
По аналогии с forward_hook
и forward_pre_hook
мы могли бы использовать backward_hook
и backward_pre_hook
:
Но здесь есть подвох: если backward_pre_hook
отработает именно так, как и ожидается, то backward_hook
может вести себя неожиданно:
-
Если на входе в модуль есть хотя бы один тензор, который не пропускает градиенты, например, маска attention,
backward_hook
запустится до выполнения backward. -
Даже если все входы в модуль пропускают градиент, нет никаких гарантий, что
backward_hook
запустится тогда, когда будут вычислены.grad всех тензоров.
Итак, исходная реализация backward_hook нас не устраивает, нужен более надёжный вариант.
Надёжный backward_hook
Почему так? Давайте посмотрим на граф вычисления градиента для довольно простых операций:
Мы применяем ко входу два независимых линейных слоя с весами Weight 1 и Weight 2 и перемножаем их выходы.
Так будет выглядеть граф вычисления градиентов:
Что можно увидеть в это графе: Все операции имеют свои *Backward узлы в этом графе, для всех весов в графе появился узел GradAccum, в котором обновляется.grad параметра, который затем будет использоваться YaFSDP для работы с градиентом.
На что стоит обратить внимание: GradAccum находится в листах этого графа. Интересно, что torch не даёт никаких гарантий на порядок обхода графа. GradAccum одного из весов может быть выполнен уже после того, как градиент выйдет из этого блока. Выполнение графа в torch не детерминировано и может отличаться от итерации к итерации.
Как гарантировать, что градиенты весов будут подсчитаны перед запуском backward другого слоя? Ведь без этого знания мы не сможем запустить reduce_scatter — он отработает только на части подсчитанных градиентов. Мы пришли к такой схеме:
Перед каждым forward мы дополнительно делаем такие трюки:
-
Пропускаем все входы и буферы весов через простой
torch.autograd.Function
GateGradFlow который просто пропускает через себя неизменённые входы и градиенты. -
Затем прописываем в нашем слое на место параметров псевдопараметры, которые ссылаются на различные фрагменты буфера весов через нашу функцию Narrow.
Что происходит на backward:
Градиент для параметров может быть записан двумя способами:
-
В стандартном случае мы запишем или добавим градиент в реализации backward Narrow, намного раньше, чем мы дойдём до GradAccum буферов.
-
Мы можем написать кастомную функцию для слоёв, в которых мы будем прописывать градиенты без выделения дополнительного тензора для экономии памяти. Тогда Narrow получит None вместо градиента и не будет ничего делать.
Таким образом, мы гарантируем, что:
-
Все градиенты будут прописаны в буфер градиентов до выполнения
backward GateGradFlow
. -
Градиенты не потекут Inputs, а затем в backward следующих слоёв до выполнения
backward GateGradFlow
.
Следовательно, логичное место для вызова backward_hook
как раз находится в backward GateGradFlow! Все градиенты весов подсчитаны и записаны, backward других слоёв ещё не начался. Теперь у нас есть всё, что нужно для параллельных коммуникаций и вычислений в backward.
Борьба с Ёлочкой
Проблема Ёлочки в том, что перед reduce_scatter
в потоке коммуникаций происходит несколько вычислительных операций: копирование градиентов в разные буферы, predivide градиентов для предотвращения переполнения в fp16, который сейчас используется редко, и ещё несколько операций.
Что мы сделали:
-
Сделали отдельную обработку RMSNorm/LayerNorm. Особенность в том, что они должны немного по‑другому обрабатываться в оптимизаторе, поэтому их стоит выделять в отдельную группу. Таких весов мало, поэтому мы их разово собираем в начале итерации и только в конце усредняем градиенты. Это позволило избавиться от лишних копирований в Ёлочке.
-
Вынесли predivide в самый конец backward, так как при
reduce_scatter
в bf16 или fp32 риска переполнения нет.
Как итог, Ёлочки не стало, что сильно уменьшило простой в вычислениях:
Ограничения
Метод YaFSDP даёт существенный выигрыш в производительности и потреблении памяти. Однако есть и некоторые ограничения:
-
Наибольшая производительность будет достигаться только в том случае, если слои будут вызываться так, чтобы соответствующие им буферы чередовались.
-
Мы явно учитываем, что с точки зрения оптимизатора может быть только одна группа весов с большим числом параметров.
Замеры
Итоговое ускорение YaFSDP на Llama 2 и Llama 3 по сравнению с FSDP:
Итоговое ускорение в сценариях с небольшим батчем превышает 20%, что делает YaFSDP удобным инструментом для дообучения моделей.
А в претрейнах Яндекса внедрение YaFSDP вместе с другими оптимизациями памяти в итоге дало ускорение в 45%.
Теперь YaFSDP можете использовать и вы! Мы выложили его в открытый доступ, пишите в комментариях о вашем опыте, мы готовы рассмотреть возможные пул‑реквесты.
А ещё добавляйтесь в телеграм‑канал «Душный NLP» с разборами интересных статей от NLP‑специалистов Яндекса.
Автор: Михаил Хрущев