- PVSM.RU - https://www.pvsm.ru -

Умножение Монтгомери

Деление целых чисел — это долго и сложно. Вычислять остаток от деления — нисколько не проще. При этом в спортивном программировании, да и в прикладной математике типа криптографии, задача умножения чисел по модулю встречается повсеместно.

Один из вариантов эффективного решения — умножать по модулю, вообще при этом не используя операции деления, с помощью алгоритма Монтгомери.

Про него я и хотел бы поговорить.

Постановка задачи

Для простоты изложения, я буду приводить алгоритм для 32-битных целых чисел. Про более широкие числа тоже будут определённые заметки, просто без кода.

Итак, положим, что у нас есть нечётное число int N, по модулю которого необходимо производить вычисления. Например, это может быть знакомое многим число 1_000_000_007. Алгоритмы для чётных N можно свести к алгоритмам для нечётных, просто потребуется заметное количество дополнительного кода для контроля младших бит, так что оставим это на другой раз.

Наша цель — иметь эффективный аналог следующего кода:

int mulMod(int a, int b, int N) {
    long longA = Integer.toUnsignedLong(a);
    long longB = Integer.toUnsignedLong(b);
    long longN = Integer.toUnsignedLong(N);

    return (int) Long.remainderUnsigned(longA * longB, longN);
}

Код будем писать для беззнаковых 32-битных целых, так интереснее. Для них в Java мы вынуждены использовать знаковый int и специальные методы. Конвертация в long же необходима для контроля за переполнением при умножении.

Часто на протяжении статьи я буду передавать 32-битные числа как long только ради того, чтобы не повторять Integer.toUnsignedLong множество раз. Это вопрос в основном краткости кода, а не того, что мне нужны беззнаковые числа. Да и конвертация в long не бесплатная, если что.

Так же у нас есть число R=232, обозначающее разрядность чисел, с которыми мы работаем в коде. Именования N и R будут сохранены до конца статьи.

В теории можно брать любое значение R, большее N и взаимно-простое с ним. На практике удобно брать такое R, что умножение по модулю R эффективно реализуемо в коде. Умножение 32-битных чисел по модулю 232 в Java (да и в других языках) — это просто умножение в int, но в теории мы могли бы брать и 264, и 2128, и другие разрядности.

Форма Монтгомери

Вместо того, чтобы работать с непосредственно значениями, над которыми необходимо производить вычисления, мы будем использовать так называемую «форму Монтгомери» для этих значений:

m(a)=(aR)pmod{N}

От условного a % N она отличается предварительным умножением на R. Данная форма является линейным отображением, т.е. её можно спокойно складывать и вычитать:

begin{align} & aR + bR=(a+b)R \ & aR - bR=(a-b)R end{align}

А с умножением не всё так просто, ведь m(ab)neq m(a)cdot m(b). Вместо этого имеем следующее:

m(a)cdot m(b)=(aRcdot bR)pmod{N}=(abR)Rpmod{N}=m(m(ab))

Другими словами, m(ab)=m^{-1}(m(a)cdot m(b)). Если научиться эффективно вычислять m^{-1}(a), то убьём двух зайцев одним ударом — и умножать научимся, и сможем преобразовывать результат в первоначальный вид.

Также было бы здорово уметь вычислять саму форму Монтгомери, не прибегая к делению. Об этом мы тоже поговорим, просто чуть позже.

Редукция Монтгомери (REDC)

Вычисление m^{-1}(T) принято называть редукцией. На просторах интернета существует 2 версии данного алгоритма, мало отличающиеся друг от друга.

Параметр я назвал T для того, чтобы текст был ближе к Википедии [1]. Он может быть равен форме Монтгомери какого-то числа, если нам нужно преобразовать его в исходную форму (вычисление m^{-1}(m(a))), либо же он может быть равен произведению двух форм Монтгомери (вычисление m^{-1}(m(a)cdot m(b))). Т.е. в нашем примере T — это 64-битное число, т.к. нельзя терять биты переполнения от произведения.

Первым я опишу как раз алгоритм, который приведён в Википедии. Для начала положим, что мы нашли число N' такое, что Ncdot N' equiv (-1)pmod{R}, которое в коде я буду обозначать M. С помощью него нужно выполнить следующую процедуру:

