- PVSM.RU - https://www.pvsm.ru -

Распознавание русского рукописного текста

Привет! Меня зовут Арсений, я работаю ML-инженером в компании Вита и параллельно учусь на втором курсе магистратуры AI Talent Hub [1]. В этой статье я хочу поделиться опытом разработки модели для распознавания русского рукописного текста.

Несмотря на всеобщую цифровизацию, огромное количество документов до сих пор заполняется от руки — медицинские карты, заявления, анкеты, почтовые адреса и многое другое. Автоматизация процесса распознавания таких документов может значительно ускорить их обработку и снизить количество ошибок при ручном вводе.

В статье я подробно расскажу о всех этапах создания модели:

  1. Какие данные использовал и где их взял;

  2. Какую архитектуру выбрал и почему;

  3. Как проходил процесс подготовки данных и обучения модели;

  4. Как организовал инференс.

Надеюсь, мой опыт будет полезен тем, кто столкнется с похожей задачей или просто интересуется данной тематикой.

Поехали!

Данные

В качестве данных для обучения использовался Cyrillic Handwriting Dataset [2] с Kaggle. Это набор рукописных текстов на кириллице, специально созданный для задач OCR (Optical Character Recognition). Датасет содержит 73830 примеров и уже разделен на train и test выборки в соотношении 95% и 5% соответственно.

Особенности датасета:

  • Каждый файл представляет собой PNG-изображение с текстом в одну строку;

  • Длина текста не превышает 40 символов;

  • Тексты написаны разными людьми, что обеспечивает разнообразие почерков;

  • Есть как цветные изображения, так и черно-белые;

  • Каждое изображение сопровождается правильной расшифровкой написанного.

Рис. 1: Примеры из Cyrillic Handwriting Dataset.

Рис. 1: Примеры из Cyrillic Handwriting Dataset.

Модель

Распознавание русского рукописного текста - 2

В качестве базовой архитектуры возьмем trocr-base-handwritten [3] от Microsoft. Это трансформерная модель, которую зафайнтюнили на датасете IAM [4] (примерно такие же картинки с текстом, как и в Cyrillic Handwriting, только на английском). На упомянутом наборе данных эта модель является SOTA [5]

Рис. 2: Архитектура TrOCR из оригинальной статьи

Рис. 2: Архитектура TrOCR из оригинальной статьи

TrOCR — это энкодер-декодерная модель, состоящая из image transformer в качестве энкодера и text transformer в качестве декодера. Энкодер был инициализирован с весами BEiT, в то время как декодер — с весами RoBERTa. А вот веса cross-attention между ними были инициализированы уже случайно.

Изображения представляются модели в виде последовательности патчей фиксированного размера (разрешение 16x16 пикселей). Перед передачей в слои энкодера к последовательности добавляется позиционный эмбеддинг. Затем декодер текста авторегрессионно генерирует токены.

TrOCR статья: https://arxiv.org/abs/2109.10282 [6]

TrOCR документация: https://huggingface.co/transformers/master/model_doc/trocr.html [7]

Подготовка данных и дообучение модели

Далее мы разберем как подготовить данные и зафайнтюнить выбранную модель под нашу задачу. Мы будем использовать класс VisionEncoderDecoderModel [8] из библиотеки transformers, который можно использовать для объединения любого image Transformer encoder (например, ViT, BEiT) с любым text Transformer в качестве декодера (например, BERT, RoBERTa, GPT-2). Примером этого является TrOCR, поскольку он имеет архитектуру энкодер-декодера, как уже говорилось ранее.

Установка библиотек

Для начала нам потребуется установить следующие библиотеки:

  • transformers — для работы с моделью;

  • evaluate и jiwer — для расчетов метрики.

Это можно сделать командой:

!pip install -q transformers evaluate jiwer

Мы не будем использовать datasets от HuggingFace для предобработки данных, а воспользуемся старым добрым Dataset из torch.

