1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
|
#!/usr/bin/env python3
# Module Imports
import sys
import sqlite3
try:
import MySQLdb
except:
print("NOTE: 'MySQLdb' not currently installed")
try:
import psycopg2
except:
print("NOTE: 'psycopg2' not currently installed")
##from tqdm import tqdm
import time
from progress.bar import Bar
from pick import pick
import yaml
import os
import argparse
# Global variables
verbose = False
cmd_skip = 0
cmd_count = 0
def get_connection(config, db_type):
if db_type == "sqlite":
return sqlite3.connect(config['path'])
elif db_type == "mysql":
return MySQLdb.connect(**config)
else:
return psycopg2.connect(**config)
def get_connections(config):
source_conn = get_connection(config[config['source']['name']], config['source']['type'])
dest_conn = get_connection(config[config['destination']['name']], config['destination']['type'])
return source_conn, dest_conn
# Returns foreign key list for a given table
def get_foreign_key_list(conn, table, source_type="sqlite"):
cur = conn.cursor()
if verbose: print("TABLE:%s" % table)
if source_type == "sqlite":
sql = f"""PRAGMA foreign_key_list({table});"""
else:
print(f"""ERROR: foreign key search for '{source_type}' databases not yet supported""")
exit(1)
cur.execute(sql)
foreign_keys = []
for foreign_key in cur:
#Example sqlite: (0, 0, 'users_srtuser', 'user_id', 'id', 'NO ACTION', 'NO ACTION', 'NONE')
if verbose: print(" KEY:%s" % str(foreign_key))
foreign_keys.append(foreign_key[2])
return foreign_keys
# returns dictionary with keys as table names, and values as dictionary with column names and count from source and dest conn
def get_db_info(conn, dest_conn=None, source_type="sqlite", mysql_db=None):
sqlite_sql = """SELECT m.name as table_name, p.name as column_name, p.type as type FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p where table_name != 'sqlite_sequence' ORDER BY m.name, p.cid"""
mysql_sql = f"""SELECT * FROM information_schema.columns where table_name like '%%' and table_schema = '{mysql_db}' order by table_name,ordinal_position"""
pg_sql = """SELECT * FROM information_schema.columns where table_schema = 'public' order by table_name,ordinal_position"""
cur = conn.cursor()
sql = sqlite_sql if source_type == "sqlite" else mysql_sql if source_type == "mysql" else pg_sql
cur.execute(sql)
columns = cur.description
results = [{columns[index][0]:column for index, column in enumerate(value)} for value in cur.fetchall()]
if source_type != "sqlite":
results = [{'column_name': col['COLUMN_NAME'], 'table_name': col['TABLE_NAME'], 'type': col['DATA_TYPE']}for col in results]
tables = {}
for i in results:
if i['table_name'] not in tables:
tables[i['table_name']] = {'columns': [], 'types': []}
tables[i['table_name']]['columns'].append(i['column_name'])
tables[i['table_name']]['types'].append(i['type'])
for table in tables:
cur = conn.cursor()
sql = f"SELECT count(*) from {table}"
cur.execute(sql)
results = cur.fetchone()[0]
tables[table]['source_count'] = results
tables[table]['foreign_keys'] = get_foreign_key_list(conn, table, source_type)
if dest_conn is not None:
for table in tables:
cur = dest_conn.cursor()
sql = f"SELECT count(*) from {table}"
cur.execute(sql)
results = cur.fetchone()[0]
tables[table]['dest_count'] = results
return tables
# Orders the table list from no foreign key dependencies to all satisfied
def gen_table_order_sql(source_conn, tables):
# Ordered table names: goal state
table_names_ordered = []
# As yet un-ordered table names: initial state
table_names_unordered = []
for table in tables:
# Never overwrite the migrations table
if 'django_migrations' == table:
continue
table_names_unordered.append([table,tables[table]['foreign_keys'].copy()])
if verbose: print("Len(table_names_unordered) = %d" % len(table_names_unordered))
interation = 0
while len(table_names_unordered):
change = False
interation += 1
for i in range(len(table_names_unordered),0,-1):
i_index = i-1
if verbose: print("Pass %s:(%s)=%s" % (interation,i_index,str(table_names_unordered[i_index])))
table_name = table_names_unordered[i_index][0]
foreign_keys = table_names_unordered[i_index][1]
# If newly satisfied dependency, remove dependency
if foreign_keys:
for j in range(len(foreign_keys),0,-1):
j_index = j-1
# Found in resolved ordered list
if foreign_keys[j_index] in table_names_ordered:
del table_names_unordered[i_index][1][j_index]
change = True
# If no pending dependencies, promote
if not foreign_keys:
# No pending dependencies, so move
table_names_ordered.append(table_name)
# Remove old name from unordered list
del table_names_unordered[i_index]
change = True
if verbose: print(" * Promote:%s" % table_name)
# Sanity Check for unresolvable loops
if not change:
print("ERROR: Unresolvable table dependency loop")
for t in table_names_ordered:
print(" Resolved:%s" % t)
for t in table_names_unordered:
print(" Unresolved:%s" % str(t))
exit(1)
return table_names_ordered
# Pre-clear the destination tables, in reverse dependency order (most to least)
def clear_dest_tables(dest_conn, table_names_ordered, tables, destination_type):
bar = Bar('Pre-clearing destination tables', max=len(table_names_ordered))
success = True
cur = dest_conn.cursor()
for i in range(len(table_names_ordered),0,-1):
i_index = i-1
sql = "DELETE from %s;" % table_names_ordered[i_index]
try:
cur.execute(sql, None)
bar.next()
except Exception as e:
success = False
print(f"\n\nException:\n{e}\n\nSQL: {sql}\nparams: None")
break
bar.finish()
if success:
dest_conn.commit()
# Transfer the tables, one by one
def transfer_sql(source_conn, dest_conn, table_names_ordered, tables, source_type, destination_type):
source_cur = source_conn.cursor()
dest_cur = dest_conn.cursor()
print("Transfer_sql...")
if verbose:
bar_max = 0
for table in tables:
bar_max += int(tables[table]['source_count'])
else:
bar_max = len(table_names_ordered)
bar = Bar('Transfering data by table', max=bar_max)
for table in table_names_ordered:
success = True
count = 0
q = '`' if destination_type != "postgres" else '"'
tables[table]['columns'] = [f'{q}{i}{q}' for i in tables[table]['columns']]
sql = f"""SELECT {','.join(tables[table]['columns'])} from {table};"""
source_cur.execute(sql)
for entry_count,entry in enumerate(source_cur):
# Development/debug support
if cmd_skip and (entry_count < cmd_skip): continue
if cmd_count and ((entry_count - cmd_skip) > cmd_count): break
entry = list(entry)
if table == "orm_cve":
if entry[-2] == '' or (entry[-2] is not None and 'RESERVED' in entry[-2]):
entry[-2] = None # set acknowledge date to None if currently value is invalid
if destination_type == "postgres":
for i in range(len(entry)): # handle lack of booleans in non postgres
if "bool" in tables[table]['types'][i]:
entry[i] = entry[i] != 0
sql = f"""INSERT INTO {table} ({','.join(tables[table]['columns'])}) VALUES ({','.join(['%s'] * len(entry))});"""
try:
dest_cur.execute(sql, entry)
if verbose: bar.next()
# Commit batches as we go
count += 1
if 0 == (count % 100):
dest_conn.commit()
except Exception as e:
success = False
print(f"\n\nException:\n{e}\n\nSQL: {sql}\nparams: {entry}")
break
# Commit the balance of this table
if not verbose:
bar.next()
if success:
dest_conn.commit()
bar.finish()
def run_tests(tables, source_conn, dest_conn):
print('running tests!')
matching_counts = 0
mismatched_tables = []
for table in tables:
table_info = tables[table]
if table_info['source_count'] == table_info['dest_count']:
matching_counts += 1
else:
mismatched_tables.append(table)
print(f'Matching Tables Counts between source and destination out of total tables:{matching_counts}/{len(tables)}')
print(f'Mismatched tables: {mismatched_tables}')
source_count = tables['orm_cve']['source_count']
dest_count = tables['orm_cve']['dest_count']
if source_count != dest_count:
print('orm_cve count does not match between source and destination, not checkin description lengths')
source_conn.close()
dest_conn.close()
return
source_curr = source_conn.cursor()
dest_curr = dest_conn.cursor()
query = 'select length(description) as dl, length(comments) as cl from orm_cve order by NAME LIMIT 1000 OFFSET '
mismatch = False
bars = source_count // 1000 + 1
print(f"Numbers of rows in orm_cve: {source_count}")
bar = Bar('Checking description lengths in batches of 1000', max=bars)
for i in range(bars):
offset_query = f'{query}{i * 1000}'
source_curr.execute(offset_query)
dest_curr.execute(offset_query)
columns = source_curr.description
source = [{columns[index][0]:column for index, column in enumerate(value)} for value in source_curr.fetchall()]
dest = [{columns[index][0]:column for index, column in enumerate(value)} for value in dest_curr.fetchall()]
mismatch = False
for i in range(len(source)):
if source[i]['dl'] != dest[i]['dl'] or source[i]['cl'] != dest[i]['cl']:
print(f'source:\n{source[i]}\n\ndestination: {dest[i]}\n\n')
mismatch = True
break
bar.next()
if mismatch:
break
bar.finish()
if mismatch:
print("Error: mismatched length of description in orm_cve")
else:
print("Success: Description and comment length matches for every row in orm_cve")
source_conn.close()
dest_conn.close()
def repair_sequences_postgres(tables, dest_conn):
bar = Bar('Repairing table sequences', max=len(tables))
for table in tables:
id = 'id'
if table in ['django_session']:
bar.next()
continue
sql = f"SELECT setval(pg_get_serial_sequence('{table}', '{id}'), (SELECT MAX({id}) FROM {table})+1);"
cur = dest_conn.cursor()
try:
cur.execute(sql)
bar.next()
except Exception as e:
print(f"\n\nException:\n{e}\n\nSQL: {sql}\n")
break
bar.finish()
def main(config, test=False, repair=False, show_order=False):
source_conn, dest_conn = get_connections(config)
mysql_db_name = config[config['source']['name']]['db'] if config['source']['type'] == "mysql" else None
tables = get_db_info(source_conn, dest_conn, config['source']['type'], mysql_db_name)
if repair:
repair_sequences_postgres(tables, dest_conn)
source_conn.close()
dest_conn.close()
return
if test:
run_tests(tables, source_conn, dest_conn)
return
_, select_table = pick(('all tables', 'select tables'), "Would you like to copy all tables, or specific tables for transfer?")
if select_table: # filter tables
selection = pick(list(tables.items()), f"Please Select which of {len(tables)} tables to copy (use space key to select).\nFormat: Table Name(Current Source Count:Current Destination Count)", multiselect=True, min_selection_count=1, options_map_func= lambda option: f"{option[0]}({option[1]['source_count']}:{option[1]['dest_count']})")
selection = [value[0] for value in selection ]
tables = {item[0]:item[1] for item in selection}
# Order the table names by foreign key dependecies
table_names_ordered = gen_table_order_sql(source_conn, tables)
if show_order:
print("Ordered Data Tables: %s" % len(table_names_ordered))
for i,table_name in enumerate(table_names_ordered):
print("%2d) %-30s %s" % (i+1, table_name, str(tables[table_name]['foreign_keys'])))
return
# Pre-clear the destination tables to remove obsolete data
clear_dest_tables(dest_conn, table_names_ordered, tables, config['destination']['type'])
# Transfer the tables, one by one
transfer_sql(source_conn, dest_conn, table_names_ordered, tables, config['source']['type'], config['destination']['type'])
# Fix up the table sequences
repair_sequences_postgres(tables, dest_conn)
source_conn.close()
dest_conn.close()
if __name__ == "__main__":
my_parser = argparse.ArgumentParser(description='DB Migration Script (Postgres/Sqlite/MySql)')
my_parser.add_argument('--path',default="db_migration_config.yml", type=str,help='the path to configuration file, default is ./db_migration_config.yml')
my_parser.add_argument('--test',default=False, action="store_true", help='Whether to test migration')
my_parser.add_argument('--repair', default=False, action="store_true", help="Whether to repair postgres sequences if destination is postgres database")
my_parser.add_argument('--show-order', '-o', default=False, action="store_true", dest="show_order", help="Show tables in least to most dependency order")
my_parser.add_argument('--verbose', '-v', default=False, action="store_true", dest="verbose", help="Verbose information")
my_parser.add_argument('--skip', dest='skip', help='Debugging: skip record count')
my_parser.add_argument('--count', dest='count', help='Debugging: short run record count')
args = my_parser.parse_args()
verbose = args.verbose
if args.skip:
cmd_skip = int(args.skip)
if args.count:
cmd_count = int(args.count)
with open(args.path, "r") as ymlfile:
config = yaml.safe_load(ymlfile)
main(config, test=args.test, repair=args.repair, show_order=args.show_order)
|