begin{align} & m leftarrow ((Tpmod{R})cdot N') pmod{R} \ & t::leftarrow (T + mN) / R \ & if:(t ge N) \&;;;; m^{-1}(T) leftarrow (t-N) \ & else \&;;;; m^{-1}(T) leftarrow t end{align}

Операции вроде x mod R и x / R — тривиальны, и для R=232 представляют собой
x & 0xFFFFFFFFL и x >>> 32 соответственно. Для других разрядностей это тоже будут конъюнкция с маской и сдвиг. Если перевести это на привычный язык программного кода, то получится следующее:

static long redc(long T, long N, int M) {
    long m = Integer.toUnsignedLong(((int) T) * M);
    long t = (T + m * N) >>> 32;

    return t >= N ? t - N : t;
}

Помним, что N передаётся как long для того, чтобы каждый раз не приходилось вызывать Integer.toUnsignedLong. M же передавать в виде long необязательно.

Обоснование правильности приведённых формул довольно-таки скучное, при желании его можно найти самостоятельно на той же Википедии. Главное тут вот что:

begin{align} &T < N^2 < NR \& mN < NR end{align}Biggr}Rightarrow (T+mN) < 2NR

Именно тот факт, что 0 <= t < 2*N, позволяет заменить деление сравнением и вычитанием.

