Могут ли трансформеры «думать»

в 11:56, , рубрики: generalization, gpt, грокинг, задачи рассуждения, обобщение, общий искусственный интеллект, память ИИ, память трансформеров, трансформатор

Недавние исследования показывают, что модели трансформеров способны почти безошибочно решать задачи, требующие нескольких логических шагов. Например, из утверждения А вывести Б и дойти логически до В. И что удивительно, это достигается без использования Chain-of-Thought или особых промптов — только классический GPT-2. Давайте посмотрим, как трансформеры «думают» при решении задач рассуждения, и напишем для этого код с использованием библиотеки Hugging Face.

Могут ли трансформеры «думать» - 1

Привет! Меня зовут Роман и я сейчас получаю Ph.D. в Мюнхене. Одной из тем моей кандидатской является память и гроккинг. Этот концепт мне кажется очень необычным и кое-где противоречащим современным концепциям обучения ИИ. Иногда пишу в свой телеграмм канал про ИИ технологии, в частности про мои исследования в университете, стартапы, кастдевы и вообще все, что мне кажется интересным. В ближайшие месяца мечтаю начать вести эфиры с экспертами из разных областей и задавать им мои глупые вопросы.

Здесь могло быть ваше введение

Сегодня мы рассмотрим необычное явление — гроккинг. Но перед этим давайте определим, что я понимаю под рассуждением и какие задачи будем решать. Возьмём, к примеру, вопрос: «Как соотносится возраст Трампа и Байдена?» Чтобы ответить, вам нужно сначала вспомнить их возраст, а затем сравнить эти значения — это и есть рассуждение в два шага.

"В этой задаче нет ничего сложного, GPT-4 ответит на этот вопрос без проблем", — воскликнете вы. И да, и нет: некоторую часть задач (около 30%, если верить данным статьи) все передовые модели действительно решают, но это очень далеко от полного обобщения! В целом, задачи рассуждений с несколькими шагами остаются сложными для текущих моделей трансформеров, тогда как гроккинг достигает впечатляющих 99% точности.

Почему статья вообще так называется и как это связано с «думанием»? В таких логических задачах тебе необходимо не просто закончить предложение наиболее статистически верным вариантом, но проделать работу:

  1. Понять, с какими объектами предстоит работать

  2. Извлечь из памяти информацию об этих объектах

  3. Выполнить логическую операцию — сравнение или композицию

  4. Сформулировать ответ

Отчасти это напоминает цепочку рассуждений (Chain-of-Thought), когда модель генерирует все доступные данные и логику перед ответом, создавая своего рода «рассуждение, записанное на электронной бумаге». Однако в этой статье и подобных задачах все эти действия происходят в скрытой части трансформера — он выполняет все 4 пункта «в своём скрытом состоянии» и лишь затем выдаёт ответ в виде токенов.

С одной стороны, неявные рассуждения затрудняют интерпретацию модели. С другой — как мы увидим далее — они позволяют достичь обобщения на сложных задачах и добиться почти 100% точности.

Как читать эту статью

Мне хотелось оставить базу статьи достаточно простой, чтобы человек с минимальным пониманием работы нейронных сетей и программирования, мог познакомиться с концепцией. При этом я хотел оставить возможность для более глубокого погружения в сложные процессы. Если вы хотите больше фактов, разобраться в коде или получить детальное объяснение — читайте все спойлеры. Если же вам нужен базовый уровень, спойлеры можно пропустить — основные идеи будут понятны и без них.

Гроккинг

Обобщение в контексте машинного обучения — это способность модели применять полученные знания к новым, ранее не встречавшимся данным или ситуациям.

Представим, что мы собрали данные для ИИ-калькулятора: 500 примеров умножения вида 2 × 3 = 6, 4 × 2 = 8 и так далее. Мы обучили модель на этих данных и теперь проверяем её на тестовой выборке: 10 × 10 = 100, 10 × 11 = 110 — примеры, которых модель не видела во время обучения. Способность к обобщению означает применение алгоритмов, изученных во время тренировки, к новым данным. Если модель не дала ни одного правильного ответа, мы говорим, что обобщения не произошло — вероятно, она просто «запомнила» тренировочный набор данных.

Современные постулаты ИИ говорят, что если у модели высокая метрика на тренировочной выборке, но низкая на тестовой — это переобучение. В таком случае в тренировочную выборку расширяют примерами, применяют регуляризацию или что-то еще. Потому что при наличии огромного объёма тренировочных данных, как и при регуляризации, модель физически не способна всё запомнить — ей приходится искать более эффективные подходы. Но необходима ли эта избыточность? Гроккинг ставит это под сомнение.

Гроккинг— это феномен отложенного обобщения модели на небольших наборах данных.

