Обучение GigaChat с контекстом в сотни тысяч токенов

в 14:38, , рубрики: GigaChat, llm, токенизация
Обучение GigaChat с контекстом в сотни тысяч токенов - 1

Помните фразу «640 килобайт памяти хватит всем»? Запросы человечества вечно растут, а индустрии надо поспевать.

Вот и с языковыми моделями так. Ещё недавно мы все удивлялись тому, на что они стали способны. А теперь нам этого мало: «ну хорошо, а может модель в диалоге учитывать то, что я сказал сотни реплик назад?»

Весной на нашей конференции I'ML Евгений Косарев (SberDevices) рассказал о том, как к увеличению контекста подошли при работе над GigaChat. А сейчас мы публикуем текстовую расшифровку его доклада. Ссылки на его видеозапись тоже прилагаем: YouTube, VK Видео.

План доклада

GigaChat — это популярная языковая модель на русском языке. Используется в десятках продуктов Сбера — SaluteJazz (бывший SberJazz), SaluteBot, умных колонках, телевизорах, голосовом ассистенте. Вне Сбера GigaChat тоже весьма полезен, тысячи клиентов пользуются GigaChat API.

Перед подготовкой доклада я у GigaChat и спросил, как рассказать о длинном языковом контексте на конференции. И AI предложил план доклада, который мне понравился:

  • преимущества длинного контекста;

  • технологии, которые помогают его реализовывать; 

  • перспективы развития.

LLaMa

Десять лет назад достаточно было следующего определения. Языковая модель — некоторый алгоритм, который способен разумно продолжать текст. Так можно было описать и простую систему T9.

Время шло, и наше понимание качественного ответа начало сильно меняться. Теперь мы хотим, чтобы языковые модели решали математические задачи, писали код, классно разговаривали и многое-многое другое. И в начале 2024 года мы пришли к модели LLaMa-3.

На момент подготовки доклада среди языковых моделей в открытом доступе эта была самой сильной по самым разным бенчмаркам: MMLU и GPQA (это бенчмарки на 57 областей общих знаний), HumanEval (оценивает способность писать код), а также GSM-8K и MATH (проверяют способность решать математические задачи на школьном уровне и сложнее).

Прогресс семейства LLaMa-моделей впечатляет. Примерно за год на сложных бенчмарках результат улучшился в два раза, а на некоторых бенчмарках даже в десять раз. Вроде бы все хорошо, но чего-то не хватает.

Контекст в LLM

Расскажу о том, что такое контекст и как его понимают языковые модели. Контекст языковой модели обрабатывают с помощью токенов — единиц текста. Раньше это были буквы или целые слова, сейчас это нечто более сложное. Текст разбивается на подстроки с номерами, которыми оперирует языковая модель.

Обучение GigaChat с контекстом в сотни тысяч токенов - 2

Сейчас большинство языковых моделей имеет контекст 8 тысяч токенов. Если перевести на понятные цифры, то это четыре страницы текста А4. Получается, что мы можем оперировать таким объемом довольно хорошо и решать важные задачи. Это много или мало?

Сейчас, когда языковые модели стали хороши, к ним всё чаще обращаются с задачами, требующими больше. Бизнес может задавать вопрос, связанный с документом более 20 страниц (это может быть статья или юридический документ). Например, в документе содержится юридический акт; что он регулирует? Это написано где-то в тексте. Напомню, наша модель обрабатывает четыре страницы А4. Что делать с остальными? 

Решения

Модификация блока внимания. Сейчас почти во всех языковых моделях присутствует блок внимания. Его задача — анализировать связь токенов друг с другом. Через этот блок модель понимает контекст. Один из вариантов — упростить этот блок, придумать эвристику. 

Эвристика CoLT5 работает так. Допустим, к нам приходит 100 тысяч токенов, а контекст — 8 тысяч токенов. Мы не можем обработать весь контекст, но можем разделить токены на важные и неважные. В контекстном окне обработать все важные токены и посмотреть связи между ними, а неважные токены отправить в более легкую ветку.

