Последние несколько лет в развитии глубоких нейронных сетей происходит настоящая революция: возникают новые архитектуры, совершенствуются фреймворки для разработчиков, а железо для экспериментов можно получить совершенно бесплатно — например, в рамках проекта Google colaboratory. Всем, кому интересно как применить предобученные модели из репозитория Tensorflow Object Detection API к решению своей задачи, используя мощности Colaboratory — добро пожаловать под кат.
Топовые GPU — для всех
Для обучения нейросетей на больших объёмах данных лучше использовать GPU: скорость обучения и инференса будет выше, чем на CPU за счёт эффективного распараллеливания операций по тысячам ядер. До последнего времени применить карточки в своих расчётах можно было, например, с помощью облачных инстансов Amazon. Однако зачем платить за что-то, что можно получить бесплатно? Google Colaboratory предоставляет доступ к карточке Tesla K80.
Доступ к карте можно получить через меню Edit->Notebook setttings->Hardware accelerator:
С этой карточкой можно быстро проводить эксперименты с нейронными сетями — например пройти крутой курс по NLP от ребят из DeepPavlov. Единственный минус сервиса — карточка ваша всего на 12 часов, этому промежуточные результаты нужно куда-нибудь сохранять — если есть необходимость использовать их в дальнейшем.
В этой статье я расскажу как с помощью Colaboratory обучить модель и сохранить её для дальнейшего использования. В статье будет мало кода — весь код, обучающий модель и осуществляющий взаимодействие с Google Drive (для сохранения промежуточных результатов) доступен в моём репозитории. Поехали!
Датасет DeepFashion
Для экспериментов я буду использовать датасет Deep Fashion — это 800к изображений предметов одежды.
Изображения содержат теги, а так же на фото размечены bounding boxes. Мы будем обучать нейросеть детектировать изображения одежды на фото — рисовать bounding box и классифицировать один из трёх классов: upper-body, lower-body и full-body.
Подготовка данных
Вначале скопируйте DeepFashion себе на GoogleDrive датасет из директории Category and Attribute Prediction Benchmark в корневую директорию. Мы будем много работать с GoogleDrive как с файловым хранилищем — копировать оттуда данные и закачивать на Диск результаты работы (например, чекпоинт модели).
Для начала скопируем репозиторий с кодом и установим зависимости:
!rm -r TFFashionDetection
!git clone https://github.com/Dju999/TFFashionDetection.git
!pip install lxml
!pip install -U -q PyDrive
!pip install tqdm
Ещё нужно создать вспомогательный объект для работы с файловой системой Google Drive
from TFFashionDetection.utils.colab_fs import GoogleColabFS
fs = GoogleColabFS()
Для более подробной информации можно посмотреть файл utils.colab_fs.py в репозитории.
Теперь нужно скачать датасет DeepFashion:
!python3 /content/TFFashionDetection/utils/dataset_download.py
В наборе данных присутствуют три директории
- Img — с изображениями предметов одежды
- Eval — содержит текстовый файл с разбиением датасета на train, test, valid
- Anno — тут файлы с тегами, баундинг боксами и прочей служебной информацией
Наша задача — подготовить эти данные для скармливания нейросети: описание файлов в специальном формате, разбиение на train и test.
Детекция изображений на Tensorflow
Google в 2017 году зарелизил Object Detection API — набор моделей и инструментов для детекции изображений. В репозитории очень много скриптов для подготовки обучающих данных, обучения моделей и визуализации результатов — например, отрисовки bounding boxes.
Код ниже устанавливает TF Object Detection API из github-репозитория в среду GoogleColaboratory.
! cd /content; git clone https://github.com/tensorflow/models.git
# установка зависимостей для object detection тут
# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
!apt-get install protobuf-compiler python-pil python-lxml python-tk
!pip install Cython
!cd /content; git clone https://github.com/cocodataset/cocoapi.git; cd cocoapi/PythonAPI; make; cp -r pycocotools /content/models/research/
!cd /content/models/research; protoc object_detection/protos/*.proto --python_out=.
# проверка - запускаем тесты
!cd /content/models/research; export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim; python object_detection/builders/model_builder_test.py
Теперь нужно подготовить данные. В Img сложная структура поддиректорий, где каждому классу одежды соответствует своя категория. Код ниже копирует все фотографии в одну директорию, а так же готовит для каждого файла описание в виде tf.train.Example — на всём этом добре будет обучаться модель детекции
В коде есть название модели ssd_mobilenet_v2_coco_2018_03_29 — подходящую модельку можно скачать в Detection Zoo. Модельку можно скачать другую — но тогда нужно будет переписать файл /content/data_dir/tf_api.config.
import sys
import os
import numpy as np
API_PATH = os.path.join('/content', 'models/research')
sys.path.append(API_PATH)
DETECTOR_PATH = os.path.join('/content', 'TFFashionDetection')
sys.path.append(DETECTOR_PATH)
from TFFashionDetection.data_preparator import DataPreparator
from TFFashionDetection.utils.ssd_config import write_config
data_preparator = DataPreparator()
data_preparator.build()
write_config('ssd_mobilenet_v2_coco_2018_03_29')
После того, как выполнение ячейки закончится — можно скачать предобученную модель и запускать обучение модели. Frozen inference graph из репы Object Detetetion позволит быстрее обучить свой детектор. Путь до директории с графом модели нужно будет передать в скрипт обучения /object_detection/train.py.
# скачиваем модель (предобученную)
!python /content/TFFashionDetection/utils/download_tf_zoo_model.py --name ssd_mobilenet_v2_coco_2018_03_29 --dir /content
# запускаем обучение модели
!export PYTHONPATH=$PYTHONPATH:/content/models/research/slim:/content/models/research/;python /content/models/research/object_detection/train.py --logtostderr --pipeline_config_path=/content/data_dir/tf_api.config --train_dir=/content/data_dir/checkpoints
Если всё сделали правильно — увидим, как побегут логи и loss будет уменьшаться на каждой итерации. Когда вы видите, что функция потерь перестала уменьшаться — можно останавливать обучение. Граф модели сохраняется в директорию /content/data_dir/checkpoints — его нужно будет сохранить для дальнейших экспериментов. Обучение модели нужно провести один раз, после этого использовать полученный граф для инференса.
Когда модель обучена — нужно сохранить её на гугл-диск
!cd /content/data_dir; zip -r checkpoint_save_20180514.zip checkpoints/*
import os
fs = GoogleColabFS()
file_name = os.path.join('/content/data_dir', 'checkpoint_save_20180514.zip')
fs.load_to_drive(file_name)
Скачивать с гугл-диска таким же образом
import os
fs.load_file_from_drive('/content', 'checkpoint_save_20180514.zip')
fs.unzip_file('/content', 'checkpoint_save_20180514.zip')
!mkdir /content/deep_detection_model
# экспортируем модель
!export PYTHONPATH=$PYTHONPATH:/content/models/research/slim:/content/models/research/;python /content/models/research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path=/content/data_dir/tf_api.config --trained_checkpoint_prefix=/content/checkpoints/model.ckpt-2108 --output_directory inference_graph
Для примера выберем случайную фотку и скормим её нашей сети для детекции:
import sys
import os
import matplotlib.pyplot as plt
plt.switch_backend('agg')
sys.path.append(os.path.join('/content', 'models/research'))
from object_detection.utils import visualization_utils as vis_util
from PIL import Image as Pil_image
%matplotlib inline
boxes = np.array([oject_detector.img_detections[3]['category_box']])
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
# загружаем картинку и превращаем в массив
image = Pil_image.open(file_path)
image_np = load_image_into_numpy_array(image)
# накладываем на массив bounding boxes
vis_util.draw_bounding_boxes_on_image_array(image_np, boxes)
# сохраняем картинку на диск
result_file_path = os.path.join('/content', 'test.png')
vis_util.save_image_array_as_png(image_np, result_file_path)
# виуализируем картинку, которую сохранили
from IPython.display import Image
Image(result_file_path)
Видим результаты детекции — lower_body предмет одежды
Заключение
TF Object Detection API — крутая технология, которая позволяет использовать в своих моделях State-of-the-Art архитектуры сеток. А Google Colaboratory — отличная площадка для экспериментов, которая позволяет тренировать сети на мощном железе. Код из статьи доступен тут.
Автор: Джумурат Александр