Метрики на тренировочной и тестовой выборке для задачи модульного деления. Модель через 100 итерацией достигает 100%, и лишь через 1 миллион итераций достигает такого же результата на тестовой.

Метрики на тренировочной и тестовой выборке для задачи модульного деления. Модель через 100 итерацией достигает 100%, и лишь через 1 миллион итераций достигает такого же результата на тестовой.

Гроккинг открыли в каких-то подземных лабораториях OpenAI, откуда исследователей, видимо, не выпускали без паспортов и заставляли придумывать самые странные эксперименты, иначе как они до такого дошли, я понять не могу. Они решили обучить модель на небольшом наборе данных с примером модульного деления — то есть натренировать калькулятор. Что-то у них ничего не шло до тех пор, пока кто-то не оставил включенным компьютер на ночь. Модель всю ночь обучалась, и оказывается, если долго обучать, то модель начинает учиться более обобщённому алгоритму, а не просто запоминать все данные.

Из графика можно увидеть, что отложенность у обобщения довольно значительная: модель запомнила тренировочную выборку за тысячу итераций, а для тестовой потребовался миллион. Если перенести эти масштабы на современные языковые модели, обучение которых занимает недели или месяцы, гроккинг увеличил бы этот период до сотен лет. Неудивительно, что данному явлению уделяется так мало внимания — на данный момент оно не слишком эффективно.

Гроккинг в логических задачах

Хватит с историей, давайте перейдем к статье. В ней рассматриваются два набора данных: сравнение фактов и их композиция. Мы поочередно рассмотрим каждую задачу.

Важно отметить, что мы используем знакомые сущности и отношения между ними лишь для простоты. В оригинале они представлены абстракциями, которые можно заменить чем угодно: от госдолга стран до результатов теста твоих друзей "Кто ты из Смешариков".

Из чего состоят эти наборы? У нас есть атомарные факты и выведенные факты. Атомарные факты — это неоспоримые истины, на которых строится вся наша логика: вода кипит при 100 градусах Цельсия, Трампу 78 лет и так далее. Из этих атомарных фактов мы хотим вывести новую информацию: либо их отношение (одно меньше другого), либо их композицию (из А следует Б, из Б - В, значит из А следует В). Давайте рассмотрим на примере, как это работает.

Сравнение

Сравнение, как мы уже поняли, это когда что-то сравнивают. Например, возраст людей, их зарплату или что угодно. Ключевая идея здесь в том, что у сущностей есть определённые значения, и наша задача — сопоставить эти значения между собой.

Могут ли трансформеры «думать» - 3

В задаче сравнения атомарные факты — это значения атрибутов, например, возраст Трампа и Байдена. Выведенные факты — результат их сравнения.

Атомарные факты:

  1. Возраст Трампа 78 лет

  2. Возраст Байдена 82 года

Выведенный факт:

  1. Байден старше Трампа.

Скрытый текст

Как уже было сказано, данные в датасете представлены в более абстрактном формате.

Атомарный факт:
<e_488><attr_0><1></a>

<e_488> - сущность номер 488
<attr_0> - атрибут номер 0
<1> - значение атрибута у сущности

Выведенный факт:
<attr_2><q><e_621><mask><e_545><attr_2_2></a>

<attr_2><q> Сравнивается аттрибут 2
<e_621><mask><e_545> у сущностей #612 и #545
<attr_2_2> - аттрибут 2 больше у #621

По сути своей, в этой части

<attr_2><q><e_621><mask><e_545>

задается вопрос: “Как соотносится аттрибут 2 у сущности 621 по сравнению с 545?”. Он может быть больше, меньше или равен. Соответственно, это отражается в следующем токене: <attr_2_0>, <attr_2_1>, <attr_2_2> - равен, меньше или больше.

Композиция

Композиция несколько сложнее сравнения. В общем виде, композиция — это своего рода объединение фактов. Например, если из факта 1 следует факт 2, а из факта 2 следует факт 3, то из факта 1 следует факт 3. Связи между сущностями в этом случае называются отношениями.

Например, Миша женат на Свете, а у Светы есть подруга Вика. Отсюда следует, что подруга жены Миши — Вика. Обратите внимание, мы не упоминали Свету напрямую — иначе задача из двухшаговой превратилась бы в одношаговую. Если мы не указываем явно промежуточную сущность, то на первом этапе нам нужно определить, кто жена Миши — Света. И только во втором шаге — кто подруга Светы.

Могут ли трансформеры «думать» - 4

Атомарные факты:

  1. Миша женат на Свете.

  2. Света дружит с Викой.

Выведенный факт:

  1. Подругу жены Мишы зовут Вика

или, выражаясь более формально:

  1. Миша женат дружит с Викой

Скрытый текст

Давайте посмотрим на более абстрактный набор