Обучение GigaChat с контекстом в сотни тысяч токенов - 3

В легкой ветке предлагается считать связи не между всеми токенами, а между соседними тремя. Внимание становится считать довольно дешево, и за два прохода Attention в тяжелую ветку через легкую мы можем обработать весь контекст. 

Довольно очевидный недостаток в том, что есть сущность, которая определяет важность токена. А определить это в произвольном тексте и построить такой роутер — довольно сложная задача. Происходит много ошибок в определении важности. 

Во-вторых, легкая ветка внимания все же плохо обрабатывает контекст и теряет много информации. 

Sliding Window Attention. Эта модификация внимания существует далеко не первый год, а в последний раз его популяризировала команда, которая сделала модель Mistral. До появления LLaMa-3 эта языковая модель считалась самой сильной в объеме 7 миллиардов параметров. Авторы декларировали следующее.

Во-первых, давайте вспомним, что мы обучаем не только большую языковую модель, но и глубокую нейронную сеть. Воспользуемся тем, что у нас есть много слоев, и будем передавать информацию последовательно между слоями:

Обучение GigaChat с контекстом в сотни тысяч токенов - 4

Если у нас есть токен под номером 20 000 и контекстное окно в 4000, то он смотрит на последние 4000 токенов. Анализируются взаимосвязи токенов под номерами от 16 000 до 20 000. Поднимаясь вверх по контексту, мы получим связь первого токена с последним. Вроде бы все хорошо. Мы имеем полное внимание в окне и с помощью архитектуры языковой модели получили агрегацию информации с нижних слоев. Недостатком можно назвать потерю информации через слои, но он не так очевиден по сравнению с тем, что было в COLT5. Как понять, есть ли недостаток? Для этого введем бенчмарк.

PassKey. Он исследует иголку в стоге сена. Есть много бенчмарков длинного контекста, но этот — один из самых простых и понятных. Его используют все, когда говорят о длинном контексте. 

Допустим, наша модель понимает контекст в сотни тысяч токенов. Возьмем какой-то факт и спрячем его на определенной глубине контекста. Например, берем глубину контекста 8000 и где-то на позиции в 4000 поместим факт вроде даты рождения или названия — то, что можно достать по запросу. Весь остальной контекст вне этой глубины мы замостим случайным шумом либо длинными текстами вроде «Войны и мира». Элементарная задача с точки зрения определения качества. 

У нас есть большой контекст, факт в середине, начале или конце этого контекста, и наша задача — достать его. Если модель справляется, то она получает 1, если не справляется, то 0. Бинарный формат ответа. По каждой длине контекста мы усредняем все показатели и получаем ответ, в каком проценте случаев модель может достать этот факт. Так становится ясно, хорошо она обрабатывает контекст или нет. 

Для Mistral 7B Sliding Window Attention результаты грустные.

Обучение GigaChat с контекстом в сотни тысяч токенов - 5

В базовом контекстном окне 4000 токенов она справляется с тем, чтобы достать результаты в 100% случаев. А в контексте 8000 токенов — уже в половине случаев. Авторы обучали модель с контекстом 32 000 токенов. Декларировалось, что это контекст, который она хорошо понимает. Но там информация теряется в 80% случаев. Модель не понимает длинный контекст. Предполагаю, что именно в связи с этим авторы переобучили модель. И получился Mistral 7B v.02 с честным механизмом внимания, то есть без всяких оптимизаций. И авторы честно получают 100% качества на всем контексте.

Почему я считаю, что авторы переобучали модель ради этого? Потому что мы проводили замеры разных версий Mistral, и по основным бенчмаркам качества они почти не отличаются. Отличается только понимание контекста. 

RAG, или retrieval augmented generation. Еще одно почти хорошее решение. К нашей языковой модели мы добавляем поисковик. Он может искать информацию как в интернете, так и по документам. Если у нас есть большой документ, поисковик говорит, что по запросу нашлась информация в пятом абзаце. Обычно он маленький. Эту информацию мы можем добавить в контекст языковой модели. Она даст нам красивый верный ответ. 

