diff options
Diffstat (limited to 'bin/common/srtool_sql.py')
-rwxr-xr-x | bin/common/srtool_sql.py | 492 |
1 files changed, 492 insertions, 0 deletions
diff --git a/bin/common/srtool_sql.py b/bin/common/srtool_sql.py new file mode 100755 index 00000000..673793d6 --- /dev/null +++ b/bin/common/srtool_sql.py @@ -0,0 +1,492 @@ +################################# +# Python SQL helper methods +# +# Provide SQL extended support via wrappers +# * Enable retry for errors, specifically database locks +# * Capture start/stop second+millisecond timestamps +# * Provide post-dump of time tracking and retry counts +# +# Solution source: +# https://stackoverflow.com/questions/15143871/simplest-way-to-retry-sqlite-query-if-db-is-locked +# Quote: "Python will retry regularly if the table is locked. It will not retry if the Database is locked." +# + +import sys +import time +import subprocess +from datetime import datetime, date +from collections import OrderedDict +import sqlite3 +import re +import os +import yaml +from types import SimpleNamespace + +# Globals +SQL_TRACE = False +SQL_VERBOSE = False +SQL_CONTEXT = "NN" +SQL_TIMEOUT_MAX = 10 +SQL_TIMEOUT_TIME = 0.0001 +SQL_LOG_DIR = 'logs' + +# Load the database configuration +SRT_BASE_DIR = os.getenv('SRT_BASE_DIR', '.') +srt_dbconfig = None +srt_dbtype = None +with open(f"{SRT_BASE_DIR}/srt_dbconfig.yml", "r") as ymlfile: + SRT_DBCONFIG = yaml.safe_load(ymlfile) + SRT_DBSELECT = SRT_DBCONFIG['dbselect'] + srt_dbconfig = SRT_DBCONFIG[SRT_DBSELECT] + srt_dbtype = srt_dbconfig['dbtype'] +if not srt_dbtype: + print(f"ERROR: Missing {SRT_BASE_DIR}/srt_dbconfig.yml'") + exit(1) +if ("mysql" == srt_dbtype) or ('1' == os.getenv('SRT_MYSQL', '0')): + import MySQLdb +if ("postgres" == srt_dbtype) or ('1' == os.getenv('SRT_POSTGRES', '0')): + import psycopg2 + from psycopg2.extras import RealDictCursor + +# quick development/debugging support +def _log(msg): + DBG_LVL = os.environ['SRTDBG_LVL'] if ('SRTDBG_LVL' in os.environ) else 2 + DBG_LOG = os.environ['SRTDBG_LOG'] if ('SRTDBG_LOG' in os.environ) else '/tmp/srt_dbg.log' + if 1 == DBG_LVL: + print(msg) + elif 2 == DBG_LVL: + f1=open(DBG_LOG, 'a') + f1.write("|" + msg + "|\n" ) + f1.close() + +#with open(f"{SRT_BASE_DIR}/db_migration_config.yml", "r") as migfile: +# DB_MIG_CONFIG = yaml.safe_load(migfile) + +################################# +# Debug Support +# + +SQL_TRACE_log = [] +SQL_VERBOSE_log = [] + +# Enable debug tracking, optional context +def SQL_DEBUG(is_trace,context=None,is_verbose=False): + global SQL_TRACE + global SQL_VERBOSE + global SQL_CONTEXT + SQL_TRACE = is_trace + if context: + SQL_CONTEXT = context + if is_verbose: + SQL_VERBOSE = context + if SQL_TRACE: + print("SRTSQL_DEBUG:Trace=%s,Context=%s,Verbose=%s)" % (SQL_TRACE,context,is_verbose)) + sys.stdout.flush() + +def _SQL_GET_MS(): + if not SQL_TRACE: return 0 + dt = datetime.now() + return (dt.minute * 100000000) + (dt.second * 1000000) + dt.microsecond + +def _SQL_TRACE_LOG_ADD(start,stop,loop): + global SQL_TRACE_log + if not SQL_TRACE: return + SQL_TRACE_log.append([SQL_CONTEXT,start,stop,loop]) + +def SQL_DUMP(): + if not SQL_TRACE: return + if not os.path.isdir(SQL_LOG_DIR): + os.makedirs(SQL_LOG_DIR) + log_file = '%s/SQL_TRACE_%s.log' % (SQL_LOG_DIR,SQL_CONTEXT) + with open(log_file, 'w') as fd: + print(" (Context) Start Stop (Retries)",file=fd) + print("===============================================",file=fd) + for context,start,stop,loop in SQL_TRACE_log: + print("sql_dump:(%3s) %d to %d (%d)" % (context[:3],start,stop,loop),file=fd) + print("SQL debug trace log:%s" % log_file) + +def SQL_DUMP_COMPARE(param,is_csv=False): + tag1,tag2 = param.split(',') + if not os.path.isdir(SQL_LOG_DIR): + os.makedirs(SQL_LOG_DIR) + log1 = '%s/SQL_TRACE_%s.log' % (SQL_LOG_DIR,tag1) + log2 = '%s/SQL_TRACE_%s.log' % (SQL_LOG_DIR,tag2) + + log = [] + + def load_log(logfile): + p = re.compile(r'sql_dump:\((\w+)\) (\d+) to (\d+) \((\d+)\)') + with open(logfile, 'r') as fs: + for line in fs.readlines(): + # sql_dump:(JOB) 39290879 to 39293849 (0) + m = p.match(line) + if not m: + continue + tag,start,stop,retry = m.groups() + log.append([tag,int(start),int(stop),retry]) + # Load the logs + load_log(log1) + load_log(log2) + # Sort the log + def sortOnStart(e): + return e[1] + log.sort(key=sortOnStart) + + # Display log table with diffs + if not is_csv: + print(" # |Tag|Start uSec|Stop uSec |Re|(diff prev )|(diff next )|(diff write)") + print("======|===|==========|==========|==|============|============|============") + else: + print("Index,Tag,Start,Stop,Retries,Diff_prev,Diff_next,Diff write") + logmax = len(log) + i = -1 + for tag,start,stop,retry in log: + i += 1 + if i == 0: + pre_diff = 0 + else: + pre_diff = log[i][1] - log[i-1][2] + if i == (logmax - 1): + post_diff = 0 + else: + post_diff = log[i+1][1] - log[i][2] + write_diff = log[i][2] - log[i][1] + if not is_csv: + print("[%4d]:%s,%010d,%010d,%s (^ %8d) (v %8d) (~ %8d)" % (i,tag,start,stop,retry,pre_diff,post_diff,write_diff)) + else: + print("%d,%s,%010d,%010d,%s,%8d,%8d,%8d" % (i,tag,start,stop,retry,pre_diff,post_diff,write_diff)) + if SQL_VERBOSE: + print('') + print('Executed SQL commands:') + for line in SQL_VERBOSE_log: + print(line) + print('') + +def SQL_FETCH_INDEXES(conn, dbconfig=None): + # Define the database type connection + if not dbconfig: + dbconfig = srt_dbconfig + dbtype = dbconfig['dbtype'] + + #Goal: Create a data structure that has: + # ordered list of tables + # table column names + # index(ordinal) of table + # columns: (table_name, column_name, ordinal_postion) -> Should be list[list] + # Formatting should not be done in the subroutine...do it in srtool_common.py + # (name, value) -> name is table_name,_column name and value is the index + # Should be returned as tuples rather than preformatted strings + # Returns should be consistent (also for error) + + cur = conn.cursor() + if 'postgres' == dbtype: + sql = "SELECT * FROM information_schema.columns where table_schema = 'public' order by table_name,ordinal_position;" + print("cursor stat: {}".format(cur)) + cur.execute(sql) + columns = cur.description + results = [{columns[index][0]:column for index, column in enumerate(value)} for value in cur.fetchall()] + tables = {} + # TODO last line of for is hardcoded in postgres format + for i in results: + if i['table_name'] not in tables: + # {'table_name' : 'column_name', 'ordinal_position'} + tables[i['table_name']] = {'ordinal_position' : 'column_name'} + tables[i['table_name']][i['ordinal_position']-1] = i['column_name'] + for val_d in tables: + tables[val_d].pop('ordinal_position') + ret_list = [] + for table in tables: + table_items = tables[table].items() + sorted_tabl = sorted(table_items) + for offset,i in enumerate(sorted_tabl): + table = table.replace('orm_','') + ret_list.append(("{}_{}".format(table.upper(), i[1].upper()), offset)) + return(ret_list) + elif 'sqlite' == dbtype: + database_file = dbconfig['path'] + create_re = re.compile(r"CREATE TABLE[A-Z ]* \"(\w+)\" \((.+)\);") + try: + cmd = ('sqlite3', database_file, '.schema') # must be abstracted + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + return([("ERROR","(%d) %s" % (e.returncode, e.output))]) + ret_list = [] + # print('RET LIST: {}'.format(ret_list)) + # print('cmd OUTPUT: {}'.format(output)) + # problem -> for loop is not executing (nothing returned from command 'sqlite3 srt-backup.sqlite .schema') + for line in output.decode("utf-8").splitlines(): + print(line) + print('In for loop - retlist: {}'.format(ret_list)) + match = create_re.match(line) + if not match: + print('ERROR: no match') + continue + + table = match.group(1).upper() + table = table.replace('ORM_','') + + columns = match.group(2) + for i, col in enumerate(columns.split(',')): + col = col.strip() + name = col[1:] + # + try: + name = name[:name.index('"')] + print('NOTE: passed try #2: {}'.format(name)) + except Exception as e: + return([("ERROR","%s:%s:" % (e,col))]) + name = col[:col.index(' ')] + ret_list.append(("%s_%s" % (table.upper(),name.upper()), i)) + return(ret_list) + else: + return([("ERROR","No support for MySQL or MariahDB. Update coming..."),]) + +################################# +# SQL wrapper methods +# + +def _SQL_ACTION(action,cur_conn,sql=None,params=None,dbconfig=None): + # Define the database type connection + if not dbconfig: + dbconfig = srt_dbconfig + dbtype = dbconfig['dbtype'] + + ret = None + timeout_count = 0 + if SQL_VERBOSE: + SQL_VERBOSE_log.append("SQL_ACTION:%s:%s:%s:%s:" % (action,cur_conn,sql,params)) + sleep_time = SQL_TIMEOUT_TIME + start = _SQL_GET_MS() + exception_occured = False + for x in range(0, SQL_TIMEOUT_MAX): + exception_occured = False + try: + if 'exec' == action: + # to account for difference between mysql/postgres and sqlite + if not dbtype == "sqlite": + sql = sql.replace("?", "%s") + if dbtype == "postgres": # for postgres case insenstive issue + sql = sql.replace('`', '"') # replace backticks with double quotes + if "INSERT INTO" in sql: + sql += " RETURNING *" + camel_case_columns = ["lastModifiedDate", "publishedDate", "cvssV3_baseScore", "cvssV3_baseSeverity", "cvssV2_baseScore", "cvssV2_severity"] + for col in camel_case_columns: + if col in sql and f'"{col}"' not in sql: + sql = sql.replace(f'{col}', f'"{col}"') + if params: + ret = cur_conn.execute(sql, params) + else: + ret = cur_conn.execute(sql) + elif 'commit' == action: + ret = cur_conn.commit() + except Exception as e: + exception_occured = True + print(f"Error occured while running\nsql: {sql}\nparams:{params}\naction:{action}") + print(e) + time.sleep(sleep_time) + timeout_count += 1 + pass + finally: + _SQL_TRACE_LOG_ADD(start,_SQL_GET_MS(),timeout_count) + break + else: + # Give up, dump what we had, and trigger a proper error + SQL_TRACE_log_add(start,_SQL_GET_MS(),timeout_count) + sql_dump() + if 'exec' == action: + ret = cur_conn.execute(sql,params) + elif 'commit' == action: + ret = cur_conn.commit() + if not dbtype == "sqlite": + ret = cur_conn + return ret + +def SQL_CONNECT(column_names=False,dbconfig=None): + # Define the database type connection + if not dbconfig: + dbconfig = srt_dbconfig + dbtype = dbconfig['dbtype'] + + if dbtype == "mysql": + conn = MySQLdb.connect( + passwd=dbconfig["passwd"], + db=dbconfig["name"], + host=dbconfig["host"], + user=dbconfig["user"], + port=dbconfig["port"] + ) + return conn + elif dbtype == "postgres": + if column_names: + conn = psycopg2.connect( + password=dbconfig["passwd"], + database=dbconfig["name"], + host=dbconfig["host"], + user=dbconfig["user"], + port=dbconfig["port"], + cursor_factory=RealDictCursor, + ) + else: + conn = psycopg2.connect( + password=dbconfig["passwd"], + database=dbconfig["name"], + host=dbconfig["host"], + user=dbconfig["user"], + port=dbconfig["port"], + ) + return conn + else: # Sqlite + conn = sqlite3.connect(dbconfig["path"]) + if column_names: + conn.row_factory = sqlite3.Row + return conn + +def SQL_CURSOR(conn,dbconfig=None): + return(conn.cursor()) + +def SQL_EXECUTE(cur,sql,params=None,dbconfig=None): + return(_SQL_ACTION('exec',cur,sql,params,dbconfig)) + +def SQL_COMMIT(conn,dbconfig=None): + return(_SQL_ACTION('commit',conn,dbconfig)) + +def SQL_CLOSE_CUR(cur,dbconfig=None): + return(cur.close()) + +def SQL_CLOSE_CONN(conn,dbconfig=None): + return(conn.close()) + +def SQL_GET_LAST_ROW_INSERTED_ID(cur,dbconfig=None): + # Define the database type connection + if not dbconfig: + dbconfig = srt_dbconfig + dbtype = dbconfig['dbtype'] + + if dbtype == "postgres": + return(SQL_FETCH_ONE(cur).id) + else: + return cur.lastrowid + +def SQL_FETCH_ONE(cur,dbconfig=None): + columns = cur.description + result = {columns[index][0]:column for index, column in enumerate(cur.fetchone()) } + return SimpleNamespace(**result) + +def SQL_FETCH_ALL(cur,dbconfig=None): + columns = cur.description + results = [{columns[index][0]:column for index, column in enumerate(value)} for value in cur.fetchall()] + return [SimpleNamespace(**result) for result in results] + +def GET_DB_TYPE(dbconfig=None): + # Define the database type connection + if not dbconfig: + dbconfig = srt_dbconfig + dbtype = dbconfig['dbtype'] + return dbtype + +def SQL_BATCH_WRITE(cur_conn, table, records, dbconfig=None, fields=None, override_values=None): + ''' + Batch write wrapper function + - Records must contain tuples of the same length + + :param cur_conn: SQL connection + :param table: target table name + :param records: list of tuples containing records to be inserted + :param dbconfig: dbconfig['dbtype'] contains DB type + :param fields: list of specified fields to insert into + :param override_values: list of specified values + + :return: SQL DB connection's cursor + ''' + + # Define the database type connection + if not dbconfig: + dbconfig = srt_dbconfig + dbtype = dbconfig['dbtype'] + + # invalid parameters check + if cur_conn == None or table == None or records == None: + raise Exception("SQL Batch Write Failed: invalid parameters provided") + if not isinstance(records, list) and not isinstance(records, tuple): + raise Exception("SQL Batch Write Failed: records must be of type 'list' or 'tuple'") + + # invalid number of fields supplied check + std_record_ct = len(records[0]) + for record in records: + if len(record) != std_record_ct: + raise BaseException("SQL Batch Write Failed: incorrect number of fields supplied") + + # bulk insert + if fields != None: + _fields = "(" + ','.join([str(field) for field in fields]) + ")" + else: + _fields = '' + + if dbtype == 'sqlite': + if override_values is None: + _ov = f"({','.join(['?'] * len(records[0]))})" + elif isinstance(override_values, list): + _ov = "(" + ','.join([str(ov) for ov in override_values]) + ")" + elif isinstance(override_values, str): + _ov = override_values + cur_conn.executemany(f"INSERT INTO {table}{_fields} VALUES{_ov};", records) + elif dbtype == 'postgres': + if override_values is None: + _ov = f"({','.join(['%s'] * len(records[0]))})" + elif isinstance(override_values, list): + _ov = "(" + ','.join([str(ov) for ov in override_values]) + ")" + elif isinstance(override_values, str): + _ov = override_values + psycopg2.extras.execute_batch(cur_conn, f"INSERT INTO {table}{_fields} VALUES{_ov};", records) + + # conn.commit() + return cur_conn + + +def SQL_BATCH_UPDATE(cur_conn, table, values_list, set_field, where_field, dbconfig=None): + ''' + Batch update wrapper function (not tested) + - Records must contain tuples of the same length + + :param cur_conn: SQL connection + :param table: target table name + :param values_list: parameter values provided to the SQL query + :param set_field: list containing the 'SET' parameterized fields in the SQL query + :param where_field: list containing the 'WHERE' parameterized fields in the SQL query + :param dbconfig: dbconfig['dbtype'] contains DB type + + :return: SQL DB connection's cursor + ''' + + # Define the database type connection + if not dbconfig: + dbconfig = srt_dbconfig + dbtype = dbconfig['dbtype'] + + # invalid parameters check + if (cur_conn == None) or (table == None) or (set_field == None) or (where_field == None) or (values_list == None): + raise Exception("SQL Batch Update Failed: invalid parameters provided") + + # invalid number of fields supplied check + if (len(set_field) + len(where_field)) != len(values_list[0]): + raise Exception(f"SQL Batch Update Failed: number of fields and values supplied mismatches ({len(set_field)},{len(where_field)},{len(values_list)})") + + if dbtype == 'sqlite': + # generate the SQL command for sqlite + update_comm = f"UPDATE {table}" + set_comm = " SET " + ", ".join([f"{s_field} = ?" for s_field in set_field]) + where_comm = " WHERE " + ", ".join([f"{w_field} = ?" for w_field in where_field]) + sql = update_comm + set_comm + where_comm + ";" + cur_conn.executemany(sql, values_list) + + elif dbtype == 'postgres': + # generate the SQL command for postgresql + update_comm = f"UPDATE {table}" + set_comm = " SET " + ", ".join([f"{s_field} = %s" for s_field in set_field]) + where_comm = " WHERE " + ", ".join([f"{w_field} = %s" for w_field in where_field]) + sql = update_comm + set_comm + where_comm + ";" + psycopg2.extras.execute_batch(cur_conn, sql, values_list) + + # conn.commit() + return cur_conn + |