Атомарные факты:
1. <e_1034><r_6><e_339></a>
2. <e_339><r_85><e_1745></a>

Тут присутствуют уже знакомые нам сущности, но есть и новые части - <r_x> - отношения между сущностями. Так как это абстракция, то может быть там что угодно - от социального статуса, физической принадлежности (колесо - часть машины) до всего, что можно выразить подобным образом.

Что тут могло бы быть:
1. <Миша> <женат> <Света> </a>
2. <Света> <дружит> <Вика></a>

Выведенные факты:

<e_1034><r_6><r_85><e_1745></a>

Здесь присутствует неявная сущность - Света. Она не упоминается в примере явно, но косвенно на нее указывает связь <r_6>, или женат. То есть, переводя на естественный язык, это было бы так:

<Миша><женат><дружит><Вика></a>

Так имеют ли трансформеры думать?

Да!

(Слева) График метрик для композиции. (Справа) График метрик для сравнения. Красная линия - точность на тренировочной выборке. Зеленая линия - точность на тестовой выборке. Синяя линия - точность на тестовой выборке, где модель не видела ни одной связи с сущностями, то есть не было выведенных фактов. О ней поговорим чуть позже

(Слева) График метрик для композиции. (Справа) График метрик для сравнения. Красная линия - точность на тренировочной выборке. Зеленая линия - точность на тестовой выборке. Синяя линия - точность на тестовой выборке, где модель не видела ни одной связи с сущностями, то есть не было выведенных фактов. О ней поговорим чуть позже

С помощью гроккинга возможно решение обоих типов задач: композиции и сравнения. В обоих случаях модель достигла 100% точности на обучающей выборке примерно к 1000-му шагу. Однако если тренировать модель в 300 раз дольше (до 300 000 шагов), то 100% точность достигается и на тестовой выборке. Неочевидный механизм, не правда ли? Он становится еще более неочевидным, если не выполнить одно важное условие.

Качество данных

Качество данных — необходимый элемент, без которого эта "магия" не сработает. Под качеством данных здесь подразумевается соотношение выведенных фактов к атомарным. При этом, по заявлению авторов, количество данных не коррелирует ни с достижением метрики, ни со скоростью её достижения.

(Слева) График тестовой метрики от соотношения выведенных / атомарных фактов. Можно заметить, что если соотношения 3.6 (то есть на 1 атомарный факт приходится 3.6 выведенных), то сходимость не наступает. В целом, чем выше соотношение, тем быстрее сходимость. (Справа) График тестовой метрики в зависимости от размера набора данных - 2, 5 и 10 тысяч.

(Слева) График тестовой метрики от соотношения выведенных / атомарных фактов. Можно заметить, что если соотношения 3.6 (то есть на 1 атомарный факт приходится 3.6 выведенных), то сходимость не наступает. В целом, чем выше соотношение, тем быстрее сходимость. (Справа) График тестовой метрики в зависимости от размера набора данных - 2, 5 и 10 тысяч.

Очевидно, что эти значения специфичны для задач и данных, сгенерированных авторами, но тенденция интересная. Если подумать, то соотношение 3,6 уже весьма высокое для реальных данных, не говоря уже о 12 и 18. Возьмём, к примеру, Википедию — статьи там в основном написаны в "атомарном стиле", без явных примеров рассуждений и композиции. В нашем неструктурированном мире обработки естественного языка преобладают атомарные факты, а выведенных гораздо меньше. Возможно, именно здесь кроется причина "ограниченности" моделей?

Ну ладно, если вы уже собрались уходить читать следующую статью, то задержитесь еще на пару минут - сейчас мы посмотрим, как это работает на практике. Авторы любезно предоставили свою интерпретацию происходящего с экспериментами и картинками - любопытно же, как и чему модель обучается.

Как это работает?

Любую нейронную сеть можно представить как функцию f(x), преобразующую набор входных токенов x в набор сгенерированных токенов y. На самом деле, эта большая функция f состоит из множества вложенных функций: f1(f2(f3(x))). Это я абстрактно веду к тому, что сложность и эффективность трансформера определяется сложностью и эффективностью преобразований, которые он выполняет от слоя к слою и их взаимодействием.

Если коротко, это и есть ответ! Во время гроккинга, внутри трансформера формируются более эффективные преобразования и их взаимодействия на уровне слоев. Благодаря этому появляется способность к обобщению — более практичный способ решать задачи по сравнению с простым запоминанием. А если длинно, давайте разберёмся.

Обучение модели