Распознавание русского рукописного текста - 4

Теперь проверим, что нам доступна CUDA:

import torch

torch.cuda.is_available()

Если все хорошо, то получим True.

В нашем случае использовалась видеокарточка H100, но подойдет и менее мощная.

Предобработка данных

Распознавание русского рукописного текста - 5

Перед тем, как обучать модель, надо загрузить и подготовить данные. Давайте этим и займемся!

Сначала мы загрузим данные. Скачать их можно тут [2]. Мы будем использовать только train часть из датасета, так как примеров там предостаточно. Для этого сделаем pandas dataframe из файла train.tsv.

import pandas as pd

train_val_df = pd.read_csv(
   "cyrillic-handwriting-dataset/train.tsv",
   sep="t",
   header=None,
   names=["file_name", "text"],
)

train_val_df.head()

file_name

text

0

aa1.png

Молдова

1

aa1007.png

продолжила борьбу

2

aa101.png

разработанные

3

aa1012.png

Плачи

4

aa1013.png

Гимны богам

В таблице присутствует одна строка с пустым text, из-за которой может поломаться все обучение, поэтому дропнем ее.

train_val_df = train_val_df.dropna()

Следующим шагом разделим данные на трейн + валидацию в пропорции 80/20 через train_test_split из sklearn.

from sklearn.model_selection import train_test_split

train_df, eval_df = train_test_split(train_val_df, test_size=0.2)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
eval_df.reset_index(drop=True, inplace=True)

Каждый элемент итогового датасета должен возвращать:

  1. pixel_values (значения пикселей исходных изображений) — это будет input для модели;

  2. labels — input_ids соответствующего текста на изображении.

Мы будем использовать TrOCRProcessor [9] для приведения данных в нужные форматы. TrOCRProcessor — это обертка над ViTFeatureExtractor [10] и RobertaTokenizer [11]. Первый будет использоваться для ресайзинга и нормализации изображений. Второй - для кодирования и декодирования текста в/из input_ids.

Штош, давайте напишем класс нашего датасета. Для этого нам нужно наследоваться от torch.utils.data.Dataset и определить 3 метода.

  1. Метод __init__ для инициализации экземпляра класса. В качестве входных параметров будем предавать следующее:

    1. root_dir — путь к директории, где хранятся изображения;

    2. df — pd.DataFrame, содержащий столбцы file_name и text;

    3. processor —  предобученный TrOCRProcessor;

    4. max_target_length — максимальная длина для текстовых меток (labels).

  2. Метод __len__, который будет возвращать количество элементов в датасете.

  3. Метод __getitem__:

    • Принимает на вход индекс idx;

    • Извлекает имя файла, изображения и соответствующий текст из df по индексу idx;

    • Загружает изображение из root_dir и конвертирует его в формат RGB;

    • Обрабатывает изображение с помощью processor, который возвращает тензоры, представляющие пиксели изображения;

    • Кодирует текст в input_ids с помощью токенизатора, который является частью processor, метки заполняются pad-токенами до max_target_length;

    • Заменяет токены pad-токены на -100, чтобы они игнорировались функцией потерь во время обучения;

    • Возвращает словарь encoding, содержащий pixel_values и labels.

import torch
from torch.utils.data import Dataset
from PIL import Image


class CHDataset(Dataset):
   def __init__(self, root_dir, df, processor, max_target_length=128):
       self.root_dir = root_dir
       self.df = df
       self.processor = processor
       self.max_target_length = max_target_length

   def __len__(self):
       return len(self.df)

   def __getitem__(self, idx):
       # get file name + text
       file_name = self.df["file_name"][idx]
       text = self.df["text"][idx]
       # prepare image (i.e. resize + normalize)
       image = Image.open(self.root_dir + file_name).convert("RGB")
       pixel_values = self.processor(image, return_tensors="pt").pixel_values
       # add labels (input_ids) by encoding the text
       labels = self.processor.tokenizer(
           text, padding="max_length", max_length=self.max_target_length
       ).input_ids
       # important: make sure that PAD tokens are ignored by the loss function
       labels = [
           label if label != self.processor.tokenizer.pad_token_id else -100
           for label in labels
       ]

       encoding = {
           "pixel_values": pixel_values.squeeze(),
           "labels": torch.tensor(labels),
       }
       return encoding

