- PVSM.RU - https://www.pvsm.ru -
Привет! Меня зовут Арсений, я работаю ML-инженером в компании Вита и параллельно учусь на втором курсе магистратуры AI Talent Hub [1]. В этой статье я хочу поделиться опытом разработки модели для распознавания русского рукописного текста.
Несмотря на всеобщую цифровизацию, огромное количество документов до сих пор заполняется от руки — медицинские карты, заявления, анкеты, почтовые адреса и многое другое. Автоматизация процесса распознавания таких документов может значительно ускорить их обработку и снизить количество ошибок при ручном вводе.
В статье я подробно расскажу о всех этапах создания модели:
Какие данные использовал и где их взял;
Какую архитектуру выбрал и почему;
Как проходил процесс подготовки данных и обучения модели;
Как организовал инференс.
Надеюсь, мой опыт будет полезен тем, кто столкнется с похожей задачей или просто интересуется данной тематикой.
Поехали!
В качестве данных для обучения использовался Cyrillic Handwriting Dataset [2] с Kaggle. Это набор рукописных текстов на кириллице, специально созданный для задач OCR (Optical Character Recognition). Датасет содержит 73830 примеров и уже разделен на train и test выборки в соотношении 95% и 5% соответственно.
Особенности датасета:
Каждый файл представляет собой PNG-изображение с текстом в одну строку;
Длина текста не превышает 40 символов;
Тексты написаны разными людьми, что обеспечивает разнообразие почерков;
Есть как цветные изображения, так и черно-белые;
Каждое изображение сопровождается правильной расшифровкой написанного.

В качестве базовой архитектуры возьмем trocr-base-handwritten [3] от Microsoft. Это трансформерная модель, которую зафайнтюнили на датасете IAM [4] (примерно такие же картинки с текстом, как и в Cyrillic Handwriting, только на английском). На упомянутом наборе данных эта модель является SOTA [5].
TrOCR — это энкодер-декодерная модель, состоящая из image transformer в качестве энкодера и text transformer в качестве декодера. Энкодер был инициализирован с весами BEiT, в то время как декодер — с весами RoBERTa. А вот веса cross-attention между ними были инициализированы уже случайно.
Изображения представляются модели в виде последовательности патчей фиксированного размера (разрешение 16x16 пикселей). Перед передачей в слои энкодера к последовательности добавляется позиционный эмбеддинг. Затем декодер текста авторегрессионно генерирует токены.
TrOCR статья: https://arxiv.org/abs/2109.10282 [6]
TrOCR документация: https://huggingface.co/transformers/master/model_doc/trocr.html [7]
Далее мы разберем как подготовить данные и зафайнтюнить выбранную модель под нашу задачу. Мы будем использовать класс VisionEncoderDecoderModel [8] из библиотеки transformers, который можно использовать для объединения любого image Transformer encoder (например, ViT, BEiT) с любым text Transformer в качестве декодера (например, BERT, RoBERTa, GPT-2). Примером этого является TrOCR, поскольку он имеет архитектуру энкодер-декодера, как уже говорилось ранее.
Для начала нам потребуется установить следующие библиотеки:
transformers — для работы с моделью;
evaluate и jiwer — для расчетов метрики.
Это можно сделать командой:
!pip install -q transformers evaluate jiwer
Мы не будем использовать datasets от HuggingFace для предобработки данных, а воспользуемся старым добрым Dataset из torch.

Теперь проверим, что нам доступна CUDA:
import torch
torch.cuda.is_available()
Если все хорошо, то получим True.
В нашем случае использовалась видеокарточка H100, но подойдет и менее мощная.