Когда модель начинает обучаться, она находит простейший способ решения задачи — запоминание всех фактов. Это объясняет, почему на тренировочной выборке метрика так быстро достигает 100%. Со временем гроккинг может наступить или нет. Если мы вспомним, отчего это зависит, а это — соотношение выведенных фактов к атомарным, то шторка неизвестности начнет приоткрываться. Гроккинг наступает при высоком соотношении фактов, так как у модели появляется стимул искать более эффективный способ решения задачи, помимо запоминания.

Исследователи обнаружили, что модель внутри начинает разделяется на две части. В нижних слоях она находит у себя в памяти сущность, используя первое отношение (жена Мишы - Света). Затем, модель в верхних слоях пытается связать этот найденную сущность и второе отношение (подруга Светы - Вика).

Получается, что модель разделяется на этапы, соответствующие шагам рассуждения. Хотя мы не упоминали Свету явно в нашем примере, модель всё равно находит её на первом этапе, так как это необходимо для определения её подруги. Вместо простого запоминания Миша-Вика, что является менее эффективным, формируется схема для поиска Миша-Света-Вика.

Нижние слои (0-5):

  1. Получают информацию о связующей сущности (Света)

  2. Сохраняют информацию о её связи со второй сущностью для последующего извлечения

Верхние слои (5-8):

  1. Выполняют второй шаг рассуждения, извлекая оставшуюся информацию — связь между связующей и конечной сущностями (Вика)

  2. Передают информацию в конечный слой, где модель использует её для формирования ответа

Если умножить скрытое состояние на нашу матрицу эмбеддингов, мы получим распределение токенов. На 5 слое, первый узел содержит связующую сущность, а второй узел передает второе отношение дальше.

Если умножить скрытое состояние на нашу матрицу эмбеддингов, мы получим распределение токенов. На 5 слое, первый узел содержит связующую сущность, а второй узел передает второе отношение дальше.

Та синяя линия на графике, или неэффективность архитектуры трансформера

(Слева) График метрик для композиции. (Справа) График метрик для сравнения. Красная линия - точность на тренировочной выборке. Зеленая линия - точность на тестовой выборке. Синяя линия - точность на тестовой выборке, где модель не видела ни одной связи с сущностями, то есть не было выведенных фактов. Для композиции модель не может обобщиться, а для сравнения — может!

(Слева) График метрик для композиции. (Справа) График метрик для сравнения. Красная линия - точность на тренировочной выборке. Зеленая линия - точность на тестовой выборке. Синяя линия - точность на тестовой выборке, где модель не видела ни одной связи с сущностями, то есть не было выведенных фактов. Для композиции модель не может обобщиться, а для сравнения — может!

Помните, у нас чуть выше был график данных и там была странная линия Test (OOD — Out of Distribution), которая никак не хотела расти выше 5%? Если вкратце, модели вообще не показывали выведенные факты об этих сущностях.

Давайте рассмотрим пример: у нас есть Антон, женатый на Маше, а у Маши есть подруга Катя. Во время обучения модели предоставили все атомарные факты об этом, но не дали выведенных. То есть модель знает, что Антон женат на Маше, а подруга Маши — Катя, но во время обучения ни разу не отвечала на вопросы вроде "Кто подруга жены Антона?". То есть, модель знает все связи, знает логику работы этих связей, но почему-то, если в тренировочную выборку не включить примеры использования этих связей, то у нее не получится ответить на эти вопросы в тесте. Честно говоря, такое поведение кажется несколько нелогичным — ведь на остальные вопросы модель отвечает без проблем. Но у авторов статьи и на это нашелся ответ!

Как мы выяснили, извлечение фактов второго шага — когда мы нашли жену и хотим найти её подругу — происходит в верхних слоях (5-8). Исследователи пришли к выводу, что модели просто незачем хранить в памяти атомарные факты второго шага в верхних слоях, если они не использовались во время обучения — модели они были не нужны. Модель может понять, что Маша — жена Антона, но когда она пытается понять, кто же подруга Маши, она осознает, что не знает этого… По крайней мере в верхних слоях.

Из-за особенностей архитектуры памяти трансформера, а именно отсутствия общей "шины" памяти, факты в первых и последних слоях никак совместно не используются. Поэтому если у модели нет причин хранить данные в верхних слоях, она не сможет решить задачу на примерах, которых никогда не видела.

Совсем немного про сравнение

Для задачи сравнения модель способна к обобщению на Test (OOD) и вот почему: формируется параллельная схема с несколько иным алгоритмом работы.

  1. Первые слои (0-5) отвечают за извлечение фактов о сущностях. Таким образом, модели не нужно хранить атомарные факты в разных своих частях.

  2. Верхние слои (5-8) сравнивают извлеченные значения и определяют отношение — больше, меньше или равно.

Могут ли трансформеры «думать» - 9

