Привет, Habr!
Ни для кого не секрет, что доминирующей на данный момент архитектурой в области Deep Learning являются трансформеры. Они произвели настоящий фурор и стали основой для самых известных LLM. На данный момент они используются почти во всех фундаментальных моделях, от тех, что с открытым исходным кодом, таких как Mistral, до закрытых, таких как ChatGPT. Однако, трансформеры не лишены некоторых недостатков. Сегодня мы разберём архитектуру под названием Mamba, которая претендует на то, чтобы стать соперником трансформеров и решить их уязвимости.
Главной проблемой оригинального трансформера является квадратичная вычислительная сложность алгоритма, из-за чего с ростом размера входной последовательности сильно увеличиваются требования к вычислительным мощностям и памяти.
В ответ на этот вызов, в декабре 2023 года была представлена научная работа и архитектура Mamba. С тех пор, прошло уже достаточно много времени, модель доказала свою жизнеспособность, а количество её улучшений и попыток объединения с трансформерами (например: Jamba, FalconMamba) растёт с каждым днём. Одно только количество цитирований оригинальной статьи не даёт пройти мимо работы. Она заинтересовала нас, поскольку недавно мы выпустили GigaCheck, и у нас есть гипотеза, что дискриминатор на кардинально другой архитектуре может показать интересные результаты в задаче определения авторства текста. Поэтому мы подробно разобрались в Mamba, достаточного с много ней поработали и теперь хотим рассказать о ней вам.
План
-
RNN
-
SSM
-
Linear State Space Layer
-
Дискретное представление
-
Рекуррентное представление
-
Свёрточное представление
-
Три представления. А если объединить?
-
-
Mamba (S6)
-
Какую проблему решали авторы?
-
Добавление селективности
-
Сканирование
-
Аппаратная сторона вопроса
-
Mamba-блок
-
Метрики
-
-
Заключение
-
Материалы
RNN
Перед тем, как перейти к целевой архитектуре, кратко опишем предтечи — рекуррентные нейронные сети. Алгоритм основан на том, чтобы прогнать всю входную последовательность через скрытые состояния, которые мы постоянно обновляем. На каждом шаге такая нейросеть получает на вход последовательность слов, обновляет своё состояние и выдаёт первый ответ, после чего передаёт полученный ответ снова на вход, опять обновляя скрытое состояние и генерирует новый ответ. Повторяя так несколько раз, мы можем сгенерировать полноценный ответ.
Однако, у RNN есть недостаток — такие модели имеют тенденцию забывать информацию со временем, поскольку учитывают только одно предыдущее состояние. Но RNN обладают и преимуществами — они линейны (от размера последовательности) по скорости работы и константны по памяти. Но им всё-таки не хватает точности, которой обладают трансформеры.
SSM
Для решения проблемы забывания в RNN, была использована новая архитектура State Space Model. Это модель, используемая для описания представлений скрытых состояний и предсказания того, каким может быть следующее, в зависимости от некоторых входных данных. Математически это можно описать с помощью системы уравнений, которая состоит из уравнения состояния и уравнения выхода.
— уравнение состояния
— уравнение выхода
где,
-
— отображает входную последовательность;
-
— латентное представление состояния;
-
— предсказанная выходная последовательность;
-
— главный параметр (отвечает за то, как мы преобразуем память с течением времени);
-
— параметр преобразования входа;
-
— параметр преобразования выхода;
-
— skip connection.
Уравнение состояния с помощью матриц A и B описывает, как состояние изменяется под влиянием входных данных.
Уравнение состояния с помощью матриц A и B описывает, как состояние изменяется под влиянием входных данных.
Уравнение выхода описывает, как состояние переводится в выход (через матрицу C) и как вход влияет на выход (через матрицу D).
Примечание: Матрицы A, B, C и D являются обучаемыми параметрами
Объединив всё описанное выше:
Таким образом, вся система работает выглядит так:
-
входной сигнал сначала умножается на матрицу B, которая описывает, как входные сигналы влияют на систему;
-
происходит обновление скрытого состояния. Мы умножаем состояние на матрицу A, которая описывает, как связаны все внутренние состояния. Матрица A применяется перед созданием представлений состояний и обновляется после того, как представление было обновлено;
-
затем, мы используем матрицу C, чтобы описать перевод в выходной сигнал;
-
матрица D — это Skip Connection, который используется, для борьбы с затуханием градиентов внутри сети.
Linear State Space Layer
Дискретное представление
Уравнения состояний, описанные выше, имеют непрерывный вид, что является проблемой из-за того, что на вход мы хотели бы подавать дискретные данные. Поэтому, нам необходимо дискретизировать SSM. Для решения используется техника названием «экстраполятор нулевого порядка», которая работает следующим образом: когда мы получаем на вход дискретный сигнал, то удерживаем его значение до тех пор, пока не получим новый.
Время удержания называется шагом дискретизации (∆) и является обучаемым параметром. Он представляет собой разрешение входного сигнала. Математически, экстраполятор нулевого порядка для нашего случая описывается следующим образом:
И даёт нам итоговое выражение в следующем виде:
# b — размер батча
# l — длина входной последовательности
# d_in — размер эмбеддинга входных данных
# n — размер тензоров B и C
# u — входные данные
# Дискретизация матрицы A
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
# Дискретизация матрицы B
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
# Проходимся по всей последовательности и получаем выходы
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
ys.append(y)
y = torch.stack(ys, dim=1)
y = y + u * D
Полученное выражение даёт нам возможность перейти от непрерывной SSM к дискретной SSM. Таким образом, мы переходим от преобразования функция-функция, к преобразованию последовательность — последовательность, . Соответственно, теперь A и B дискретизированные параметры, а вместо времени мы используем конкретные точки .
Рекуррентное представление
Имея дискретное представление мы можем вернуться к идее рекуррентности, реализованной в RNN. Теперь, мы можем брать конкретные входные значения и обрабатывать их с помощью SSM.
Если проиллюстрировать это графически, то получим схему очень похожую на RNN, но с дополнительными обучаемыми параметрами.
Развернём эту схему, чтобы увидеть её более подробно.
Свёрточное представление
Если рассмотреть подробнее, как изменяется наша система и её уравнения с течением времени, то можно увидеть, что система уравнений разрастается для каждого состояния и имеет повторяющиеся паттерны.
Для упрощения этих выражений мы можем собрать все обучаемые параметры в виде ядра свёртки и перейти к свёрточному представлению. По сути, мы переходим к использованию свёрточной нейронной сети (CNN), но для 1-D тензора.
Если записать это всё математически, то получим следующее выражение:
Итоговое уравнение для выходного сигнала теперь выглядит крайне просто и элегантно, а главное, что теперь мы можем использовать преимущество CNN — параллельное обучение на видеокарте, в отличие от обычных RNN. Однако, из-за фиксированного размера ядра свётки, при такой реализации скорость их инференса не такая быстрая и неограниченная, как у RNN.
Небольшой пример, как ядро свёртки работает на практике. Мы можем использовать его для перебора каждого набора слов и вычисления выходного результата. Для того чтобы длина выходной последовательности имела такую же размерность, как и входная добавим паддинг с нулевыми значениями, которые не будут вносить изменения. После этого, производим операцию свёртки с первым элементом:
На следующем шаге мы смещаем ядро свёртки для получения следующего ответа:
Смещаем ещё раз для получения финального результата:
Так мы получили решение, которое поможет нам быстрее обучать SSM, но имеет недостатки при инференсе, можем ли мы их как-то решить?
Три представления. А если объединить?
На данный момент мы имеем три разных представления SSM. А именно: непрерывное, рекуррентное и свёрточное. С непрерывным и его дискретизацией, мы уже разобрались, остались рекуррентное и свёрточное. Они имеют различные преимущества и недостатки.
Рекуррентное:
-
плюс: быстрый инференс за счёт линейной вычислительной сложности;
-
минус: нельзя распределить вычисления на видеокарте при тренировке.
Свёрточное:
-
плюс: можно распределить на видеокартах при тренировке;
-
минус: ограниченная длина контекста.
Однако, мы можем проделать изящный трюк и использовать сразу оба представления, просто для разных задач. При обучении мы можем использовать свёрточное, которое можно распараллелить, а во время инференса — рекуррентное. Таким образом мы получаем архитектуру, которая называется: Linear State-Space Layer (LSSL).
У этих представлений есть одно важное свойство — линейная инвариантность во времени. Она гласит, что параметры SSM (A, B и C), фиксированы для всех временных интервалов. Это означает, что матрицы A, B и C одинаковы для каждого слова, генерируемого SSM. Другими словами, независимо от того, какую последовательность вы задаете SSM, значения A, B и C остаются неизменными. Мы имеем статическое представление, которое не учитывает содержимое.
Mamba (S6)
Мы рассмотрели концепции, которые послужили основной для разработки Mamba. Перейдём же к самой архитектуре и разберёмся, что конкретно улучшили авторы. SSM может использоваться для генерации текстовых последовательностей, но при этом имеет ряд недостатков (о них будет описано ниже), которых мы хотели бы избежать. Поэтому, разберём два основных улучшения, которые были добавлены:
-
алгоритм селективного сканирования, позволяющий модели фильтровать нерелевантную информацию;
-
аппаратно-ориентированный алгоритм, позволяющий эффективно хранить (промежуточные) результаты путем параллельного сканирования, слияния ядер и повторных вычислений.
Вместе все эти улучшения привели к созданию селективной SSM или S6, которую можно использовать, как attention блоки в трансформере, для создания Mamba-блоков. Прежде чем перейти к рассмотрению двух основных улучшений, давайте сначала разберёмся, зачем они нужны.
Какие проблемы решали авторы?
SSM обладает некоторыми недостатками, снижающими её эффективность при решении задачи языкового моделирования. В частности, отсутствует механизм, позволяющий фокусироваться на определенных фрагментах входных данных или игнорировать их. Мы можем проиллюстрировать это на примере двух задач: выборочного копирования и индукции.
В задаче выборочного копирования целью SSM является копирование значимых частей входного сигнала и вывод их по порядку:
Она плохо справляется с этой задачей, поскольку инвариантна к линейному времени. Как мы уже писали, матрицы A, B и C одинаковы для каждого слова, которое генерирует SSM, и, в результате, нет возможности выполнять рассуждения о содержании. Это проблема, поскольку мы хотим, чтобы SSM могла выделять значимые части из входной последовательности.
Вторая задача, с которой модель справляется плохо — это индукция, где целью является воспроизведение паттернов, найденных во входных данных:
Проиллюстрируем это на примере матрицы B. Независимо от того, какой входной сигнал x, матрица B остается неизменной и, следовательно, не зависит от x:
Аналогично, A и C также остаются фиксированными независимо от входных данных. То есть, мы никак не анализируем входные данные, а просто используем их все для обновления скрытого состояния.
Добавление селективности
SSM создаёт компактное состояние, которое эффективно сжимает всю историю в скрытое представление небольшого размера. Однако, по сравнению с трансформером, который не сжимает историю вовсе (используя вместо этого матрицу внимания), это гораздо менее эффективно. Mamba стремится получить лучшее из двух миров. Благодаря сжатой информации из скрытого состояния,, она использует небольшой объём данных, не уступающий при этом трансформеру по количеству знаний:
Делает она это с помощью селективного выбора данных в скрытое состояние. Во входной последовательности часто присутствует информация, которая не имеет большого значения. Например стоп-слова (слова-связки): «или», «но», «затем», «потом», «что», «который» и т. п. Чтобы отбирать важные элементы последовательности, нам нужно, чтобы параметры зависели от входных данных. Для этого рассмотрим размеры входных и выходных данных в SSM-модели во время обучения:
В SSM матрицы A, B и C не зависят от входных данных, поскольку их размеры N и D статичны и не изменяются. Авторы Mamba сделали матрицы B, C и размер шага ∆ изменяемыми и зависимыми от длины входной последовательности (L) и размера батча (B) входных данных.
Это означает, что теперь для каждого входного слова матрицы B и C будут отличаться. Таким образом, мы получаем систему, которая теперь может решать, какие слова оставлять в скрытом состоянии, а какие игнорировать. Работает это следующим образом:
-
Δ (размер шага дискретизации) — управляет балансом между тем, насколько сильно фокусироваться или игнорировать текущий вход;
-
при малом значении Δ, игнорирует конкретные слова и использует предыдущий контекст;
-
при большом значении Δ, фокусируется не на входных словах, а на контексте.
Примечание: размер матрицы A остается неизменным, поскольку мы хотим, чтобы само состояние оставалось статичным, но способ воздействия на него (через B и C) был динамичным.
Теперь, размеры матриц B и C могут изменяться в зависимости от шага дискретизации, который определяет их размер и делает его зависимым от размера (B, L) входных данных.
В коде это реализовано следующим образом:
# b — размер батча
# l — длина входной последовательности
# dt_rank — размер тензора delta
# n — размер тензоров B и C
# x — входные данные
# d_inner — размер эмбеддинга входных данных после расширения
# Передаём входные данные в линейную проекцию
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
# Получаем из проекции delta, B, C
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
Полученный механизм селекции в SSM работает следующим образом:
-
подаём данные на вход;
-
передаём их в линейную проекцию из которой получаем наши матрицы B,C, Δ;
-
дискретизируем матрицы A и B;
-
передаём входные данные в матрицу B, которая теперь отвечает за то, как входные данные должны быть структурированы в пространстве состояний;
-
записываем полученные данные в скрытое состояние и выдаём новый ответ;
-
выдаём ответ с помощью матрицы C, которая теперь отвечает за то, чтобы выбрать релевантную информацию из скрытого состояния.
Таким образом, матрицы B и C управляют тем, как последовательность слов влияет на пространство состояний и как пространство состояний влияет на выход.
Итоговый код Selective SSM:
# Передаём входные данные в линейную проекцию
x_dbl = self.x_proj(x)
# Получаем из проекции delta, B, C
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # [b, l, dt_rank] -> [b, l, d_inner]
(b, l, d_in) = x.shape
n = A.shape[1]
# Дискретизация матрицы A
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
# Дискретизация матицы B
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
# Проходимся по всей последовательности и получаем выводы
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
ys.append(y)
y = torch.stack(ys, dim=1)
y = y + u * D
Сканирование
Поскольку теперь матрицы B и C изменяемые, их нельзя вычислить при обучении с помощью свёрточного представления, так как оно предполагает ядро фиксированного размера. Мы можем использовать только рекуррентное представление и потерять возможность распараллеливания операций. Чтобы его обеспечить, давайте рассмотрим, как мы вычисляем вывод с помощью рекуррентного представления:
Каждое состояние — это сумма предыдущего состояния (умноженного на A) плюс текущий вход (умноженный на B). Распараллеливание в таком случае кажется невозможным, поскольку каждое состояние может быть вычислено только при наличии предыдущего. Однако, Mamba делает это возможным благодаря алгоритму параллельного сканирования, подробнее о том, как он работает можно прочитать здесь. Он предполагает, что порядок выполнения операций не имеет значения благодаря свойству ассоциативности. В результате мы можем вычислять последовательности по частям и итеративно объединять их.
Аппаратная сторона вопроса
У современных GPU есть недостаток в виде ограниченной скорости передачи данных (IO) между небольшой, но высокоэффективной SRAM и большой, но менее эффективной DRAM. Частое копирование информации между SRAM и DRAM становится бутылочным горлышком при создании эффективных алгоритмов
Mamba использует идеи Flash Attention и пытается ограничить количество переходов от DRAM к SRAM и обратно. Для этого несколько операций объединяют в ядро, а именно:
-
шаг дискретизации с размером шага ∆;
-
алгоритм селективного сканирования;
-
умножение на матрицу C.
Полученное ядро позволяет модели не записывать промежуточные результаты и постоянно выполнять вычисления, пока они не закончатся.
Схематично изобразим, какие части Mamba какую память используют:
Вместе, зависимые от входа B, C и алгоритм параллельного сканирования создают алгоритм селективного сканирования.
Mamba-блок
При помощи селективной SSM, которую мы получили, реализуется Mamba-блок. Как и в случае с декодером трансформера, можно складывать несколько Mamba-блоков и использовать их выход в качестве входа для следующего блока Mamba.
Рассмотрим сам блок. Он начинается с линейной проекции для расширения входного эмбеддинга. Затем, применяется свёртка для распараллеливания вычислений при обучении, после чего идёт сам блок с селективной SSM. Затем полученный эмбеддинг объединяется с данными из Skip Connection и передаётся в линейную проекцию для сжатия в исходный размер. Мы можем немного расширить эту схему и изучить, как выглядит работа на конкретном примере:
Поднимемся на уровень повыше и рассмотрим все составляющие архитектуры целиком.
Все составляющие архитектуры Mamba. Сама архитектура. Mamba-слой. Mamba-блок
Она состоит из:
-
входных эмбеддингов, которые мы подаём на вход;
-
набор Mamba-слоёв, которые, в свою очередь, состоят из:
-
слоя нормализации;
-
Mamba-блока;
-
Skip Connection.
-
-
выходного эмбеддинга;
-
Softmax-слоя, для получения итогового распределения вероятностей.
Метрики
В таблице 1 приведены метрики для оценки задач селективного копирования и индукции. Можно увидеть, что архитектура Mamba и конкретно S6-блок, которые предложили создатели, показывает сильно большую точность, чем предыдущие решения. В таблице 2 авторы показывают точность в зависимости от длины контекста, которая со временем падает практически до 0 у всех моделей, кроме Mamba
На картинке ниже приведены сравнения Mamba с трансформером в скорости обучения и инференса в зависимости от длины контекста. Mamba и тут показывает сильный отрыв. До 40 раз быстрее скорость тренировки и до 5 раз быстрее скорость инференса.
Недавно вышла статья, в которой авторы провели масштабное сравнение работы трансформера, Mamba и Mamba-2 (её в данной статье мы рассматривать не будем, но на данный момент уже есть улучшенная версия).
Как мы видим, Mamba показывает сравнимые или даже лучшие результаты, чем трансформер на всех бенчмарках, кроме MMLU:
Авторы попытались понять причину такой низкой метрики на MMLU. Бенчмарк состоит из вопроса, на который модель должна ответить и вариантов ответа (Choices). Ответом должен быть только буква, соответствующая правильному ответу (A, B, C или D). Авторы решили добавить два дополнительных варианта. В первом каждый из вариантов содержит букву и полный текст соответствующего букве ответа, во втором — вариант состоит только из текста правильного ответа (без буквы). Модифицировав таким образом тест, авторы провели замеры снова.
Результаты оказались интереснее: если не давать модели варианты на выбор, а сразу просить дать ответ, то Mamba показывает более высокие метрики, чем трансформеры. Авторы указывают, что такой результат связан тем, что SSM испытывает трудности с тем, чтобы направить знания об ответе в один выходной токен.
Также, они провели замеры на тех же бенчмарках, но в этот раз модели были обучены на 3,5T токенов. И здесь уже Mamba-2 показывает сравнимые с транcформером результаты без всяких ухищрений.
Заключение
Основное различие между трансформером и Mamba заключается в механизме внимания и механизме выбора. Трансформер полностью полагается на механизм внимания, который учитывает весь контекст, а Mamba, напротив, не рассматривает все данные сразу, а избирательно фокусируется на самой важной части входной последовательности для предсказания следующего слова.
Оригинальная архитектура трансформера была представлена в 2017 году и с тех пор появилось не так уж много новых моделей, способных бросить вызов трансформерам. Mamba бросает и предлагает подход, лишённый недостатков соперника. На данный момент уже вышла обновлённая версия Mamba-2, а также появляется множество новых гибридных решений на основе Mamba и трансформера. Модель наделала немало шума, пользуется популярностью в сообществе и весь её потенциал ещё предстоит изучить.
Спасибо, что дочитали до конца!
Если пост вам понравился, то подписывайтесь, на наш Телеграм-канал!
Материалы и прочее
-
Статья про Mamba: https://arxiv.org/abs/2312.00752
-
Репозиторий Mamba: https://github.com/state-spaces/mamba
-
Mamba minimal code: https://github.com/johnma2006/mamba-minimal
-
Разбор Mamba в картинках: https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state#§what-is-a-state-space
-
Ещё хороший разбор Mamba: https://medium.com/@puneetthegde22/mamba-architecture-a-leap-forward-in-sequence-modeling-370dfcbfe44a
-
Статья про трансформер: https://arxiv.org/pdf/1706.03762
-
Сравнение работы Mamba и трансформера: https://arxiv.org/abs/2406.07887
-
Статья про S4: https://arxiv.org/pdf/2111.00396
-
Визуальный разбор S4: https://srush.github.io/annotated-s4/
Автор: syakubson