Джедайская техника уменьшения сверточных сетей — pruning

в 7:10, , рубрики: cnn, convolutional neural network, CUDA, deep learning, eco, gpu, keras, machine learning, neural networks, Nvidia, optimization, pruning, python, speedup, TensorFlow, tensorrt, yolo, высокая производительность, искусственный интеллект, Исследования и прогнозы в IT, машинное обучение

image

Перед тобой снова задача детектирования объектов. Приоритет — скорость работы при приемлемой точности. Берешь архитектуру YOLOv3 и дообучаешь. Точность(mAp75) больше 0.95. Но скорость прогона всё еще низкая. Черт.

Сегодня обойдём стороной квантизацию. А под катом рассмотрим Model Pruning — обрезание избыточных частей сети для ускорения Inference без потери точности. Наглядно — откуда, сколько и как можно вырезать. Разберем, как сделать это вручную и где можно автоматизировать. В конце — репозиторий на keras.

Введение

На прошлом месте работы, пермском Macroscop, я обрёл одну привычку — всегда следить за временем исполнения алгоритмов. А время прогона сетей всегда проверять через фильтр адекватности. Обычно state-of-the-art в проде не проходят этот фильтр, что и привело меня к Pruning.

Pruning — тема старая, о которой рассказывали в cтэндфордских лекциях в 2017 году. Основная идея — уменьшение размера обученной сети без потери точности путем удаления различных узлов. Звучит клево, но я редко слышу о его применении. Наверное, не хватает имплементаций, нет русскоязычных статей или просто все считают pruning ноу-хау и молчат.
Но го разбирать

Взгляд в биологию

Люблю, когда в Deep Learning заглядывают идеи, пришедшие из биологии. Им, как и эволюции, можно доверять (а ты знал, что ReLU весьма похожа на функцию активации нейронов в мозге?)

Процесс Model Pruning тоже близок к биологии. Реакцию сети здесь можно сравнить с пластичностью мозга. Пара интересных примеров есть в книге Нормана Дойджа:

  1. Мозг женщины, имевшей от рождения только одну половину, перепрограммировал сам себя для выполнения функций отсутствующей половины
  2. Парень отстрелил себе часть мозга, отвечающую за зрение. Со временем другие части мозга взяли на себя эти функции. (повторить не пытаемся)

Так и из вашей модели можно вырезать часть слабых свёрток. В крайнем случае, оставшиеся свёртки помогут заменить вырезанные.

Любишь Transfer Learning или учишь с нуля?

Вариант номер один. Ты используешь Transfer Learning на Yolov3. Retina, Mask-RCNN или U-Net. Но чаще всего нам не нужно распознавать 80 классов объектов, как в COCO. В моей практике все ограничивается 1-2 классами. Можно предположить, что архитектура для 80 классов здесь избыточна. Напрашивается мысль, что архитектуру нужно уменьшить. Причем, хотелось бы сделать это без потери имеющихся предобученных весов.

Вариант номер два. Может быть, у тебя много данных и вычислительных ресурсов или просто нужна сверхкастомная архитектура. Неважно. Но ты учишь сеть с нуля. Обычный порядок — смотрим на структуру данных, подбираем ИЗБЫТОЧНУЮ по мощности архитектуру и пушим дропауты от переобучения. Я видел дропауты 0.6, Карл.

В обоих случаях сеть можно уменьшать. Промотивировали. Теперь идем разбираться, что за обрезание pruning

Общий алгоритм

Мы решили, что можем удалять свертки. Выглядит это весьма просто:

Джедайская техника уменьшения сверточных сетей — pruning - 2

Удаление любой свертки — это стресс для сети, который обычно ведет за собой и некоторый рост ошибки. С одной стороны, этот рост ошибки является показателем того, насколько правильно мы удаляем свертки (например, большой рост говорит о том, что мы делаем что-то не так). Но небольшой рост вполне допустим и зачастую устраняется последующим легким дообучением с небольшим LR. Добавляем шаг дообучения:

Джедайская техника уменьшения сверточных сетей — pruning - 3

Теперь нам нужно понять, когда же мы хотим остановить наш цикл Learning<->Pruning. Здесь могут быть экзотические варианты, когда нам нужно уменьшать сеть до определенного размера и скорости прогона (например, для мобильных устройств). Однако, самый частый вариант — это продолжение цикла, пока ошибка не станет выше допустимой. Добавляем условие:

Джедайская техника уменьшения сверточных сетей — pruning - 4

Итак, алгоритм становится понятным. Остается разобрать, как определить удаляемые свертки.

Поиск удаляемых сверток

Нам нужно удалить какие-то свертки. Рваться напролом и "отстреливать" любые — плохая идея, хоть и будет работать. Но раз есть голова, можно подумать и попытаться выделить для удаления "слабые" свертки. Вариантов есть несколько:

  1. Наименьшая L1-мера или low_magnitude_pruning. Идея, говорящая о том, что свертки с малыми значениями весов, вносят малый вклад в итоговое принятие решения
  2. Наименьшая L1-мера с учетом среднего и стандартного отклонения. Дополняем оценкой характера распределения.
  3. Маскирование сверток и исключение наименее влияющих на итоговую точность. Более точное определение малозначимых свёрток, но весьма затратное по времени и ресурсам.
  4. Другие

Каждый из вариантов имеет право на жизнь и свои особенности реализации. Здесь рассмотрим вариант с наименьшей L1-мерой