Резюме

  • Гроккинг — феномен отложенного обобщения модели на небольших наборах данных.

  • Благодаря гроккингу модель формирует более эффективные алгоритмы решения задач, достигая до 100% точности на тестовой выборке.

  • Качество данных критично для гроккинга — ключевое значение имеет соотношение выведенных фактов к атомарным.

  • В процессе обучения модель разделяется на две части:

    • Нижние слои (0-5) выполняют первый шаг — получают значения сущностей или ищут связующую.

    • Верхние слои (5-8) объединяют полученные факты или сравнивают значения.

  • Текущая архитектура памяти трансформера не позволяет ей эффективно делиться знаниями и фактами между слоями, поэтому в задаче композиции модель не смогла обобщиться полностью

Практика

Давайте попробуем повторить результаты и поведение гроккинга на таком же наборе данных, что и у авторов. Мы сделаем несколько изменений для простоты, о которых я буду сообщать по ходу. Весь код будет доступен в виде ноутбука. Найти ссылка на него и на данные вы сможете в самом конце статьи.

Мы будем использовать библиотеку Hugging Face и её реализацию класса для простого обучения модели — Trainer. С его помощью модель можно обучить всего за несколько строк кода. В нашем примере мы сосредоточимся на задаче сравнения, но этот же код можно применить и для задачи композиции.

Данные

Как мы уже разбирали ранее, данные представляют из себя абстрактные сущности.

{
  'input_text': '<e_156><attr_0>',
  'target_text': '<e_156><attr_0><4></a>',
  'type': 'id_atomic'
}

В оригинальной статье используется seq2seq модель, поэтому здесь есть input_text и target_text, который модель должна сгенерировать. Мы будем работать с классическим предсказанием следующего токена, поэтому понадобиться лишь target_text. Атрибут type говорит нам, что именно это за данные - атомарный факт, выведенный факт или что-то еще. Вот какие значения могут быть:

  1. id_atomic - атомарный факт

  2. ood_atomic - атомарный факт, для которого нет выведенных в тренировочной выборке

  3. train_inferred - выведенный тренировочный факт

  4. test_inferred_iid - выведенный тестовый факт (атомарные факты были в каких-то других выведенных фактах)

  5. test_inferred_ood - выведенный тестовый факт (атомарные факты не были в каких-то других выведенных фактах)

Для простоты и скорости тренировки, использовать будем соотношение выведенных к атомарным фактам — 12.6, то есть почти 13 выведенных фактов на 1 атомарный.

train.json
{
  'id_atomic': 18000, 
  'ood_atomic': 2000, 
  'train_inferred': 226800
}

test.json
{
  'test_inferred_ood': 3000,
  'test_inferred_iid': 3000,
  'train_inferred': 3000
}
Скрытый текст
with open('train.json', 'r') as f:
   train = json.load(f)

with open('test.json', 'r') as f:
   valid = json.load(f)

Токенизатор и словарь

В статье используется уже готовый словарь с токенами, который был составлен вместе с генерацией набора данных. Можно использовать его, но мы не ищем легких путей, поэтому натренируем свой токенизатор и составим словарь под нас.

Скрытый текст
from datasets import Dataset, DatasetDict
from tokenizers import Tokenizer
from tokenizers.models import BPE
from transformers import GPT2TokenizerFast
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.decoders import ByteLevel

Инициализируем токенизатор. Использовать будем Byte Pair Encoding - изначально он был представлен вместе с первыми GPT моделями, и до сих пор там используется. И раз в дальнейшем мы будем работать с GPT-2 small моделью, то этот выбор энкодера логичен.

Byte Pair Encoding работает, как ни странно, с байтами и объединяет их часто встречающиеся последовательности в токены. Благодаря этому разбиваются основы и окончания слов. Например, "lower" = "low" + "er" и "lowest" = "low" + "est". Мы не будем углубляться в эту тему, так как она не является основной в нашей статье. Если вам интересно узнать больше, можете почитать тут.

Это особенно полезно для нашей задачи, поскольку наши примеры не разделены пробелами и фактически представляют собой одно длинное слово. Благодаря работе на уровне байтов, нам не нужно вручную разбивать этот текст на отдельные токены — энкодер сам определит, что следует объединить, а что разделить.

Скрытый текст
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = Whitespace()
tokenizer.decoder = ByteLevel()
# добавляем специальные токен
SPECIAL_TOKENS = ["<q>", "<pad>", "</q>", "<unk>", "<mask>", "<a>", "</a>"]

trainer = BpeTrainer(min_frequency=1, special_tokens=SPECIAL_TOKENS)

Напишем небольшую функцию, где натренируем наш токенизатор из итерируемого объекта. В нашем случае это будет просто list с текстами. Затем закодируем какой-нибудь случайный текст и декодируем его, чтобы проверить адекватность нашего подхода.