Перед тем, как обучать модель, надо загрузить и подготовить данные. Давайте этим и займемся!
Сначала мы загрузим данные. Скачать их можно тут [2]. Мы будем использовать только train часть из датасета, так как примеров там предостаточно. Для этого сделаем pandas dataframe из файла train.tsv.
import pandas as pd
train_val_df = pd.read_csv(
"cyrillic-handwriting-dataset/train.tsv",
sep="t",
header=None,
names=["file_name", "text"],
)
train_val_df.head()
|
file_name |
text |
|
|
0 |
aa1.png |
Молдова |
|
1 |
aa1007.png |
продолжила борьбу |
|
2 |
aa101.png |
разработанные |
|
3 |
aa1012.png |
Плачи |
|
4 |
aa1013.png |
Гимны богам |
В таблице присутствует одна строка с пустым text, из-за которой может поломаться все обучение, поэтому дропнем ее.
train_val_df = train_val_df.dropna()
Следующим шагом разделим данные на трейн + валидацию в пропорции 80/20 через train_test_split из sklearn.
from sklearn.model_selection import train_test_split
train_df, eval_df = train_test_split(train_val_df, test_size=0.2)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
eval_df.reset_index(drop=True, inplace=True)
Каждый элемент итогового датасета должен возвращать:
pixel_values (значения пикселей исходных изображений) — это будет input для модели;
labels — input_ids соответствующего текста на изображении.
Мы будем использовать TrOCRProcessor [9] для приведения данных в нужные форматы. TrOCRProcessor — это обертка над ViTFeatureExtractor [10] и RobertaTokenizer [11]. Первый будет использоваться для ресайзинга и нормализации изображений. Второй - для кодирования и декодирования текста в/из input_ids.
Штош, давайте напишем класс нашего датасета. Для этого нам нужно наследоваться от torch.utils.data.Dataset и определить 3 метода.
Метод __init__ для инициализации экземпляра класса. В качестве входных параметров будем предавать следующее:
root_dir — путь к директории, где хранятся изображения;
df — pd.DataFrame, содержащий столбцы file_name и text;
processor — предобученный TrOCRProcessor;
max_target_length — максимальная длина для текстовых меток (labels).
Метод __len__, который будет возвращать количество элементов в датасете.
Метод __getitem__:
Принимает на вход индекс idx;
Извлекает имя файла, изображения и соответствующий текст из df по индексу idx;
Загружает изображение из root_dir и конвертирует его в формат RGB;
Обрабатывает изображение с помощью processor, который возвращает тензоры, представляющие пиксели изображения;
Кодирует текст в input_ids с помощью токенизатора, который является частью processor, метки заполняются pad-токенами до max_target_length;
Заменяет токены pad-токены на -100, чтобы они игнорировались функцией потерь во время обучения;
Возвращает словарь encoding, содержащий pixel_values и labels.
import torch
from torch.utils.data import Dataset
from PIL import Image
class CHDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=128):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df["file_name"][idx]
text = self.df["text"][idx]
# prepare image (i.e. resize + normalize)
image = Image.open(self.root_dir + file_name).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = self.processor.tokenizer(
text, padding="max_length", max_length=self.max_target_length
).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [
label if label != self.processor.tokenizer.pad_token_id else -100
for label in labels
]
encoding = {
"pixel_values": pixel_values.squeeze(),
"labels": torch.tensor(labels),
}
return encoding
Давайте инициализируем наборы данных для обучения и оценки:
from transformers import TrOCRProcessor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
train_dataset = CHDataset(
root_dir="cyrillic-handwriting-dataset/train/",
df=train_df,
processor=processor,
)
eval_dataset = CHDataset(
root_dir="cyrillic-handwriting-dataset/train/",
df=eval_df,
processor=processor,
)
Посмотрим на количество примеров в обеих подвыборках:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))
Number of training examples: 57827
Number of validation examples: 14457
И поймем какие размерности у их элементов:
encoding = train_dataset[0]
for k, v in encoding.items():
print(k, v.shape)
pixel_values torch.Size([3, 384, 384])
labels torch.Size([128])
Получили, что pixel_values — это тензор размером [3, 384, 384]. Здесь 3 обозначает количество цветовых каналов в формате RGB, а 384x384 — это размеры изображения после изменения его размера (ресайзинга). А labels — тензор размером [128] с учетом pad-токенов (128, потому что указали такой max_target_length).
Данные готовы и наконец-то можно приступать к дообучению.
Здесь мы инициализируем модель TrOCR с её предобученными весами. Обратите внимание, что веса головы языкового моделирования инициализированы из претрейна, так как модель уже была обучена генерировать текст на этапе предобучения. Подробности можно найти в статье [6].
from transformers import VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
Важно установить несколько атрибутов, а именно:
атрибуты, необходимые для создания decoder_input_ids из labels (модель автоматически создаст decoder_input_ids, сдвинув labels на одну позицию вправо и добавив decoder_start_token_id в начало, а также заменив идентификаторы, равные -100, на pad_token_id)
размер словаря модели (для языковой модели, расположенной поверх декодера)
параметры beam-search [12], которые используются при генерации текста.
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
Далее определим некоторые гиперпараметры обучения, создав экземпляр training_args. Важно отметить, что существует множество других параметров, полный список которых можно найти в документации [13]. Например, вы можете задать размер батча для обучения/оценки, решить, использовать ли mixed precision training, установить частоту сохранения модели и т.д.
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
num_train_epochs=5,
predict_with_generate=True,
eval_strategy="epoch",
save_strategy="epoch",
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
fp16=True,
output_dir="./trained_models",
)
Мы будем оценивать модель по Character Error Rate (CER) (подробнее см. здесь [14]).
import evaluate
cer_metric = evaluate.load("cer")
Функция compute_metrics принимает на вход EvalPrediction (который является NamedTuple) и должна возвращать словарь. На этапе оценки модель вернёт EvalPrediction, который состоит из двух элементов:
predictions: предсказания, сделанные моделью.
label_ids: фактические истинные метки.
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer}
Давайте начнем обучение! Мы также предоставляем default_data_collator для Trainer, который используется для объединения примеров в батчи.
Следует учесть, что оценка занимает довольно много времени, так как мы используем beam search для декодирования, что требует нескольких прямых проходов для каждого примера.
from transformers import default_data_collator
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
processing_class=processor.tokenizer,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
trainer.train()
После обучения мы получили вот такие вот результаты:
Training Loss: 0.026100;
Validation Loss: 0.120961;
CER: 0.048542.
Вполне неплохой результат.
Осталось только сохранить модель, например, локально для дальнейшего использования.
processor.save_pretrained("model/TrOCRInferenceModel/weights")
model.save_pretrained("model/TrOCRInferenceModel/weights")
После обучения мы можем легко загрузить модель, используя метод .from_pretrained(output_dir), так как на предыдущем шаге мы сохранили ее.
Загрузим модель и изображение из тестовой выборки:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
image = Image.open(
"cyrillic-handwriting-dataset/test/test4.png"
).convert("RGB")
processor = TrOCRProcessor.from_pretrained(
"model/TrOCRInferenceModel/weights"
)
model = VisionEncoderDecoderModel.from_pretrained(
"model/TrOCRInferenceModel/weights"
)
Посмотрим, что там за текст на картинке.