Давайте инициализируем наборы данных для обучения и оценки:

from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

train_dataset = CHDataset(
   root_dir="cyrillic-handwriting-dataset/train/",
   df=train_df,
   processor=processor,
)
eval_dataset = CHDataset(
   root_dir="cyrillic-handwriting-dataset/train/",
   df=eval_df,
   processor=processor,
)

Посмотрим на количество примеров в обеих подвыборках:

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

Number of training examples: 57827
Number of validation examples: 14457

И поймем какие размерности у их элементов:

encoding = train_dataset[0]
for k, v in encoding.items():
   print(k, v.shape)

pixel_values torch.Size([3, 384, 384])
labels torch.Size([128])

Получили, что pixel_values — это тензор размером [3, 384, 384]. Здесь 3 обозначает количество цветовых каналов в формате RGB, а 384x384 — это размеры изображения после изменения его размера (ресайзинга). А labels — тензор размером [128] с учетом pad-токенов (128, потому что указали такой max_target_length).

Данные готовы и наконец-то можно приступать к дообучению.

Файнтюнинг модели

Здесь мы инициализируем модель TrOCR с её предобученными весами. Обратите внимание, что веса головы языкового моделирования инициализированы из претрейна, так как модель уже была обучена генерировать текст на этапе предобучения. Подробности можно найти в статье [6].

from transformers import VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

Важно установить несколько атрибутов, а именно:

  • атрибуты, необходимые для создания decoder_input_ids из labels (модель автоматически создаст decoder_input_ids, сдвинув labels на одну позицию вправо и добавив decoder_start_token_id в начало, а также заменив идентификаторы, равные -100, на pad_token_id)

  • размер словаря модели (для языковой модели, расположенной поверх декодера)

  • параметры beam-search [12], которые используются при генерации текста.

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

Далее определим некоторые гиперпараметры обучения, создав экземпляр training_args. Важно отметить, что существует множество других параметров, полный список которых можно найти в документации [13]. Например, вы можете задать размер батча для обучения/оценки, решить, использовать ли mixed precision training, установить частоту сохранения модели и т.д.

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
   num_train_epochs=5,
   predict_with_generate=True,
   eval_strategy="epoch",
   save_strategy="epoch",
   per_device_train_batch_size=64,
   per_device_eval_batch_size=64,
   fp16=True,
   output_dir="./trained_models",
)

Мы будем оценивать модель по Character Error Rate (CER) (подробнее см. здесь [14]).

import evaluate

cer_metric = evaluate.load("cer")

Функция compute_metrics принимает на вход EvalPrediction (который является NamedTuple) и должна возвращать словарь. На этапе оценки модель вернёт EvalPrediction, который состоит из двух элементов:

  • predictions: предсказания, сделанные моделью.

  • label_ids: фактические истинные метки.

def compute_metrics(pred):
   labels_ids = pred.label_ids
   pred_ids = pred.predictions

   pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
   labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
   label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

   cer = cer_metric.compute(predictions=pred_str, references=label_str)

   return {"cer": cer}

Давайте начнем обучение! Мы также предоставляем default_data_collator для Trainer, который используется для объединения примеров в батчи.

Следует учесть, что оценка занимает довольно много времени, так как мы используем beam search для декодирования, что требует нескольких прямых проходов для каждого примера.

from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
   model=model,
   processing_class=processor.tokenizer,
   args=training_args,
   compute_metrics=compute_metrics,
   train_dataset=train_dataset,
   eval_dataset=eval_dataset,
   data_collator=default_data_collator,
)
trainer.train()