Скрытый текст
def train_tokenizer(texts, save_path = None):
    tokenizer.train_from_iterator(texts, trainer=trainer)

    print('Обучение завершено. nВсего токенов в словаре {}.nПримеры токенов:n{}'.format(
        tokenizer.get_vocab_size(),
        'n'.join(["{} - {}".format(token, token_id) for token, token_id in tokenizer.get_vocab().items()][:10])
    ))

    fast_tokenizer = GPT2TokenizerFast(tokenizer_object=tokenizer)
    fast_tokenizer.add_special_tokens({'pad_token': '<pad>'})

    # проверим, что вышло
    for text in texts[:1]:
        encoded_text = fast_tokenizer.encode(text)
        decoded_text = fast_tokenizer.decode(encoded_text)
        tokenized_text = fast_tokenizer.tokenize(text)

        print('nnОригинальный текст: {}nТокенезированный: {}nЗакодированный: {}nДекодированный: {}'.format(
            text, tokenized_text, encoded_text, decoded_text
        ))

    if save_path is not None:
        fast_tokenizer.save_pretrained(save_path)

# обучаем токенайзер на всех данных
train_texts = [row['target_text'] for row in train]
all_texts = train_texts + valid_texts
train_tokenizer(all_texts, save_path = 'tokenizer')

И вот что получилось

Обучение завершено. 
Всего токенов в словаре 1132.
Примеры токенов:
6 - 14
e_299 - 664
e_105 - 511
e_143 - 684
e_585 - 317
e_805 - 475
e_363 - 710
e_655 - 268
e_17 - 188
e_771 - 999


Оригинальный текст: <e_156><attr_0><4></a>
Токенезированный: ['<', 'e_156', '><', 'attr_0', '><', '4', '>', '</a>']
Закодированный: [18, 612, 29, 67, 29, 12, 19, 6]
Декодированный: <e_156><attr_0><4></a>

Из хороших новостей: токенизатор правильно выделил атрибуты, сущности и ответы для сравнения в разные токены — ура! Из плохих: токенизатор считает скобочки >< за отдельные токены: это увеличивает длину наших токенизированных предложений и может добавить шум, когда мы соберёмся модель интерпретировать, но для наших целей этого достаточно.

Конфигурация модели

Давайте посчитаем оптимальную конфигурацию для нашего трансфомера. Для этого нужно посчитать длину токенизированных предложений - так мы поймем, какой потребуется контекст. С оригинальным контекстом GPT-2 в 1024 токенов работать смысла нет, так как у нас ограниченный вход, а увеличенный контекст лишь увеличит вычислительные затраты.

Скрытый текст
tokenizer = GPT2TokenizerFast.from_pretrained("tokenizer")

# подсчитаем максимальную длину токенизированного текста, чтобы оптимизировать трансформер
Counter([len(tokenizer.encode(text)) for text in all_texts])

Counter({8: 25000, 14: 235800})

Максимальная длина последовательности оказалась 14 токенов. При работе с GPU эффективнее использовать гиперпараметры, являющиеся степенью двойки или хотя бы чётным числом. Это оптимизирует обработку данных на GPU. Ближайшая к 14 степень двойки — 16, поэтому установим её как максимальную длину нашей последовательности. Также определим конфигурацию модели, взятую из оригинальной статьи.

Скрытый текст
MODEL_SIZES = ['small', 'medium']
MAX_SEQ_LEN = 16

# Используем самую маленькую модель
model_size = MODEL_SIZES[0]

# инициализация конфига - тут указываем все гиперпараметры, которые могут нам потребоваться
if model_size == 'small':
    n_embd = 768
    n_layer = 8
    n_head = 12

else:
    n_embd = 1024
    n_layer = 16
    n_head = 16
    
config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_positions=MAX_SEQ_LEN,
    n_ctx=MAX_SEQ_LEN,
    n_embd=n_embd,
    n_layer=n_layer,
    n_head=n_head,
    pad_token_id=tokenizer.pad_token_id,
)

model = GPT2LMHeadModel(config)

Подготовка данных

Теперь поработаем с данными. Наша задача — токенизировать текст, дополнить его паддинг-токенами до максимальной длины, если необходимо, и преобразовать в формат, подходящий для работы с trainer. На этом, в целом, всё. Использовать будем класс dataset.

Скрытый текст
# функция токенизации
def preprocess_function(examples):
    model_inputs = tokenizer(examples['target_text'], truncation=True, padding='max_length', max_length=MAX_SEQ_LEN)
    return model_inputs
# преобразуем из list в dataset
dataset_train = Dataset.from_list(train)
dataset_valid = Dataset.from_list(valid)
# токенизируем и заполняем [pad]
dataset_train = dataset_train.map(preprocess_function, batched=True)
dataset_valid = dataset_valid.map(preprocess_function, batched=True)
# посмотрим, что получилось
dataset_valid

