aboutsummaryrefslogtreecommitdiffstats
path: root/bin/dev_tools/db_migrations.py
blob: 1fb65562b205b150656fe7f4a6407c3213dac28f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
#!/usr/bin/env python3

# Module Imports
import sys
import sqlite3
try:
    import MySQLdb
except:
    print("NOTE: 'MySQLdb' not currently installed")
try:
	import psycopg2
except:
    print("NOTE: 'psycopg2' not currently installed")
##from tqdm import tqdm
import time
from progress.bar import Bar
from pick import pick
import yaml
import os
import argparse

# Global variables
verbose = False
cmd_skip = 0
cmd_count = 0

def get_connection(config, db_type):
    if db_type == "sqlite":
        return sqlite3.connect(config['path'])
    elif db_type == "mysql":
        return MySQLdb.connect(**config)
    else:
        return psycopg2.connect(**config)

def get_connections(config):
    source_conn = get_connection(config[config['source']['name']], config['source']['type'])
    dest_conn = get_connection(config[config['destination']['name']], config['destination']['type'])
    return source_conn, dest_conn

# Returns foreign key list for a given table
def get_foreign_key_list(conn, table, source_type="sqlite"):
    cur = conn.cursor()
    if verbose: print("TABLE:%s" % table)
    if source_type == "sqlite":
        sql = f"""PRAGMA foreign_key_list({table});"""
    else:
        print(f"""ERROR: foreign key search for '{source_type}' databases not yet supported""")
        exit(1)
    cur.execute(sql)
    foreign_keys = []
    for foreign_key in cur:
        #Example sqlite: (0, 0, 'users_srtuser', 'user_id', 'id', 'NO ACTION', 'NO ACTION', 'NONE')
        if verbose: print("  KEY:%s" % str(foreign_key))
        foreign_keys.append(foreign_key[2])
    return foreign_keys

# returns dictionary with keys as table names, and values as dictionary with column names and count from source and dest conn
def get_db_info(conn, dest_conn=None, source_type="sqlite", mysql_db=None):
    sqlite_sql = """SELECT m.name as table_name, p.name as column_name, p.type as type FROM sqlite_master AS m JOIN pragma_table_info(m.name)  AS p where table_name != 'sqlite_sequence' ORDER BY m.name, p.cid"""
    mysql_sql = f"""SELECT * FROM information_schema.columns where table_name like '%%' and table_schema = '{mysql_db}' order by table_name,ordinal_position"""
    pg_sql = """SELECT * FROM information_schema.columns where  table_schema = 'public' order by table_name,ordinal_position"""
    cur = conn.cursor()
    sql = sqlite_sql if source_type == "sqlite" else mysql_sql if source_type == "mysql" else pg_sql
    cur.execute(sql)
    columns = cur.description
    results =  [{columns[index][0]:column for index, column in enumerate(value)} for value in cur.fetchall()]
    if source_type != "sqlite":
        results = [{'column_name': col['COLUMN_NAME'], 'table_name': col['TABLE_NAME'], 'type': col['DATA_TYPE']}for col in results]
    tables = {}
    for i in results:
        if i['table_name'] not in tables:
            tables[i['table_name']] = {'columns': [], 'types': []}
        tables[i['table_name']]['columns'].append(i['column_name'])
        tables[i['table_name']]['types'].append(i['type'])
    for table in tables:
        cur = conn.cursor()
        sql = f"SELECT count(*) from {table}"
        cur.execute(sql)
        results = cur.fetchone()[0]
        tables[table]['source_count'] = results
        tables[table]['foreign_keys'] = get_foreign_key_list(conn, table, source_type)
    if dest_conn is not None:
        for table in tables:
            cur = dest_conn.cursor()
            sql = f"SELECT count(*) from {table}"
            cur.execute(sql)
            results = cur.fetchone()[0]
            tables[table]['dest_count'] = results
    return tables