Преимущества:

  • Поиск намного дешевле.

  • «Бесконечный» контекст.

Недостатки:

  • Качество поисковика.

  • Ограниченность информации.

Вообще обучение языковых моделей дорого обходится, поэтому она обучается не каждую неделю и не каждый месяц. Поэтому мы обучаем модель, потом в ней хранятся знания, которые со временем устаревают. Вот чтобы они не устаревали, мы и добавляем RAG.


У всех перечисленных подходов есть недостатки. Не получается сделать большой контекст, чтобы модель понимала все факты и могла оперировать большим объемом данных.

Поэтому нужно честно увеличивать контекстное окно языковой модели до 130 000 токенов или больше. Почему так не делают все? Проблема в том, что это слишком дорого обходится:

  1. Активации тензоров не помещаются на GPU.

  2. Замедление вычислений.

Попробуем оптимизировать вычисления.

Оптимизация вычислений

Для начала определим, что мы собираемся оптимизировать. У нас есть нейронные сети, которые состоят из слоев. Очень приблизительно работа внутреннего слоя выглядит так.

Обучение GigaChat с контекстом в сотни тысяч токенов - 6

 На вход слоя приходит матрица М1 — [B, T, D]. Она имеет размерность. B — размер батча, Т — число токенов в контексте, D —внутреннее представление. Последнее зависит от размера модели: чем больше модель, тем больше параметр D. Также есть некоторый вес слоя [D,H]. Происходит матричное произведение (активация) — то, что хранится на видеокарте и что может в нее не поместиться. Как мы видим, для небольших матриц M1 и M2 может получиться довольно большая активация. Нужно что-то сделать, чтобы иметь возможность обучать большие языковые модели с большим контекстом. 

Для проведения оптимизаций нужно выбрать архитектуру. В докладе я буду говорить об архитектурах открытых языковых моделей, таких как LLaMa-1,2,3. Это декодер-архитектуры, и они состоят из двух важных блоков.

Обучение GigaChat с контекстом в сотни тысяч токенов - 7

Первый — блок внимания, второй — блок MLP. Обратите внимание, что оптимизируются не только эти блоки. При распределенном обучении всегда оптимизируется слой эмбеддинга, а также распределенный подсчет функции потерь. В слое эмбеддингов все тривиально — если у нас много видеокарт, то мы просто храним веса этого слоя раздельно на видеокартах. После того, как слой собирается для прохождения форварда, есть половина слоя и половина слоя на второй видеокарте, которые общаются между собой, и получается результат. 

С подсчетом функции потерь все просто — подключаем библиотеку, и происходит подсчет. Вариаций пока что нет. 

Интереснее, что делать с механизмом внимания MLP. Он нужен нам для учета взаимосвязи токенов и производит преобразования знаний модели. Сами по себе MLP-блоки довольно большие. Если посчитать, сколько времени занимает вычисление механизмов внимания и сколько времени занимает MLP, то получим следующее.

Обучение GigaChat с контекстом в сотни тысяч токенов - 8

Сложность вычисления механизмов внимания квадратична как по длине контекста, так и по внутренним представлениям, можно сказать, по размерам модели. Память у механизма внимания квадратична всегда. Если мы хотим увеличить контекст в два раза, то наша память увеличивается в четыре раза. У MLP все немного лучше, там зависимость линейная. Перейдем к первым оптимизациям.

Flash-Attention

Эта оптимизация знакома всем, кто занимается MLP. Обучение больших языковых моделей происходит обычно на Python, фреймворке PyTorch. Стоит сказать, что PyTorch — всего лишь интерфейс. Когда мы просим его перемножить матрицы, он отправляет команду на видеокарту, на нее выгружаются тензоры. 

Обучение GigaChat с контекстом в сотни тысяч токенов - 9

Выполняется операция на GPU, результат выполнения снова отправляется на процессор, и так это все работает. Если мы пишем много операций подряд, то затраты на запуск операции велики. Поэтому предлагается написать эффективное вычисление механизма внимания на видеокартах. 

