SQLAchemy的多进程实践

前言

最近上头说我写的ETL工具从MySQL导出为CSV的速度太慢了,需要性能优化。的确,有部分数据因为在MySQL里面做了
分表分库,而我目前的导出实现是一个一个对小表进行导出。其实,这一步完全是可以并发多个表同时导出的。
理论上如果网络IO没有瓶颈的话,多个表同时从MySQL里dump可以大大提升效率。

正题

查阅文档

为了实现并发地使用sqlalchemy我花了不少时间在网上找资料,也在StackOverflow寻求帮助。
其实刚开始我是想用threading结合SQLAlchemy来实现多线程导出的。去SQLAlchemy官网一看,没有找到相关的实现文档,
却找到了multiprocessing与SQLAlchemy的实践:

1
2
3
4
5
6
Using Connection Pools with Multiprocessing
It’s critical that when using a connection pool, and by extension when using an Engine created via create_engine(), that the pooled connections are not shared to a forked process. TCP connections are represented as file descriptors, which usually work across process boundaries, meaning this will cause concurrent access to the file descriptor on behalf of two or more entirely independent Python interpreter states.
There are two approaches to dealing with this.
The first is, either create a new Engine within the child process, or upon an existing Engine, call Engine.dispose() before the child process uses any connections. This will remove all existing connections from the pool so that it makes all new ones. Below is a simple version using multiprocessing.Process, but this idea should be adapted to the style of forking in use:

1
2
3
4
5
6
7
8
9
engine = create_engine("...")
def run_in_process():
engine.dispose()
with engine.connect() as conn:
conn.execute("...")
p = Process(target=run_in_process)

实践

于是,我就打算用multiprocessing试试了。multiprocessing有个比较坑爹的地方就是它会用pickle来序列化一些数据,
因为要把数据复制到新spawn出来的进程。

关于pickle

首先我们要了解下pickle,虽然pickle挺坑的。哪些内容可以pickle呢,官网上说:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
What can be pickled and unpickled?
The following types can be pickled:
None, True, and False
integers, floating point numbers, complex numbers
strings, bytes, bytearrays
tuples, lists, sets, and dictionaries containing only picklable objects
functions defined at the top level of a module (using def, not lambda)
built-in functions defined at the top level of a module
classes that are defined at the top level of a module
instances of such classes whose __dict__ or the result of calling __getstate__() is picklable (see section Pickling Class Instances for details).

比如, 我们举个简单的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from sqlalchemy import create_engine
import time
import pymysql
from multiprocessing import Pool
#engine = create_engine(f'mysql+pymysql://root:ignorance@localhost:3306/', server_side_cursors=True, pool_size=20)
class Client(object):
def __init__(self):
#self.engine = create_engine(f'mysql+pymysql://root:ignorance@localhost:3306/', server_side_cursors=True, pool_size=20)
self.pool = Pool(5)
self.connection = pymysql.connect(host='localhost',
user='root',
password='ignorance',
port=3306,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor)
self.engine = create_engine(f'mysql+pymysql://root:ignorance@localhost:3306/', server_side_cursors=True, pool_size=20)
def run_in_process(self, x, y='hehe'):
print('run in process')
self.engine.dispose()
conn = self.engine.connect()
res = conn.execute('select count(1) from zhihu.zhihu_answer_meta limit 10')
print(res.fetchall())
time.sleep(5)
print(conn)
def run(self):
x = 'x'
res_list = []
res = self.pool.apply_async(self.run_in_process, args=(x,), kwds={'y':'shit'})
res_list.append(res)
#[each.get(3) for each in res_list]
def run_pool(self):
self.pool.close()
self.pool.join()
client = Client()
for i in range(10):
client.run()
client.run_pool()

这段代码我的目的是想用多个进程多个connector连接到MySQL,然后各自进程同时去查询,这样便可以实现并行处理,
提升效率。然而,实际上这段代码不能正常地执行,而且没有任何报错提示。这是为何呢?

有些时候可能会看到这样的报错:

1
TypeError: can't pickle _thread._local objects

https://stackoverflow.com/questions/58022926/cant-pickle-the-sqlalchemy-engine-in-the-class

下面我把上面的代码修改下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from sqlalchemy import create_engine
import time
import pymysql
from multiprocessing import Pool
class Client(object):
def __init__(self):
self.pool = Pool(5)
self.connection = pymysql.connect(host='localhost',
user='root',
password='ignorance',
port=3306,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor)
def run_in_process(self, x, y='hehe'):
# 这里我们把engine改为了方法的内部变量,而不是实例变量。这样就不报错了
engine = create_engine(f'mysql+pymysql://root:ignorance@localhost:3306/', server_side_cursors=True, pool_size=20)
engine.dispose()
conn = engine.connect()
res = conn.execute('select count(1) from zhihu.zhihu_answer_meta limit 10')
print(res.fetchall())
time.sleep(5)
print(conn)
def run(self):
x = 'x'
res_list = []
res = self.pool.apply_async(self.run_in_process, args=(x,), kwds={'y':'shit'})
res_list.append(res)
#[each.get(3) for each in res_list]
def run_pool(self):
self.pool.close()
self.pool.join()
def __getstate__(self):
# 这里我增加了这个魔术方法
self_dict = self.__dict__.copy()
del self_dict['pool']
del self_dict['connection'] # if conenction is not deleted, it would be silent without any errors
return self_dict
def __setstate__(self, state):
# 这里我增加了这个魔术方法
self.__dict__.update(state)
client = Client()
for i in range(10):
client.run()
client.run_pool()