# Orders the table list from no foreign key dependencies to all satisfied
def gen_table_order_sql(source_conn, tables):
    # Ordered table names: goal state
    table_names_ordered = []
    # As yet un-ordered table names: initial state
    table_names_unordered = []
    for table in tables:
        # Never overwrite the migrations table
        if 'django_migrations' == table:
            continue
        table_names_unordered.append([table,tables[table]['foreign_keys'].copy()])

    if verbose: print("Len(table_names_unordered) = %d" % len(table_names_unordered))
    interation = 0
    while len(table_names_unordered):
        change = False
        interation += 1
        for i in range(len(table_names_unordered),0,-1):
            i_index = i-1
            if verbose: print("Pass %s:(%s)=%s" % (interation,i_index,str(table_names_unordered[i_index])))
            table_name = table_names_unordered[i_index][0]
            foreign_keys = table_names_unordered[i_index][1]
            # If newly satisfied dependency, remove dependency
            if foreign_keys:
                for j in range(len(foreign_keys),0,-1):
                    j_index = j-1
                    # Found in resolved ordered list
                    if foreign_keys[j_index] in table_names_ordered:
                        del table_names_unordered[i_index][1][j_index]
                        change = True
            # If no pending dependencies, promote
            if not foreign_keys:
                # No pending dependencies, so move
                table_names_ordered.append(table_name)
                # Remove old name from unordered list
                del table_names_unordered[i_index]
                change = True
                if verbose: print(" * Promote:%s" % table_name)
        # Sanity Check for unresolvable loops
        if not change:
            print("ERROR: Unresolvable table dependency loop")
            for t in table_names_ordered:
                print("  Resolved:%s" % t)
            for t in table_names_unordered:
                print("  Unresolved:%s" % str(t))
            exit(1)
    return table_names_ordered

# Pre-clear the destination tables, in reverse dependency order (most to least)
def clear_dest_tables(dest_conn, table_names_ordered, tables, destination_type):
    bar = Bar('Pre-clearing destination tables', max=len(table_names_ordered))
    success = True
    cur = dest_conn.cursor()
    for i in range(len(table_names_ordered),0,-1):
        i_index = i-1
        sql = "DELETE from %s;" % table_names_ordered[i_index]
        try:
            cur.execute(sql, None)
            bar.next()
        except Exception as e:
            success = False
            print(f"\n\nException:\n{e}\n\nSQL: {sql}\nparams: None")
            break
    bar.finish()
    if success:
        dest_conn.commit()

# Transfer the tables, one by one
def transfer_sql(source_conn, dest_conn, table_names_ordered, tables, source_type, destination_type):
    source_cur = source_conn.cursor()
    dest_cur = dest_conn.cursor()

    print("Transfer_sql...")

    if verbose:
        bar_max = 0
        for table in tables:
            bar_max += int(tables[table]['source_count'])
    else:
        bar_max = len(table_names_ordered)
    bar = Bar('Transfering data by table', max=bar_max)

    for table in table_names_ordered:
        success = True
        count = 0

        q = '`' if destination_type != "postgres" else '"'
        tables[table]['columns'] = [f'{q}{i}{q}' for i in tables[table]['columns']]
        sql = f"""SELECT {','.join(tables[table]['columns'])} from {table};"""
        source_cur.execute(sql)
        for entry_count,entry in enumerate(source_cur):
            # Development/debug support
            if cmd_skip and (entry_count < cmd_skip): continue
            if cmd_count and ((entry_count - cmd_skip) > cmd_count): break

            entry = list(entry)
            if table == "orm_cve":
                if entry[-2] == '' or (entry[-2] is not None and 'RESERVED' in entry[-2]):
                    entry[-2] = None # set acknowledge date to None if currently value is invalid
            if destination_type == "postgres":
                for i in range(len(entry)): # handle lack of booleans in non postgres
                    if "bool" in tables[table]['types'][i]:
                        entry[i] = entry[i] != 0

            sql = f"""INSERT INTO {table} ({','.join(tables[table]['columns'])}) VALUES ({','.join(['%s'] * len(entry))});"""
            try:
                dest_cur.execute(sql, entry)
                if verbose: bar.next()
                # Commit batches as we go
                count += 1
                if 0 == (count % 100):
                    dest_conn.commit()
            except Exception as e:
                success = False
                print(f"\n\nException:\n{e}\n\nSQL: {sql}\nparams: {entry}")
                break

        # Commit the balance of this table
        if not verbose:
            bar.next()
        if success:
            dest_conn.commit()
    bar.finish()

def run_tests(tables, source_conn, dest_conn):
    print('running tests!')
    matching_counts = 0
    mismatched_tables = []
    for table in tables:
        table_info = tables[table]
        if table_info['source_count'] == table_info['dest_count']:
            matching_counts += 1
        else:
            mismatched_tables.append(table)

    print(f'Matching Tables Counts between source and destination out of total tables:{matching_counts}/{len(tables)}')
    print(f'Mismatched tables: {mismatched_tables}')
    source_count = tables['orm_cve']['source_count']
    dest_count = tables['orm_cve']['dest_count']
    if source_count != dest_count:
        print('orm_cve count does not match between source and destination, not checkin description lengths')
        source_conn.close()
        dest_conn.close()
        return

    source_curr = source_conn.cursor()
    dest_curr = dest_conn.cursor()
    query = 'select length(description) as dl, length(comments) as cl from orm_cve order by NAME LIMIT 1000 OFFSET '
    mismatch = False
    bars = source_count // 1000 + 1
    print(f"Numbers of rows in orm_cve: {source_count}")
    bar = Bar('Checking description lengths in batches of 1000', max=bars)
    for i in range(bars):
        offset_query = f'{query}{i * 1000}'
        source_curr.execute(offset_query)
        dest_curr.execute(offset_query)
        columns = source_curr.description

        source = [{columns[index][0]:column for index, column in enumerate(value)} for value in source_curr.fetchall()]
        dest = [{columns[index][0]:column for index, column in enumerate(value)} for value in dest_curr.fetchall()]
        mismatch = False
        for i in range(len(source)):
            if source[i]['dl'] != dest[i]['dl'] or source[i]['cl'] != dest[i]['cl']:
                print(f'source:\n{source[i]}\n\ndestination: {dest[i]}\n\n')
                mismatch = True
                break
        bar.next()
        if mismatch:
            break
    bar.finish()
    if mismatch:
        print("Error: mismatched length of description in orm_cve")
    else:
        print("Success: Description and comment length matches for every row in orm_cve")
    source_conn.close()
    dest_conn.close()