Для второй версии алгоритма нам понадобится другое значение M (N'), а именно такое, которое соответствует уравнению Ncdot N' equiv 1pmod{R}, без минуса. В этом случае код будет следующим:

static long redc(long T, long N, int M) {
    long m = Integer.toUnsignedLong(((int) T) * M);
    long t = (T - m * N) >>> 32;

    return t < 0 ? t + N : t;
}

Видите разницу? + m * N заменилось на - m * N, ведь у m из-за умножение на M, условно говоря, «противоположный знак». Кавычки потому, что все числа тут неотрицательные. Кроме t, конечно — его знак нужен.

Это в свою очередь приводит к тому, что -N < t < N, и здесь уже деление заменяется на сложение, а не вычитание.

Интуитивно кажется, что второй алгоритм на практике чуть более эффективен, ведь N'_2=N^{-1}pmod{R} выглядит проще, чем N'_1=(-N)^{-1}pmod{R}, отличие аж на целый минус!

Как вычислить M

M — это значение, которое для каждого N нужно вычислить лишь один раз. Тем не менее, подходы к его вычислению весьма интересны, поэтому я заострю на них особое внимание.

Расширенный алгоритм Евклида

Обычно в интернете для этого предлагают использовать расширенный алгоритм Евклида [2]. Мол «какой‑то алгоритм существует, а дальше сами разбирайтесь». Подробно объяснять его, опять же, долго и очень скучно. Плюс он должен входить в школьную программу по математике, насколько я помню.

Суть здесь в том, чтобы найти представление gcd(a, b) = s * a + t * b, где gcd — наибольший общий делитель (greatest common divisor). В нашем случае в качестве a и b выступают R и N.

По условию gcd(R, N) = 1. Более того, в 32-битной арифметике умножение на R всегда даёт 0, ведь оно эквивалентно сдвигу влево на 32 бита. Учитывая это, алгоритм Евклида фактически позволит нам найти представление 1 = t * N, т.е. буквально обратное по модулю R. Приведу код, чтобы не пришлось писать самим:

static int inverseExtendedEuclid(long N) {
    long old_r = 1L << 32, r = N;
    long old_t = 0,        t = 1;

    while (r != 0) {
        long q = old_r / r;

        long tmp0 = old_r;
        old_r = r;
        r = tmp0 - q * r;

        long tmp1 = old_t;
        old_t = t;
        t = tmp1 - q * t;

        // Контроль инвариантов, для понятности.
        assert (int) r == (int) t * N;
        assert (int) old_r == (int) old_t * N;
    }

    // Ещё один контроль инвариантов.
    assert r == 0;     // Условие выхода из цикла, мог бы и не писать.
    assert old_r == 1; // Значение gcd(R, N).

    return (int) old_t;
}

Малая теорема Ферма

Согласно известной теореме Эйлера [3], являющейся обобщением малой теоремы Ферма, для любых двух взаимно простых чисел (в нашем случае N и R) выполнено:

N^{varphi(R)} equiv 1pmod{R}

varphi(R) — это значение функции Эйлера, так же известной под именем totient, и равно оно количеству целых чисел, меньших R и при этом взаимно простых с R. Если p - простое число, то varphi(p^n)=(p-1)p^{n-1}, данная формула вполне известна и по ссылке на Википедии можно найти ей объяснение. Привожу я её для того, чтобы мы смогли вычислить varphi(2^{32}), она как раз подходит:

varphi(R)=varphi(2^{32})=(2-1)cdot 2^{31}

Раз нам известно, что N^{2^{31}} equiv 1pmod{R}, то легко заметить, что N^{2^{31}-1} equiv N^{-1}pmod{R}. А значит можно воспользоваться быстрым возведением в степень, заранее зная, что 231-1 состоит из 31-го единичного бита:

static int inverseEuler(int N) {
    int M = N;

    // 30 потому, что 1-я итерация уже выполнена в момент присваивания выше.
    for (int i = 0; i < 30; i++) {
        M = M * M * N;
    }

    return M;
}

Здесь N можно передать как int, поскольку знак N никак не влияет на вычисление 32-битных произведений.

Данный код уже не содержит делений, зато количество умножений в нём достаточно большое — O(log(R)), 60 штук для 32-битных чисел. Это слишком много.

Ещё можно добавить, что раз N^{2^{31}} equiv 1pmod{R}, то на него можно спокойно умножать:

N^{2^{31}}cdot N^{2^{31}-1}=N^{2^{32}-1} equiv N^{-1}pmod{R}

При явном возведении в степень эта формула приведёт в двум лишним умножениям. Тем не менее, именно такое представление является базисом для более продвинутых алгоритмов.

Метод Ньютона

Алгоритм берёт своё название от известного метода Ньютона [4], но дословно его не повторяет.

Суть в том, чтобы построить рекуррентную формулу, сходящуюся к верному решению. Имея изначально формулу Nx equiv 1 pmod{R}, мы можем преобразовать её сперва в Nx^2 equiv x pmod{R}, а уже после этого в x equiv 2x - Nx^2 pmod{R}. На основании этого равенства можно построить следующее рекуррентное соотношение:

begin{align}&x_0=1 \&x_{n+1}=x_n(2-Nx_n)end{align}

Для того, чтобы найти N^{-1}pmod{R}, достаточно всего 5-ти итераций, или если точнее, то log_2(log_2{R}) итераций, что уже значительно меньше, чем в предыдущем способе. Для 64-битных чисел было бы 6 итераций, а для 128-битных — 7. Код для int:

static int inverseNewton(int N) {
    int M = 2 - N;

    // 4 итерации потому, что 1-я итерация уже выполнена в момент присваивания выше.
    for (int i = 0; i < 4; i++) {
        M = M * (2 - N * M);
    }

    return M;
}

Ссылку на доказательство предоставлю позже, чтобы вы по ней спойлеров не начитались.

Алгоритм Дюма

Метод Ньютона — хороший, и с точки зрения оценки трудоёмкости — оптимальный. Всё, что пойдёт далее — это точечные улучшения. Первое из них выглядит, на первый взгляд, не интуитивно. Да и на второй тоже:

static int inverseDumas(int N) {
    int M = 2 - N;
    int y = N - 1;

    // 4 итерации потому, что 1-я итерация уже выполнена в момент присваивания выше.
    for (int i = 0; i < 4; i++) {
        y = y * y;
        M = M * (1 + y);
    }

    return M;
}

Формулы для y подобраны таким образом, чтобы M на каждой итерации был точно таким же, как в методе Ньютона, это несложно доказать с помощью математической индукции. Цель такого изменения станет понятна, если вручную развернуть цикл:

static int inverseDumas(int N) {
    int M = 2 - N;
    int y = N - 1;

    y = y * y;
    M = M * (1 + y);

    y = y * y;
    M = M * (1 + y);

    y = y * y;
    M = M * (1 + y);

    y = y * y;
    M = M * (1 + y);

    return M;
}

В методе Ньютона выражение M = M * (2 - N * M) обязано быть вычисленным по порядку, ведь результат каждой из операций является операндом для следующей операции. В методе Дюма же есть 2 независимых выражения — M = M * (1 + y) и следующий за ним y = y * y, и ваш процессор может вычислять их буквально одновременно, значительно ускоряя весь процесс.

Оптимизация первых итераций

Описывая трудоёмкость, я уже упомянул, что, увеличивая число итераций, можно получить алгоритмы для 64 или 128-битных чисел, к примеру. Аналогично этому, ничего не мешает нам уменьшать число итераций и получать алгоритмы для 16, 8 или даже 4-битных чисел:

static int inverseDumas4(int N) {
    int M = 2 - N;
    int y = N - 1;

    y = y * y;
    M = M * (1 + y);

    // Можно обрезать лишние биты, если нужно.
    // А можно и не обрезать, если не нужно.
    return M & 0xF;
}

Это, конечно, забавное свойство, но зачем оно нам? А затем, что для 4-битных чисел есть более эффективные алгоритмы!

Первый такой алгоритм был предложен самим Монтгомери:

int inverseMontgomery4(int N) {
    return 3 * N ^ 2;
}

Умножение на 3 можно вообще не считать за умножение, ведь оно эквивалентно выражению N + N + N, которое наверняка вычисляется более эффективно.

Альтернативный вариант был найден неуказанным автором с помощью брутфорса:

int inverseBruteforce4(int N) {
    return (N ^ 2) - 2 * N;
}

Скобки обязательны, поскольку у - приоритет выше, чем у ^. Комментарий про умножение тут имеет ещё больше смысла, его ещё и на сдвиг заменить можно.

Оба этих алгоритма доказываются полным перебором. К счастью, существует всего лишь 16 различных 4-битных чисел, так что перебор выходит небольшим.

Если объединить всё сказанное, то получим следующий код:

private static int inverseDumas(int N) {
    int M = 3 * N ^ 2;
    int y = 1 - N * M;

    M = M * (1 + y);

    y = y * y;
    M = M * (1 + y);

    y = y * y;
    M = M * (1 + y);

    return M;
}

Вот и обещанная ссылка на источник [5], из которого взяты эти методы, там же можно найти доказательства. Почему в начале y = 1 - N * M там тоже есть, пересказывать не буду.

Судя по тому, что там сказано, это самый быстрый из известных способов найти N^{-1}pmod{2^{32}}, что, на мой взгляд, очень круто!

Как вычислить форму Монтгомери

Забыл сказать важную вещь: если умножение по модулю нужно сделать всего несколько раз, то, может, и алгоритм Монтгомери не нужен, ведь вычисление формы Монтгомери имеет свою цену.

Но вот что ещё интересно: в зависимости от того, для скольки различных чисел нужно найти форму Монтгомери, мы тоже можем использовать для этого разные подходы.

Самый очевидный подход — тупо взять и посчитать:

static long m(int a, long N) {
    // a * R % N;
    return Long.remainderUnsigned(((long) a) << 32, N);
}

Для данного кода существует один интересный частный случай — вычисление m(1):

static long m1(int N) {
    return Integer.remainderUnsigned(-1, N) + 1;
}

Данный код является лишь адаптацией формулы R pmod{N}=(R - 1) pmod{N} + 1, которая в свою очередь справедлива потому что gcd(R, N)=1, а значит Умножение Монтгомери 0" alt="R pmod{N} > 0" src="http://www.pvsm.ru/images/2024/07/14/umnojenie-montgomeri-38.svg" width="147" height="22" title="Умножение Монтгомери - 38"/>.

А хорош этот код тем, что использует 32-битное деление вместо 64-битного, т.е. работает быстрее.

Алгоритм Лемира

Что делать, если форму Монтгомери нужно вычислять достаточно часто, и делить при этом мы не хотим? Может быть для 32-битных чисел это не так принципиально, а для чисел большей разрядности деление становится всё более и более дорогим.

Выход есть, но он не простой, потому что потребует от нас 128-битного умножения для того, чтобы реализовать деление 64-битных чисел без, собственно, деления. Идея алгоритма взята отсюда [6] и аккуратно перенесена на Java мною лично.

Первое, что нужно сделать — это вычислить

lceil frac{2^{128}}{N}rceil

Тут имеется ввиду округление вверх, так же известное как ceil. Повторяя рассуждения, которые были чуть выше, можно заметить, что это значение совпадает с

lceil frac{2^{128} - 1}{N}rceil

Напомню, что 2128-1 — это 128-битное число, состоящее из 128 единиц.

Результат выражения в коде по ссылке называют M, я повторю это именование, думаю, путаницы не будет. Алгоритмы деления длинных чисел явно выходят за рамки данной статьи, их много и они сложные. Поэтому в данном конкретном случае обойдёмся чем-нибудь простым, например ручным делением с помощью цикла.

static UUID M(long N) {
    // Старшие байты деления совпадут с делением старшей половины числа на N.
    long msb = Long.divideUnsigned(-1L, N);

    // Long.remainderUnsigned(-1L, N);
    long r = -1L - N * msb;

    // Младшие биты будем накапливать в цикле.
    long lsb = 0;
    for (int i = 0; i < 64; i++) {
        lsb <<= 1;
        r = (r << 1) + 1;

        if (N < r) {
            r -= N;
            lsb |= 1;
        }
    }

    // Округление вверх, тот самый ceil.
    lsb++;
    // Обработка переполнения младших байтов.
    if (lsb == 0) msb++;

    // Да, использую UUID для 128-битных чисел, и что?
    return new UUID(msb, lsb);
}

Неэффективный алгоритм деления не особо страшен, поскольку это буквально единственное деление, которое нам нужно.

Что дальше делать с этим числом, спросите? Подставить его в эту формулу:

m(a)=aR pmod{N}=lfloor aRMN / 2^{128}rfloorpmod{2^{64}}

Для того, чтобы вычислить эту формулу и ничего не потерять, нам понадобится:

  • aR — 64-битное число.

  • aRcdot M — это 192-битное число, но нам достаточно вычислить младшие 128 бит. Остальные исчезнут при вычислении mod{2^{64}}.

  • aRMcdot N — это тоже 192-битное число, но нам достаточно вычислить старшие 64 бита, остальные будут проигнорированы при делении на 2128.

  • Из результирующих 64 бит нам в реальности понадобятся только младшие 32. Тем не менее, я раньше уже оговаривал, что беззнаковые 32 битные целые нам проще будет хранить в long, чтобы реже вызывать Integer.toUnsignedLong.

В коде это будет выглядеть следующим образом:

static long m(long a, long N, UUID M) {
    long lsb_M = M.getLeastSignificantBits();
    long msb_M = M.getMostSignificantBits();

    // A = a * R;
    long A = a << 32;

    // L = A * M;
    long lsb_L = lsb_M * A;
    long msb_L = msb_M * A + Math.unsignedMultiplyHigh(lsb_M, A);

    // Дальше идёт то, что у Лемира названо "mul128_u64", а именно вычисление
    // старших 64 бит произведения 128-битной L и 64-битной N.
    long lsb_Bottom = Math.unsignedMultiplyHigh(lsb_L, N);

    long lsb_Top = msb_L * N;
    long msb_Top = Math.unsignedMultiplyHigh(msb_L, N);

    if ((lsb_Bottom & lsb_Top) < 0 || (lsb_Bottom | lsb_Top) < 0 && (lsb_Bottom + lsb_Top) >= 0) {
        msb_Top++;
    }

    return msb_Top & 0xFFFFFFFFL;
}

Данный код требует некоторых пояснений, поскольку он не очень простой.

Во-первых, как вычисляется L. Если положить, что M=msbcdot2^{64} + lsb— разложение на старшие и младшие байты, то получим

Acdot M=[(Acdot msb)pmod{2^{64}} + lfloor frac{Acdot lsb}{2^{64}}rfloor]cdot 2^{64} + [(Acdot lsb)pmod{2^{64}}]

Именно это написано в коде.

Во-вторых, Math.unsignedMultiplyHigh может показаться чем-то незнакомым. Дело в том, что он был добавлен в Java только в 18-й версии, так что если хотите позапускать код под версией поменьше, то воспользуйтесь, пожалуйста, этой копией реализации:

  public static long unsignedMultiplyHigh(long x, long y) {
      // Compute via multiplyHigh() to leverage the intrinsic
      long result = Math.multiplyHigh(x, y);
      result += (y & (x >> 63)); // equivalent to `if (x < 0) result += y;`
      result += (x & (y >> 63)); // equivalent to `if (y < 0) result += x;`
      return result;
  }

В-третьих, что происходит в mul128_u64. Суть там та же, что и при вычислении L, просто гораздо больше движущихся частей. По этой причине формулу я писать не буду, сильно уж большая.

В любом случае, самая загадочная часть данного кода — это условие в if.

if ((a & b) < 0 || (a | b) < 0 && (a + b) >= 0) ...

Данный код проверяет, случится ли целочисленное переполнение, если вычислить сумму a и b как беззнаковое число. Это случается в двух случаях:

  • Когда оба числа отрицательные, т.е. имеют в качества старшего бита единицы — это левая часть дизъюнкции.

  • Когда оба числа имеют разный знак, т.е. у них отличаются старшие биты, и сумма чисел имеет в старшем бите 0 — это правая часть дизъюнкции.

Наверняка существует более элегантный способ проверки переполнения, но мне он в голову не пришёл. Если условие выполнено, то сложение чисел приводит к переполнению и выставлению carry флага, который нужно не забыть прибавить к старшим байтам вычисляемого значения. Это и делается кодом msb_Top++. Надеюсь, ничего не напутал, тесты у меня вроде проходят.

Удивительно, но этот код действительно быстрее, чем 64-битное деление, причём минимум процентов на 25. Во всяком случае, согласно моим тестам. С трудом верится, но что поделать. Данный алгоритм обобщается и на другие разрядности, вероятно там он тоже должен побеждать деление.

Заключение

В первую очередь, хочу выразить благодарность пользователям @encyclopedist [7], @YouDontKnowMe [8] и @Ruimteschroot [9], оставившим комментарии к моей предыдущей статье и побудивших меня разобраться в теме.

Что в итоге. По моему опыту, в задачах спортивного программирования всё зависит от трудоёмкости алгоритма, а не от низкоуровнего тюнинга. Много лет я писал a % N и не испытывал никаких проблем, укладываясь в нужные лимиты по времени. Большей проблемой для меня был парсинг ввода (java.util.Scanner — тормозное зло, не вздумайте использовать).

С другой стороны, есть спортивное программирование, в котором люди соревнуются в скорости программ, а не только в корректности, как например тут [10]. Здесь уже использование обычного деления не прокатит, так что какая-то ниша у подобных алгоритмов всё-же есть.

Ну и естественно криптография. Возводить длинные числа в степени по модулю во время шифрования — это прямо норма.

Автор: Иван Бессонов

Источник [11]


Сайт-источник PVSM.RU: https://www.pvsm.ru

Путь до страницы источника: https://www.pvsm.ru/algoritm-evklida/392667

Ссылки в тексте:

[1] Википедии: https://en.wikipedia.org/wiki/Montgomery_modular_multiplication

[2] расширенный алгоритм Евклида: https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm

[3] теореме Эйлера: https://en.wikipedia.org/wiki/Euler%27s_theorem

[4] метода Ньютона: https://en.wikipedia.org/wiki/Newton%27s_method

[5] ссылка на источник: https://github.com/hurchalla/modular_arithmetic/blob/master/montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/low_level_api/detail/integer_inverse.pdf

[6] взята отсюда: https://github.com/lemire/fastmod/blob/master/include/fastmod.h

[7] @encyclopedist: https://www.pvsm.ru/users/encyclopedist

[8] @YouDontKnowMe: https://www.pvsm.ru/users/youdontknowme

[9] @Ruimteschroot: https://www.pvsm.ru/users/ruimteschroot

[10] тут: https://highload.fun/tasks/10

[11] Источник: https://habr.com/ru/articles/827880/?utm_source=habrahabr&utm_medium=rss&utm_campaign=827880