После обучения мы получили вот такие вот результаты:

  • Training Loss: 0.026100;

  • Validation Loss: 0.120961;

  • CER: 0.048542.

Вполне неплохой результат.
Осталось только сохранить модель, например, локально для дальнейшего использования.

processor.save_pretrained("model/TrOCRInferenceModel/weights")
model.save_pretrained("model/TrOCRInferenceModel/weights")

Инференс

После обучения мы можем легко загрузить модель, используя метод .from_pretrained(output_dir), так как на предыдущем шаге мы сохранили ее.

Загрузим модель и изображение из тестовой выборки:

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests

image = Image.open(
   "cyrillic-handwriting-dataset/test/test4.png"
).convert("RGB")

processor = TrOCRProcessor.from_pretrained(
   "model/TrOCRInferenceModel/weights"
)
model = VisionEncoderDecoderModel.from_pretrained(
   "model/TrOCRInferenceModel/weights"
)

Посмотрим, что там за текст на картинке.

Распознавание русского рукописного текста - 6

Тут написано слово «Определим». Конечно не с первого раза, но разобрать буквы можно. Правда в данном случае буква «п» больше похоже на «й». Посмотрим как с этим справится наша модель.

pixel_values = processor(images=image, return_tensors="pt").pixel_values

generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_text

«Определим»

Супер! Модель распознала слово без ошибок.

Вывод

В этой статье мы разобрали как можно дообучить модель TrOCR под задачу распознавания русского рукописного текста и какие данные для этого могут использоваться. Модель получилась довольно точной и может использоваться в реальных кейсах.

P.S.: Свою дообученную модель я разместил по ссылке [15]. Буду только рад, если она кому-то будет полезна.

Материал подготовил магистрант 2 курса AI Talent Hub [16], Арсений Казанцев.

Автор: ai-talent

Источник [17]


Сайт-источник PVSM.RU: https://www.pvsm.ru

Путь до страницы источника: https://www.pvsm.ru/itmo/404835

Ссылки в тексте:

[1] AI Talent Hub: https://ai.itmo.ru/?utm_source=habr&utm_medium=article&utm_campaign=artem

[2] Cyrillic Handwriting Dataset: https://www.kaggle.com/datasets/constantinwerner/cyrillic-handwriting-dataset

[3] trocr-base-handwritten: https://huggingface.co/microsoft/trocr-base-handwritten

[4] IAM: https://fki.tic.heia-fr.ch/databases/iam-handwriting-database

[5] SOTA: https://paperswithcode.com/sota/handwritten-text-recognition-on-iam-line

[6] https://arxiv.org/abs/2109.10282: https://arxiv.org/abs/2109.10282

[7] https://huggingface.co/transformers/master/model_doc/trocr.html: https://huggingface.co/transformers/master/model_doc/trocr.html

[8] VisionEncoderDecoderModel: https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder

[9] TrOCRProcessor: https://huggingface.co/docs/transformers/model_doc/trocr#transformers.TrOCRProcessor

[10] ViTFeatureExtractor: https://huggingface.co/transformers/v4.5.1/model_doc/vit.html#vitfeatureextractor

[11] RobertaTokenizer: https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer

[12] beam-search: https://www.geeksforgeeks.org/introduction-to-beam-search-algorithm/

[13] документации: https://huggingface.co/docs/transformers/main_classes/trainer#seq2seqtrainingarguments

[14] здесь: https://huggingface.co/spaces/evaluate-metric/cer/blob/main/README.md

[15] ссылке: https://huggingface.co/kazars24/trocr-base-handwritten-ru

[16] AI Talent Hub: https://ai.itmo.ru/?utm_source=habr&utm_medium=kazancev&utm_campaign=pedpraktika&utm_content=statya1&utm_term=llm_nlp

[17] Источник: https://habr.com/ru/articles/865688/?utm_campaign=865688&utm_source=habrahabr&utm_medium=rss