我们看到,这段代码有一些变化:

增加了getstategetstate这两个魔术方法。为什么要加呢?

1
2
3
首先,我们 self.pool.apply_async(self.run_in_process) 可以看出apply_async调用的是实例的方法,所以Python需要pickle整个Client对象,
包括它的所有实例变量。在第一个代码片段中我们可以看到它的实例变量有pool, connection, engine等等。然而这些对象都是不可以被pickle的,
所以代码执行的时候会有问题。所以就有了__getstate__, __getstate__ 这两个东西。

1
2
3
__getstate__ 总是在对象pickle之前调用, 同时,它让你可以指定你想pickle的对象状态。然后unpickle的时候,
如果__setstate__被实现了,则__setstate__(state)会被调用。如果没有被实现的话,__getstate__返回的dict将会被
unpickle的实例使用。在上面例子中,__setstate__ 其实没有实际效果,不写也可以。

multiprocessing的一些细节

传多个参数在target函数

有时候当我们想在apply_async 的 target函数上传指定参数的时候, 可以用kwds传进去,比如:

1
2
3
4
5
6
7
8
def run_in_process(self, x, y='hehe'):
# 这里我们把engine改为了方法的内部变量,而不是实例变量。
print(x)
print(y)
def run(self):
x = 'x'
res = self.pool.apply_async(self.run_in_process, args=(x,), kwds={'y':'shit'})

map和apply的区别

map执行的顺序是和参数的顺序一致的,apply_async的顺序是随机的
map一般用来切分参数执行在同一个方法上。而apply_async可以调用不同的方法。

此外,
map相当于 map_async().get()
apply相当于 apply_async().get()

子进程报错没有提示

有个很坑的地方,有些时候逻辑明明没有执行,但又没有任何报错!

1
2
3
4
5
6
7
8
9
10
11
12
13
class Foo():
#@staticmethod
def work(self):
raise ValueError("error")
if __name__ == '__main__':
pool = mp.Pool()
foo = Foo()
res = pool.apply_async(foo.work)
pool.close()
pool.join()
#print(res.get())

比如这个,如果我不get() 一下 apply_async后的返回的话,看不到任何报错信息,解决办法就是用get()后才
能得知报错的信息。

加装饰器后报错

比如我们在run_in_process方法上了个装饰器,然后就报错了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def timeit(method):
def timed(*args, **kw):
ts = time.time()
result = method(*args, **kw)
te = time.time()
if 'log_time' in kw:
name = kw.get('log_name', method.__name__.upper())
kw['log_time'][name] = int((te - ts) * 1000)
else:
print('%r 执行时长 %2.2f s' % (method.__name__, (te - ts) ))
return result
return timed
@timeit
def run_in_process(self, x, y='hehe'):
# 这里我们把engine改为了方法的内部变量,而不是实例变量。
print(x)
print(y)
def run(self):
x = 'x'
res = self.pool.apply_async(self.run_in_process, args=(x,), kwds={'y':'shit'})

报错说:

1
2
3
4
5
6
7
8
9
10
Traceback (most recent call last):
File "/data/software/miniconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
self.run()
File "/data/software/miniconda3/lib/python3.7/multiprocessing/process.py", line 99, in run
self._target(*self._args, **self._kwargs)
File "/data/software/miniconda3/lib/python3.7/multiprocessing/pool.py", line 110, in worker
task = get()
File "/data/software/miniconda3/lib/python3.7/multiprocessing/queues.py", line 354, in get
return _ForkingPickler.loads(res)
AttributeError: 'Client' object has no attribute 'timed'

解决办法:
比较麻烦,不能用@了。可以以这种办法代替装饰器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from multiprocessing import Pool
def decorate_func(f):
def _decorate_func(*args, **kwargs):
print "I'm decorating"
return f(*args, **kwargs)
return _decorate_func
def actual_func(x):
return x ** 2
def wrapped_func(*args, **kwargs):
return decorate_func(actual_func)(*args, **kwargs)
my_swimming_pool = Pool()
result = my_swimming_pool.apply_async(wrapped_func,(2,))
print result.get()

关于处理ctrl-c 的报错:

有时候我们用multiprocessing处理一些任务,当我们想终止任务时候,用Ctrl+C 然后会看到一堆的报错,有时候还得连续按很多CTRL+C完全终止掉。
下面是最佳解决方案:

1
2
解决办法是 首先防止子进程接收KeyboardInterrupt,然后完全交给父进程catch interrupt然后清洗进程池。通过这种方法可以
避免在子进程写处理异常逻辑,并且防止了由idle workers 生成的无止尽的Error。

解决方案代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import multiprocessing
import os
import signal
import time
def init_worker():
signal.signal(signal.SIGINT, signal.SIG_IGN)
def run_worker():
time.sleep(15)
def main():
print "Initializng 5 workers"
pool = multiprocessing.Pool(5, init_worker)
print "Starting 3 jobs of 15 seconds each"
for i in range(3):
pool.apply_async(run_worker)
try:
print("Waiting 10 seconds")
time.sleep(10)
except KeyboardInterrupt:
print("Caught KeyboardInterrupt, terminating workers")
pool.terminate()
pool.join()
else:
print "Quitting normally"
pool.close()
pool.join()

Reference

https://docs.python.org/3/library/pickle.html
https://stackoverflow.com/questions/25382455/python-notimplementederror-pool-objects-cannot-be-passed-between-processes/25385582#25385582