После просмотра курса Programming Languages и прочтения Functional JavaScript захотелось повторить все эти крутые штуки в python. Часть вещей получилось сделать красиво и легко, остальное вышло страшным и непригодным для использования.
Статья включает в себя:
- немного непонятных слов;
- каррирование;
- pattern matching;
- рекурсия (включая хвостовую).
Статья рассчитана на python 3.3+.
Немного непонятных слов
На python можно писать в функциональном стиле, ведь в нем есть анонимные функции:
sum_x_y = lambda x, y: x + y
print(sum_x_y(1, 2)) # 3
Функции высшего порядка (принимающие или возвращающие другие функции):
def call_and_twice(fnc, x, y):
return fnc(x, y) * 2
print(call_and_twice(sum_x_y, 3, 4)) # 14
Замыкания:
def closure_sum(x):
fnc = lambda y: x + y
return fnc
sum_with_3 = closure_sum(3)
print(sum_with_3(12)) # 15
Tuple unpacking(почти pattern matching):
a, b, c = [1, 2, 3]
print(a, b, c) # 1 2 3
hd, *tl = range(5)
print(hd, 'tl:', *tl) # 0 tl: 1 2 3 4
И крутые модули functools и itertools.
Каррирование
Преобразование функции от многих аргументов в функцию, берущую свои аргументы по одному.
Рассмотрим самый простой случай, каррируем функцию sum_x_y
:
sum_x_y_carry = lambda x: lambda y: sum_x_y(x, y)
print(sum_x_y_carry(5)(12)) # 17
Что-то совсем не круто, попробуем так:
sum_with_12 = sum_x_y_carry(12)
print(sum_with_12(1), sum_with_12(12)) # 13 24
sum_with_5 = sum_x_y_carry(5)
print(sum_with_12(10), sum_with_12(17)) # 22 29
Уже интересней, теперь сделаем универсальную функцию для каррирования функций с двумя аргументами, ведь каждый раз писать lambda x: lambda y: zzzz
совсем не круто:
curry_2 = lambda fn: lambda x: lambda y: fn(x, y)
И применим ее к используемой в реальных проектах функции map
:
curry_map_2 = curry_2(map)
@curry_map_2
def twice_or_increase(n):
if n % 2 == 0:
n += 1
if n % 3:
n *= 2
return n
print(*twice_or_increase(range(10))) # 2 2 3 3 10 10 14 14 9 9
print(*twice_or_increase(range(30))) # 2 2 3 3 10 10 14 14 9 9 22 22 26 26 15 15 34 34 38...
Да-да, я использовал каррированый map
как декоратор и нивелировал этим отсутствие многострочных лямбд.
Но не все функции принимают 2 аргумента, поэтому сделаем функцию curry_n
, используя partial, замыкания и немножко рекурсии:
from functools import partial
def curry_n(fn, n):
def aux(x, n=None, args=None): # вспомогательная функция
args = args + [x] # добавим аргумент в список всех аргументов
return partial(aux, n=n - 1, args=args) if n > 1 else fn(*args) # вернем функцию с одним аргументом, созданную из aux либо вызовем изначальную с полученными аргументами
return partial(aux, n=n, args=[])
И в очередной раз применим к map
, но уже с 3 аргументами:
curry_3_map = curry_n(map, 3)
И сделаем функцию для сложения элементов списка с элементами списка 1..10:
sum_arrays = curry_3_map(lambda x, y: x + y)
sum_with_range_10 = sum_arrays(range(10))
print(*sum_with_range_10(range(100, 0, -10))) # 100 91 82 73 64 55 46 37 28 19
print(*sum_with_range_10(range(10))) # 0 2 4 6 8 10 12 14 16 18
Так как curry_2
— это частный случай curry_n
, то можно сделать:
curry_2 = partial(curry_n, n=2)
И для примера применим его к filter
:
curry_filter = curry_2(filter)
only_odd = curry_filter(lambda n: n % 2)
print(*only_odd(range(10))) # 1 3 5 7 9
print(*only_odd(range(-10, 0, 1))) # -9 -7 -5 -3 -1
Pattern matching
Метод анализа списков или других структур данных на наличие в них заданных образцов.
Pattern matching — это то, что больше всего мне понравилось в sml и хуже всего вышло в python.
Придумаем себе цель — написать функцию, которая:
- если принимает список чисел, возвращает их произведение;
- если принимает список строк, возвращает одну большую объединенную строку.
Создадим вспомогательный exception и функцию для его «бросания», которую будем использовать, когда сопоставление не проходит:
class NotMatch(Exception):
"""Not match"""
def not_match(x):
raise NotMatch(x)
И функцию, которая делает проверку и возвращает объект, либо бросает exception:
match = lambda check, obj: obj if check(obj) else not_match(obj)
match_curry = curry_n(match, 2)
Теперь мы можем создать проверку типа:
instance_of = lambda type_: match_curry(lambda obj: isinstance(obj, type_))
Тогда для int
:
is_int = instance_of(int)
print(is_int(2)) # 2
try:
is_int('str')
except NotMatch:
print('not int') # not int
Создадим проверку типа для списка, проверяя его каждый элемент:
is_array_of = lambda matcher: match_curry(lambda obj: all(map(matcher, obj)))
И тогда для int
:
is_array_of_int = is_array_of(is_int)
print(is_array_of_int([1, 2, 3])) # 1 2 3
try:
is_array_of_int('str')
except NotMatch:
print('not int') # not int
И теперь аналогично для str
:
is_str = instance_of(str)
is_array_of_str = is_array_of(is_str)
Также добавим функцию, возвращающую свой аргумент, идемпотентную =)
identity = lambda x: x
print(identity(10)) # 10
print(identity(20)) # 20
И проверку на пустой список:
is_blank = match_curry(lambda xs: len(xs) == 0)
print(is_blank([])) # []
try:
is_blank([1, 2, 3])
except NotMatch:
print('not blank') # not blank
Теперь создадим функцию для разделения списка на первый элемент и остаток с применением к ним «проверок»:
def hd_tl(match_x, match_xs, arr):
x, *xs = arr
return match_x(x), match_xs(xs)
hd_tl_partial = lambda match_x, match_xs: partial(hd_tl, match_x, match_xs)
И рассмотрим самый простой пример с identity
:
hd_tl_identity = hd_tl_partial(identity, identity)
print(hd_tl_identity(range(5))) # 0 [1, 2, 3, 4]
А теперь с числами:
hd_tl_ints = hd_tl_partial(is_int, is_array_of_int)
print(hd_tl_ints(range(2, 6))) # 2 [3, 4, 5]
try:
hd_tl_ints(['str', 1, 2])
except NotMatch:
print('not ints') # not ints
А теперь саму функцию, которая будет перебирать все проверки. Она очень простая:
def pattern_match(patterns, args):
for pattern, fnc in patterns:
try:
return fnc(pattern(args))
except NotMatch:
continue
raise NotMatch(args)
pattern_match_curry = curry_n(pattern_match, 2)
Но зато она неудобна в использовании и требует целый мир скобок, например, нужная нам функция будет выглядеть так:
sum_or_multiply = pattern_match_curry((
(hd_tl_partial(identity, is_blank), lambda arr: arr[0]), # x::[] -> x
(hd_tl_ints, lambda arr: arr[0] * sum_or_multiply(arr[1])), # x::xs -> x * sum_or_multiply (xs) где type(x) == int
(hd_tl_partial(is_str, is_array_of_str), lambda arr: arr[0] + sum_or_multiply(arr[1])), # x::xs -> x + sum_or_multiply (xs) где type(x) == str
))
Теперь проверим ее в действии:
print(sum_or_multiply(range(1, 10))) # 362880
print(sum_or_multiply(['a', 'b', 'c'])) # abc
Ура! Оно работает =)
Рекурсия
Во всех классных языках программирования крутые ребята реализуют map
через рекурсию, чем мы хуже? Тем более мы уже умеем pattern matching:
r_map = lambda fn, arg: pattern_match((
(hd_tl_partial(identity, is_blank), lambda arr: [fn(arr[0])]), # x::[] -> fn(x)
(
hd_tl_partial(identity, identity),
lambda arr: [fn(arr[0])] + r_map(fn, arr[1]) # x::xs -> fn(x)::r_map(fn, xs)
),
), arg)
print(r_map(lambda x: x**2, range(10))) # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
Теперь каррируем:
r_map_curry = curry_n(r_map, 2)
twice = r_map_curry(lambda x: x * 2)
print(twice(range(10)))
try:
print(twice(range(1000))) # [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
except RuntimeError as e:
print(e) # maximum recursion depth exceeded in comparison
Что-то пошло не так, попробуем хвостовую рекурсию.
Для этого создадим «проверку» на None
:
is_none = match_curry(lambda obj: obj is None)
И проверку пары:
pair = lambda match_x, match_y: lambda arr: (match_x(arr[0]), match_y(arr[1]))
А теперь и сам map
:
def r_map_tail(fn, arg):
aux = lambda arg: pattern_match((
(pair(identity, is_none), lambda arr: aux([arr[0], []])), # если аккумулятор None, делаем его []
(
pair(hd_tl_partial(identity, is_blank), identity),
lambda arr: arr[1] + [fn(arr[0][0])] # если (x::[], acc), то прибавляем к аккумулятору fn(x) и возвращаем его
),
(
pair(hd_tl_partial(identity, identity), identity),
lambda arr: aux([arr[0][1], arr[1] + [fn(arr[0][0])]]) # если (x::xs, acc), то делаем рекурсивный вызов с xs и аккумулятором + fn(x)
),
), arg)
return aux([arg, None])
Теперь опробуем наше чудо:
r_map_tail_curry = curry_n(r_map_tail, 2)
twice_tail = r_map_tail_curry(lambda x: x * 2)
print(twice_tail(range(10))) # [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
try:
print(twice_tail(range(10000)))
except RuntimeError as e:
print(e) # maximum recursion depth exceeded
Вот ведь незадача — python не оптимизирует хвостовую рекурсию. Но теперь на помощь нам придут костыли:
def tail_fnc(fn):
called = False
calls = []
def run():
while len(calls): # вызываем функцию с аргументами из списка
res = fn(*calls.pop())
return res
def call(*args):
nonlocal called
calls.append(args) # добавляем аргументы в список
if not called: # проверяем вызвалась ли функция, если нет - запускаем цикл
called = True
return run()
return call
Теперь реализуем с этим map
:
def r_map_really_tail(fn, arg):
aux = tail_fnc(lambda arg: pattern_match(( # декорируем вспомогательную функцию
(pair(identity, is_none), lambda arr: aux([arr[0], []])), # если аккумулятор None, делаем его []
(
pair(hd_tl_partial(identity, is_blank), identity),
lambda arr: arr[1] + [fn(arr[0][0])] # если (x::[], acc), то прибавляем к аккумулятору fn(x) и возвращаем его
),
(
pair(hd_tl_partial(identity, identity), identity),
lambda arr: aux([arr[0][1], arr[1] + [fn(arr[0][0])]]) # если (x::xs, acc), то делаем рекурсивный вызов с xs и аккумулятором + fn(x)
),
), arg))
return aux([arg, None])
r_map_really_tail_curry = curry_n(r_map_really_tail, 2)
twice_really_tail = r_map_really_tail_curry(lambda x: x * 2)
print(twice_really_tail(range(1000))) # [0, 2, 4, 6, 8, 10, 12, 14, 16, 18...
Теперь и это заработало =)
Не все так страшно
Если забыть про наш ужасный pattern matching, то рекурсивный map
можно реализовать вполне аккуратно:
def tail_r_map(fn, arr_):
@tail_fnc
def aux(arr, acc=None):
x, *xs = arr
if xs:
return aux(xs, acc + [fn(x)])
else:
return acc + [fn(x)]
return aux(arr_, [])
curry_tail_r_map = curry_2(tail_r_map)
И сделаем на нем умножение всех нечетных чисел в списке на 2:
@curry_tail_r_map
def twice_if_odd(x):
if x % 2 == 0:
return x * 2
else:
return x
print(twice_if_odd(range(10000))) # [0, 1, 4, 3, 8, 5, 12, 7, 16, 9, 20, 11, 24, 13, 28, 15, 32, 17, 36, 19...
Получилось вполне аккуратно, хоть медленно и ненужно. Как минимум из-за скорости. Сравним производительность разных вариантов map
:
from time import time
checker = lambda x: x ** 2 + x
limit = 100000
start = time()
xs = [checker(x) for x in range(limit)][::-1]
print('inline for:', time() - start)
start = time()
xs = list(map(checker, xs))[::-1]
print('map:', time() - start)
calculate = curry_tail_r_map(checker)
start = time()
xs = calculate(xs)[::-1]
print('r_map without pattern matching:', time() - start)
calculate = r_map_really_tail_curry(checker)
start = time()
xs = calculate(xs)[::-1]
print('r_map with pattern matching:', time() - start)
После чего получим:
inline for: 0.010764837265014648
map: 0.010544061660766602
r_map without pattern matching: 4.720803737640381
r_map with pattern matching: 8.376755237579346
Вариант с pattern matching'ом оказался самым медленным, а встроенные map и for оказались самыми быстрыми.
Заключение
Из этой статьи в реальных приложениях можно использовать, пожалуй, только каррирование. Остальное либо нечитаемый, либо тормозной велосипед =)
Все примеры доступны на github.
Автор: nvbn