Ручной процесс для YOLOv3

В исходной архитектуре содержатся остаточные блоки. Но какими бы крутыми они ни были для глубоких сетей, нам они несколько помешают. Сложность в том, что нельзя удалять сверки с разными индексами в этих слоях:

Джедайская техника уменьшения сверточных сетей — pruning - 5

Поэтому выделим слои, из которых мы можем свободно удалять сверки:

Джедайская техника уменьшения сверточных сетей — pruning - 6

Теперь построим цикл работы:

  1. Выгружаем активации
  2. Прикидываем, сколько вырезать
  3. Вырезаем
  4. Учим 10 эпох с LR=1e-4
  5. Тестируем

Выгружать свертки полезно, чтобы оценить, какую часть мы можем удалить на определённом шаге. Примеры выгрузки:

Джедайская техника уменьшения сверточных сетей — pruning - 7

Видим, что практически везде 5% свёрток имеют весьма низкую L1-норму и мы можем их удалить. На каждом шаге такая выгрузка повторялась и производилась оценка, из каких слоев и сколько можно вырезать.

Весь процесс уложился в 4 шага(тут и везде числа для RTX 2060 Super):

Шаг mAp75 Число параметров, млн Размер сети, мб От изначальной, % Время прогона, мс Условие обрезания
0 0.9656 60 241 100 180 -
1 0.9622 55 218 91 175 5% от всех
2 0.9625 50 197 83 168 5% от всех
3 0.9633 39 155 64 155 15% для слоев с 400+ сверток
4 0.9555 31 124 51 146 10% для слоев с 100+ сверток

К 2 шагу добавился один положительный эффект — в память влез батч-сайз 4, что весьма ускорило процесс дообучения.
На 4 шаге процесс был остановлен, т.к. даже длительное дообучение не поднимало mAp75 до старых значений.
В итоге удалось ускорить инференс на 15%, уменьшить размер на 35% и не потерять в точности.

Автоматизация для более простых архитектур

Для более простых архитектур сетей(без условных add, concaternate и residual блоков) вполне можно ориентироваться на обработку всех сверточных слоёв и автоматизировать процесс вырезания сверток.

Такой вариант я заимплементировал здесь.
Всё просто: с вас только функция потерь, оптимизатор и батч-генераторы:

import pruning
from keras.optimizers import Adam
from keras.utils import Sequence

train_batch_generator = BatchGenerator...
score_batch_generator = BatchGenerator...

opt = Adam(lr=1e-4)
pruner = pruning.Pruner("config.json", "categorical_crossentropy", opt)

pruner.prune(train_batch, valid_batch)

При необходимости можно изменить параметры конфиги:

{
    "input_model_path": "model.h5",
    "output_model_path": "model_pruned.h5",
    "finetuning_epochs": 10, # the number of epochs for train between pruning steps
    "stop_loss": 0.1, # loss for stopping process
    "pruning_percent_step": 0.05, # part of convs for delete on every pruning step
    "pruning_standart_deviation_part": 0.2 # shift for limit pruning part
}

Дополнительно реализовано ограничение на основании стандартного отклонения. Цель — ограничить часть удаляемых, исключая свертки с уже "достаточными" L1-мерами:

Джедайская техника уменьшения сверточных сетей — pruning - 8

Тем самым, мы позволяем удалить только слабые свертки из распределений подобных правому и не влиять на удаление из распределений подобных левому:

Джедайская техника уменьшения сверточных сетей — pruning - 9

При приближении распределения к нормальному коэффициент pruning_standart_deviation_part можно подобрать из:

Джедайская техника уменьшения сверточных сетей — pruning - 10
Я рекомендую допущение в 2 сигма. Или можно не ориентироваться на эту особенность, оставив значение < 1.0.

На выходе получается график размера сети, потери и времени прогона сети по всему тесту, отнормированные к 1.0. Например, здесь размер сети был уменьшен почти в 2 раза без потери в качестве (небольшая сверточная сеть на 100к весов):

Джедайская техника уменьшения сверточных сетей — pruning - 11

Скорость прогона подвержена нормальным флуктуациям и практически не изменилась. Этому есть объяснение:

  1. Число сверток меняется с удобного (32, 64, 128) на не самые удобные для видеокарт — 27, 51 и тд. Тут могу ошибиться, но скорее всего это влияет.
  2. Архитектура не широкая, но последовательная. Уменьшая ширину, мы не трогаем глубину. Тем самым уменьшаем загрузку, но не меняем скорость.

Поэтому улучшение выразилось в уменьшении загрузки CUDA при прогоне на 20-30%, но не в уменьшении времени прогона

Итоги

Порефлексируем. Мы рассмотрели 2 варианта pruning — для YOLOv3(когда приходится работать руками) и для сетей с архитектурами попроще. Видно, что в обоих случаях можно добиться уменьшения размера сети и ускорения без потери точности. Результаты:

  • Уменьшение размера
  • Ускорение прогона
  • Уменьшение загрузки CUDA
  • Как следствие, экологичность (Мы оптимизируем будущее использование вычислительных ресурсов. Где-то радуется одна Грета Тунберг)

Appendix

  • После шага pruning можно докрутить и квантизацию (например с TensorRT)
  • Tensorflow предоставляет возможности для low_magnitude_pruning. Работает.
  • Репозиторий хочу развивать и буду рад помощи

Автор: Дмитрий Пагин

Источник

* - обязательные к заполнению поля


https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js