Dataset({
    features: ['input_text', 'target_text', 'type', 'input_ids', 'attention_mask'],
    num_rows: 14000
})

В тестовом наборе данных получилось 14 тысяч примеров. У нас есть уже знакомые поля input_text, target_text, type. Но еще добавились два новых: input_ids, attention_mask. Посмотрим на них

Скрытый текст
dataset_valid['input_ids'][0], dataset_valid['attention_mask'][0]

([18, 56, 19, 0, 18, 751, 19, 4, 18, 1052, 29, 105, 19, 6, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0])

input_ids представляет собой токенизированный текст. Как можно заметить, повторяющиеся 1 в конце — это наши [pad] токены. attention_mask — набор 1 и 0, указывающий трансформеру, какие токены не нужно включать в механизм внимания. В нашем случае это [pad] токены.

Теперь что мы не сделали еще: не подготовили данные для метрик! Следуя оригинальной статье, подготовим train, test (iid) и test (ood). Для этого возьмем все данные из нашего dataset_valid, которые обладают соответствующим типом.

Скрытый текст
train_questions = dataset_valid.filter(lambda x:  x['type'] == 'train_inferred' )
valid_questions_iid = dataset_valid.filter(lambda x: x['type'] == 'test_inferred_iid')
valid_questions_ood = dataset_valid.filter(lambda x: x['type'] == 'test_inferred_ood')

Метрики

Последний рывок перед обучением! Для наглядности и простоты, код будет не самым эффективным и стабильным, но для наших учебных целей вполне приемлем.

Наша модель отдаст нам распределение вероятность на наши токены, поэтому первое, что необходимо сделать - взять наиболее вероятные токены на каждую позицию контекста.

Скрытый текст
  predictions, labels = eval_pred
  # получаем наиболее вероятный токен
  predictions = np.argmax(predictions, axis=-1)

Теперь для простоты реализации, давайте декодируем текст, то есть переведем его обратно из токенов в слова

Скрытый текст
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=False)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False)

А теперь, главный трюк! Так как адекватных токенов-разделителей у нас нет, просто подсчитаем на каком месте должно стоять предсказываемое слово - в нашем случае на 5 для сравнения и 3 для композиции.

Скрытый текст
label_text = [l.split('><')[5].strip() for l in decoded_labels]       
pred_text = [p.split('><')[5].strip() for p in decoded_preds]


# раскомментируйте если хотите работать с задачей композиции
# label_text = [l.split('><')[3].strip() for l in decoded_labels]
# pred_text = [p.split('><')[3].strip() for p in decoded_preds]

И подсчитаем метрику accuracy — долю точно совпавших предсказанных слов.

Скрытый текст
correct = sum([1 if pred == label else 0 for pred, label in zip(pred_text, label_text)])
total = len(label_text)
accuracy = correct / total

Добавим теперь try-except, потому что код… в общем, у нас же нет гарантии, что в предсказанном моделью тексте будет 5 символов «><». Их может быть 3, 2 или вообще не быть, если модель только начала обучаться. Но не обращайте на это внимания — скажем, что легаси, и менять его нельзя — так и пройдем ревью. Кстати, не используйте общий блок try-except как здесь - это плохая практика. И обернем все в функцию.

Скрытый текст
def compute_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        # получаем наиболее вероятный токен
        predictions = np.argmax(predictions, axis=-1)
        labels[labels == -100] = tokenizer.eos_token_id
        # декодируем текст, чтобы получить представление
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=False)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False)
        # к сожалению, авторы не сделали более адекватных ограничителей вопрос - ответ, типа <a></a>,
        # поэтому считаем количество скобочек, где должен находится наш ответ
        label_text = [l.split('><')[5].strip() for l in decoded_labels]
        pred_text = [p.split('><')[5].strip() for p in decoded_preds]

        # раскомментируйте если хотите работать с задачей композиции
        # label_text = [l.split('><')[3].strip() for l in decoded_labels]
        # pred_text = [p.split('><')[3].strip() for p in decoded_preds]


        # пример наших предсказания и ответов
        print('nnnfiltered: ', label_text[:5], pred_text[:5])
        # считаем accuracy
        correct = sum([1 if pred == label else 0 for pred, label in zip(pred_text, label_text)])
        total = len(label_text)
        accuracy = correct / total
    except Exception as e:
        print(repr(e))
        accuracy = 0

    return {"accuracy": accuracy}

Теперь добавим callback для нашего trainer - без него он ничего сохранять не будет.