Во-первых, как мы видим слева, мы можем делать не много вызовов операций в PyTorch, а сделать один, но ядро, написанное на CUDA, будет отрабатывать быстрее, чем много разных вызовов. 

Во-вторых, мы все же хотим эффективно вычислять механизмы внимания на видеокарте, поэтому нельзя обойти вниманием устройство видеокарты и вообще устройство нашей системы. Устройство системы выглядит как пирамида. То есть на видеокарте есть быстрая память (оранжевая на изображении), зеленая память — она побольше и чуть помедленнее. А самая медленная память из трех — память процессора. 

Таким образом, мы можем вычислять блок внимания, как показано справа на изображении выше. Для блока внимания нам нужно посчитать попарные взаимосвязи всех токенов в контексте. Если размер контекста — Т, то это Т2 вычислений. Автор утверждает, что нам не нужно хранить всю матрицу в памяти и полностью ее материализовывать. Можем считать ее блочно и отправлять ее на вычисления, а собирать ее уже в других местах. У нас эффективнее расходуется память. В сетапе модели я брал модель 7 миллиардов, и с использованием flash-attention у нас на 50% возрастает скорость вычислений и мы теряем 7 Гб памяти на активации. Это довольно хорошо.

GQA

Суть в следующем. Мы используем механизм многоголового внимания, то есть query key value матрицы разбивается на более мелкие подматрицы. В стандартной реализации мы используем изображение слева.

Обучение GigaChat с контекстом в сотни тысяч токенов - 10

Каждой query-голове соответствует key- и value-голова. Далее с помощью таких вычислений мы определяем их взаимодействие и считаем матрицу внимания. В одной работе утверждается, что нам нет нужды для каждой query-матрицы делать уникальные key- и value-головы. Можно их сгруппировать и уменьшить key-матрицы, value-матрицы без особых потерь в качестве. Эта оптимизация дает нам уменьшение памяти на гигабайт и прирост скорости на десять процентов. 

Еще attention очень важен для инференса языковых моделей. Помимо того чтобы обучить языковую модель, нужно еще и использовать ее. Key value-кэш — это некоторая оптимизация пересчета, чтобы инференс языковой модели был быстрым. Чем он меньше, тем эффективнее инференс. То есть эта оптимизация также влияет на скорость работы языковой модели. 

Итак, мы смогли ускорить attention. Что дальше? Здесь на ум приходят две вещи: tensor parallel и sequence parallel.

Tensor Parallel

Произведение матриц происходит на видеокартах. Предположим, что на одну видеокарту произведение матриц не влезает. Давайте использовать две видеокарты или в общем случае — n видеокарт.

Обучение GigaChat с контекстом в сотни тысяч токенов - 11

У матрицы две важные размерности: размерность внутреннего представления и размерность контекста. Таким образом, мы можем разделить эти матрицы между двумя видеокартами. Видеокарта №0 имеет в себе и оперирует первой половиной матрицы, видеокарта №1 — второй половиной матрицы. С помощью общения между видеокартами можно сделать так, что все активации становятся меньше в n раз и хранимая память становится меньше в n раз. Также в n раз у вас возрастает количество видеокарт. Это позволяет обучать большие модели с большим контекстом, но на большем количестве карт.

Что в итоге. Tensor parallel — это разделение по размерности модели. Он сокращает память на 50%, а скорость увеличивает на 90%. Почему не на 100%? Чтобы производить такие вещи, нужно делать коммуникации между видеокартами. На этапах forward и backward видеокарты общаются и обмениваются результатами вычислений. Таким образом, для эффективной реализации ускорение должно быть примерно на 90%. 

Sequence parallel

Снова покажу вам страшную картинку.

Обучение GigaChat с контекстом в сотни тысяч токенов - 12

Sequence parallel имеет схожие показатели. Память в два раза меньше, скорость опять же 90%. В итоге в обоих подходах мы делим матрицу пополам. Что лучше, зачем и как это использовать?

Tensor vs Sequence Parallel

Была статья, которой на момент доклада уже два или три года. В ней подсчитали, сколько памяти тратится на первый и второй подходы. Авторы пришли к вот таким формулам:

Обучение GigaChat с контекстом в сотни тысяч токенов - 13

Можно заметить, что есть общие члены и есть совершенно отличающиеся, которые уменьшают память в n раз. Поэтому нужно выбирать один из двух подходов по ситуации. Если пересчитать все то же самое для LLaMa-архитектуры, то получим следующее.

Tensor vs Sequence Parallel

Tensor vs Sequence Parallel

Все подсчитано для B = 1, можно пересчитать и для большего количества. Механизм внимания выгоден в том случае, если размер контекста превышает внутреннее представление в 8 раз. В семимиллиардных LLaMa-моделях это 4000 токенов. То есть нам выгоднее использовать sequence parallel при контексте 32 000 токенов. Для MLP мы тоже можем это посчитать. Превышение в 16 раз. 

Если контекст составляет 64 000 токенов, то нам гарантированно выгоднее использовать sequence parallel. Иногда можно и комбинировать. 

Итого, так как мы хотим учить экстремально большие контексты, то обсудим, как можно реализовать sequence parallel.

Реализация SP — all2all

Ленивая реализация, предложенная в дипспиде. Она вообще не меняет код. У вас есть код обучения языковой модели. Вам нужно сделать всего две вещи. Во-первых, до входа в языковую модель разделить токены. Зачем это нужно?

Языковые модели, а именно декодер-слои, могут архитектурно обрабатывать произвольный контекст.

Обучение GigaChat с контекстом в сотни тысяч токенов - 15

Сама модель может в себя принять контекст любого размера с точки зрения матриц, с точки зрения вычислений. Таким образом, мы можем для обучения модели, не меняя код, просто направить в одну часть модели первую половину контекста, в другую часть — его вторую половину. Единственное, что мы меняем, — до блока attention и после блока attention применяется all2all-коммуникация. 

Вкратце она делает следующее: если матрица разбита по размерности sequence, то после all2all-коммуникации она полностью собирает sequence (последовательность) и разбирает внутреннюю последовательность модели. Таким образом, до all2all-коммуникаций модель работает в режиме sequence parallel, а внутри модели она работает в режиме tensor parallel. Эта реализация самая простая, почему бы и не попробовать?

Ring Attention

Эта реализация сложнее. Она позволяет обучать экстремальные контексты, и ее можно комбинировать с предыдущей реализацией. Давайте посмотрим, что происходит.

Ring attention наследует идею у flash-attention (а может, и наоборот), что не нужно реализовывать всю матрицу внимания внутри. Напомню, что матрица внимания квадратична относительно длины контекста. Мы можем последовательно вычислять блоки внимания в цикле. Это и предлагается сделать здесь.

Обучение GigaChat с контекстом в сотни тысяч токенов - 16

Есть query-матрица, key value-матрица. Они покрашены в четыре цвета, и за четыре цикла мы можем, передавая query key value между видеокартами, посчитать полный блок внимания. Это довольно хорошо, потому что пока считается блок внимания, видеокарты общаются. Происходит наложение коммуникаций и вычислений. Если мы вспомним all2all-коммуникации, то они вызывают простой модели. Модели параллельно посчитали блок MLP или предыдущие блоки, остановились, пообщались. Посчитали attention, снова остановились, снова пообщались и продолжили считать. Во время общения кластер стоит, и видеокарты ничего не делают. Это плохо.

В ring attention такого нет. Пока видеокарты общаются, производятся вычисления, key value передается матрице на следующий блок. Что здесь можно заметить?

Снизу показаны маски внимания — черные и серые квадраты. Можно заметить, что есть маски внимания полностью квадратные, полностью серые, полностью черные и треугольные. Черная маска — простой видеокарты. У нас внимание каузальное, то есть второй токен зависит от первого, третий — от первого и второго, четвертый — от первых трех, но четвертый не зависит от пятого. То есть токен не видит того, что происходит в будущем. Когда в видеокарте выпадает черная маска, то весь результат ее вычислений выбрасывается и модель считает вхолостую. Авторы striped ring attention предлагают это исправить.

Striped Ring Attention