def repair_sequences_postgres(tables, dest_conn):
    bar = Bar('Repairing table sequences', max=len(tables))
    for table in tables:
        id = 'id'
        if table in ['django_session']:
            bar.next()
            continue
        sql = f"SELECT setval(pg_get_serial_sequence('{table}', '{id}'), (SELECT MAX({id}) FROM {table})+1);"
        cur = dest_conn.cursor()
        try:
            cur.execute(sql)
            bar.next()
        except Exception as e:
            print(f"\n\nException:\n{e}\n\nSQL: {sql}\n")
            break
    bar.finish()

def main(config, test=False, repair=False, show_order=False):
    source_conn, dest_conn = get_connections(config)
    mysql_db_name = config[config['source']['name']]['db'] if config['source']['type'] == "mysql" else None
    tables = get_db_info(source_conn, dest_conn, config['source']['type'], mysql_db_name)
    if repair:
        repair_sequences_postgres(tables, dest_conn)
        source_conn.close()
        dest_conn.close()
        return
    if test:
        run_tests(tables, source_conn, dest_conn)
        return
    _, select_table = pick(('all tables', 'select tables'), "Would you like to copy all tables, or specific tables for transfer?")
    if select_table: # filter tables
        selection = pick(list(tables.items()), f"Please Select which of {len(tables)} tables to copy (use space key to select).\nFormat: Table Name(Current Source Count:Current Destination Count)", multiselect=True, min_selection_count=1, options_map_func= lambda option: f"{option[0]}({option[1]['source_count']}:{option[1]['dest_count']})")
        selection = [value[0] for value in selection ]
        tables = {item[0]:item[1] for item in selection}

    # Order the table names by foreign key dependecies
    table_names_ordered = gen_table_order_sql(source_conn, tables)
    if show_order:
        print("Ordered Data Tables: %s" % len(table_names_ordered))
        for i,table_name in enumerate(table_names_ordered):
            print("%2d)  %-30s  %s" % (i+1, table_name, str(tables[table_name]['foreign_keys'])))
        return
    # Pre-clear the destination tables to remove obsolete data
    clear_dest_tables(dest_conn, table_names_ordered, tables, config['destination']['type'])
    # Transfer the tables, one by one
    transfer_sql(source_conn, dest_conn, table_names_ordered, tables, config['source']['type'], config['destination']['type'])
    # Fix up the table sequences
    repair_sequences_postgres(tables, dest_conn)
    source_conn.close()
    dest_conn.close()

if __name__ == "__main__":
    my_parser = argparse.ArgumentParser(description='DB Migration Script (Postgres/Sqlite/MySql)')
    my_parser.add_argument('--path',default="db_migration_config.yml", type=str,help='the path to configuration file, default is ./db_migration_config.yml')
    my_parser.add_argument('--test',default=False, action="store_true", help='Whether to test migration')
    my_parser.add_argument('--repair', default=False, action="store_true", help="Whether to repair postgres sequences if destination is postgres database")
    my_parser.add_argument('--show-order', '-o', default=False, action="store_true", dest="show_order", help="Show tables in least to most dependency order")
    my_parser.add_argument('--verbose', '-v', default=False, action="store_true", dest="verbose", help="Verbose information")
    my_parser.add_argument('--skip', dest='skip', help='Debugging: skip record count')
    my_parser.add_argument('--count', dest='count', help='Debugging: short run record count')
    args = my_parser.parse_args()

    verbose = args.verbose
    if args.skip:
        cmd_skip = int(args.skip)
    if args.count:
        cmd_count = int(args.count)

    with open(args.path, "r") as ymlfile:
        config = yaml.safe_load(ymlfile)
        main(config, test=args.test, repair=args.repair, show_order=args.show_order)