Скрытый текст
class TrainMetricsCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.eval_steps == 0:

            train_metrics = trainer.evaluate(train_questions)
            valid_iid_metrics = trainer.evaluate(valid_questions_iid)
            valid_ood_metrics = trainer.evaluate(valid_questions_ood)

            with open(os.path.join(args.logging_dir, "metrics_log.txt"), "a") as f:
                f.write(f"Train Step {state.global_step}: {train_metrics}n")
                f.write(f"Valid IID Step {state.global_step}: {valid_iid_metrics}n")
                f.write(f"Valid OOD Step {state.global_step}: {valid_ood_metrics}n")

Обучение

Наконец, мы добрались до главного. Давайте определим конфиг для нашего обучения. Вы можете оставить все параметры без изменений, но если ваша видеокарта не справляется с большим размером батча (batch size), или наоборот — поддерживает функции bf16 или torch.compile (требуются новейшие видеокарты), обязательно включите их. Это существенно сократит время обучения.

Скрытый текст
BATCH_SIZE = 512

training_args = TrainingArguments(
    # куда сохранять веса
    output_dir="./results",
    # сколько весов сохранять
    save_total_limit=2,
    # тут понятно
    learning_rate=1e-4,
    # batch_size для этапов тестирования и тренировки
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size =BATCH_SIZE,
    # сколько примеров аккумулировать на gpu во время evaluate() перед тем, как отдать cpu
    # чем больше число, тем меньше задержка из-за передачи данных с gpu на cpu, но расходуется больше vram gpu
    # ставьте максимальное число, которое позволяет ваша видеокарта
    eval_accumulation_steps=5,
    # логгировать ли первый шаг
    logging_first_step=True,
    # раз в сколько шагов считать метрики
    eval_steps = 500,
    # количество эпох обучения
    num_train_epochs=4000,
    # без комментариев
    weight_decay=1,
    # раз в сколько шагов логгировать. это не подсчет метрик - тут просто выводится loss и на каком шаге модель сейчас
    logging_steps=200,
    # куда записывать логи
    logging_dir='.',
    # сколько cpu использовать для предобработки данных
    dataloader_num_workers=4,

    # если ваша видеокарта позволяет, ставьте True
    torch_compile=False,
    # что-то одно
    bf16 = False,
    fp16 = False
)

Теперь определяем специальные классы для trainer и сам trainer. Это DataCollatorForLanguageModeling - он подготавливает данные для задачи предсказания следующего токена. Делает соответствующую предобработку - добавляет padding (нам не нужно), преобразует input_ids и attention_mask в формат задачи и добавляет дополнительные маски, если требуется.

Скрытый текст
# mlm - masked language modelling. когда мы маскируем определенный % нашего текста
# используется, например, в BERT
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_train,
    eval_dataset=valid_questions_iid,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[TrainMetricsCallback()]  # Добавление callback для подсчета метрик
)

И все! Можно запускать обучение.

Скрытый текст
trainer.train()

Теперь вы можете увидеть прогресс обучения модели.

Могут ли трансформеры «думать» - 10

Скажу честно, терпения на Google Colab с T4 видеокартой у меня не хватило, поэтому я воспользовался своей. Проводил я все эксперименты на A100 с 80 ГБ VRAM, но вам столько не нужно. У меня утилизировалось всего 5–6 ГБ. Можете уменьшить batch_size и понизить это потребление ещё сильнее. В любом случае, гроккинг — дело долгое. У меня ушло приблизительно 8 часов, чтобы достичь 100% на тестовом наборе.

Результаты

(Слева) функция потерь (Справа) Точность на данных. Можно заметить, что функция потерь растет до определенного момента, а для OOD все время. Через 150,000 шагов достигается 100% точность на тестовой выборке.

(Слева) функция потерь (Справа) Точность на данных. Можно заметить, что функция потерь растет до определенного момента, а для OOD все время. Через 150,000 шагов достигается 100% точность на тестовой выборке.

Как мы видим, график приблизительно напоминает график из статьи. Только там использовалась log-шкала, а у нас обычная. До 50,000 шага можно наблюдать просадку в точности для OOD данных. Затем точность растет на обоих датасетах и сходится к 100% примерно на 150,000 шаге.

Выводы

Поздравляю! Мы успешно достигли гроккинга. Наша модель прошла путь от неопытного смешарика, который просто запоминает все значения, до опытного эксперта, способного строить сложные цепочки рассуждений.

Данный подход довольно нетривиален и требует очень качественных данных. Как мы видим, работали в этой статье с задачами синтетическими, имеющими графовую структуру. Что будет, если связность графа мала? Как нам применять этот подход в реальных задачах? Этому сейчас посвящена как раз моя кандидатская отчасти, поэтому можете найти ответы на эти вопросы здесь:

Если у вас остались вопросы - приглашаю в комментарии для обсуждения.

Автор: perfect_startup

Источник

* - обязательные к заполнению поля


https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js