Тут написано слово «Определим». Конечно не с первого раза, но разобрать буквы можно. Правда в данном случае буква «п» больше похоже на «й». Посмотрим как с этим справится наша модель.
pixel_values = processor(images=image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_text
«Определим»
Супер! Модель распознала слово без ошибок.
В этой статье мы разобрали как можно дообучить модель TrOCR под задачу распознавания русского рукописного текста и какие данные для этого могут использоваться. Модель получилась довольно точной и может использоваться в реальных кейсах.
P.S.: Свою дообученную модель я разместил по ссылке [15]. Буду только рад, если она кому-то будет полезна.
Материал подготовил магистрант 2 курса AI Talent Hub [16], Арсений Казанцев.
Автор: ai-talent
Источник [17]
Сайт-источник PVSM.RU: https://www.pvsm.ru
Путь до страницы источника: https://www.pvsm.ru/itmo/404835
Ссылки в тексте:
[1] AI Talent Hub: https://ai.itmo.ru/?utm_source=habr&utm_medium=article&utm_campaign=artem
[2] Cyrillic Handwriting Dataset: https://www.kaggle.com/datasets/constantinwerner/cyrillic-handwriting-dataset
[3] trocr-base-handwritten: https://huggingface.co/microsoft/trocr-base-handwritten
[4] IAM: https://fki.tic.heia-fr.ch/databases/iam-handwriting-database
[5] SOTA: https://paperswithcode.com/sota/handwritten-text-recognition-on-iam-line
[6] https://arxiv.org/abs/2109.10282: https://arxiv.org/abs/2109.10282
[7] https://huggingface.co/transformers/master/model_doc/trocr.html: https://huggingface.co/transformers/master/model_doc/trocr.html
[8] VisionEncoderDecoderModel: https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder
[9] TrOCRProcessor: https://huggingface.co/docs/transformers/model_doc/trocr#transformers.TrOCRProcessor
[10] ViTFeatureExtractor: https://huggingface.co/transformers/v4.5.1/model_doc/vit.html#vitfeatureextractor
[11] RobertaTokenizer: https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer
[12] beam-search: https://www.geeksforgeeks.org/introduction-to-beam-search-algorithm/
[13] документации: https://huggingface.co/docs/transformers/main_classes/trainer#seq2seqtrainingarguments
[14] здесь: https://huggingface.co/spaces/evaluate-metric/cer/blob/main/README.md
[15] ссылке: https://huggingface.co/kazars24/trocr-base-handwritten-ru
[16] AI Talent Hub: https://ai.itmo.ru/?utm_source=habr&utm_medium=kazancev&utm_campaign=pedpraktika&utm_content=statya1&utm_term=llm_nlp
[17] Источник: https://habr.com/ru/articles/865688/?utm_campaign=865688&utm_source=habrahabr&utm_medium=rss
Нажмите здесь для печати.