Предлагается иная нарезка на query и key value в матрице.

Обучение GigaChat с контекстом в сотни тысяч токенов - 17

Блоки получаются немного другие, которые делают матрицы треугольными. Эти страйпы позволяют нам всегда считать либо больше, либо чуть меньше половины механизмов внимания. Нагрузка между видеокартами распределяется более оптимально, и мы уже имеем некоторое ускорение. Но есть нюанс, с которым справляется уже следующее ускорение.

ZigZag Ring Attention

Если striped ring attention — это последовательность, которая написана в первой строчке на изображении ниже, то они бы пообщались и вычислили бы attention так. 

Обучение GigaChat с контекстом в сотни тысяч токенов - 18

ZigZag группирует по два или в общем случае по n токенов друг с другом, то есть последовательно. Такая группировка приводит нас к одному из трех блоков внимания. Либо он закрашен наполовину по горизонтали или вертикали, либо он верхнетреугольный или нижнетреугольный (смотря как посмотреть). То есть видеокарта стоит либо на 50%, либо меньше. 

Кажется, что такая реализация лучше предыдущих. Проверим на бенчмарках.

Сравнение реализаций

Вот таблица сравнения бенчмарков на разных видеокартах.

Обучение GigaChat с контекстом в сотни тысяч токенов - 19

ZigZag ring attention выигрывает по скорости у всех. Можно заметить, что прирост относительно стандартного ring attention у ZigZag довольно сильный, примерно в полтора раза.

В итоге у нас есть all2all и ring attention. Что выгоднее? Чтобы измерить эффект, я взял 64 видеокарты и использовал следующий сетап.

Обучение GigaChat с контекстом в сотни тысяч токенов - 20

Это TP=1, SP=4. И я замерял скорость токенов, которые обрабатывает один GPU, и максимальное количество памяти, которое требуется одной видеокарте. Что мы имеем? 

Ring attention для всех моделей оптимален по памяти. All2all ест память. Собственно, он и не проектировался для того, чтобы быть эффективным по памяти. На маленьких моделях 7 миллиардов all2all выгоднее по скорости, но ненамного. Скорее всего, в случае ring attention не происходит полного перекрытия коммуникаций и вычислений, и может быть, можно оптимизировать сетап и получить разницу ring attention и all2all. Но если не заморачиваться, то для маленьких моделей all2all будет лучше. Он и проще в реализации. 

Если брать большие модели на 30 миллиардов и делать на них контекст в 100 000 токенов, то all2all не запускается в принципе — не помещается на видеокарту. А для меньшего контекста он медленнее практически в два раза и в полтора раза больше потребляет памяти.

В итоге у нас есть технологии. Теперь перейдем к обучению GigaChat.

Обучение GigaChat

Подведем итоги:

  • Имеем формулы подсчета активаций.

  • Знаем, когда tensor parallel выгоднее sequence parallel.

  • Умеем объединять оба режима.

  • Эффективный ring attention нужен для экстремальных контекстов.

Наши результаты получились такими. Наш GigaChat Pro в 29 миллиардов параметров обучен на контексте 128 000 токенов. GigaChat Light в 7 миллиардов обучен на контексте 1 миллион токенов.

Обучение GigaChat с контекстом в сотни тысяч токенов - 21

Чтобы было понятнее, я зашел в интернет, посчитал, сколько примерно в русском слове букв, сколько слов помещается на странице А4, и результаты такие. GigaChat Pro может за раз обрабатывать 64 страницы, а GigaChat Light — 512 страниц А4. В итоге контекст гигантский, можно даже подсовывать ему книги. Зачем это вообще нужно?

Перспективы

Перспективы заключаются в следующем.

  • Код, видео и аудио требуют большого контекста. 

  • Решение бизнес-задач с дополнительной информацией.

  • Персонализация языковых моделей.

  •  Мультиагентное взаимодействие (то есть нескольких умных устройств).

На этом все! Вопросы по докладу мы можете задать в Telegram: evgenijkkk.

Автор: lelyakuznetsova

Источник

* - обязательные к заполнению поля


https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js