aboutsummaryrefslogtreecommitdiffstats
path: root/bin/common/srtool_sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'bin/common/srtool_sql.py')
-rwxr-xr-xbin/common/srtool_sql.py492
1 files changed, 492 insertions, 0 deletions
diff --git a/bin/common/srtool_sql.py b/bin/common/srtool_sql.py
new file mode 100755
index 00000000..673793d6
--- /dev/null
+++ b/bin/common/srtool_sql.py
@@ -0,0 +1,492 @@
+#################################
+# Python SQL helper methods
+#
+# Provide SQL extended support via wrappers
+# * Enable retry for errors, specifically database locks
+# * Capture start/stop second+millisecond timestamps
+# * Provide post-dump of time tracking and retry counts
+#
+# Solution source:
+# https://stackoverflow.com/questions/15143871/simplest-way-to-retry-sqlite-query-if-db-is-locked
+# Quote: "Python will retry regularly if the table is locked. It will not retry if the Database is locked."
+#
+
+import sys
+import time
+import subprocess
+from datetime import datetime, date
+from collections import OrderedDict
+import sqlite3
+import re
+import os
+import yaml
+from types import SimpleNamespace
+
+# Globals
+SQL_TRACE = False
+SQL_VERBOSE = False
+SQL_CONTEXT = "NN"
+SQL_TIMEOUT_MAX = 10
+SQL_TIMEOUT_TIME = 0.0001
+SQL_LOG_DIR = 'logs'
+
+# Load the database configuration
+SRT_BASE_DIR = os.getenv('SRT_BASE_DIR', '.')
+srt_dbconfig = None
+srt_dbtype = None
+with open(f"{SRT_BASE_DIR}/srt_dbconfig.yml", "r") as ymlfile:
+ SRT_DBCONFIG = yaml.safe_load(ymlfile)
+ SRT_DBSELECT = SRT_DBCONFIG['dbselect']
+ srt_dbconfig = SRT_DBCONFIG[SRT_DBSELECT]
+ srt_dbtype = srt_dbconfig['dbtype']
+if not srt_dbtype:
+ print(f"ERROR: Missing {SRT_BASE_DIR}/srt_dbconfig.yml'")
+ exit(1)
+if ("mysql" == srt_dbtype) or ('1' == os.getenv('SRT_MYSQL', '0')):
+ import MySQLdb
+if ("postgres" == srt_dbtype) or ('1' == os.getenv('SRT_POSTGRES', '0')):
+ import psycopg2
+ from psycopg2.extras import RealDictCursor
+
+# quick development/debugging support
+def _log(msg):
+ DBG_LVL = os.environ['SRTDBG_LVL'] if ('SRTDBG_LVL' in os.environ) else 2
+ DBG_LOG = os.environ['SRTDBG_LOG'] if ('SRTDBG_LOG' in os.environ) else '/tmp/srt_dbg.log'
+ if 1 == DBG_LVL:
+ print(msg)
+ elif 2 == DBG_LVL:
+ f1=open(DBG_LOG, 'a')
+ f1.write("|" + msg + "|\n" )
+ f1.close()
+
+#with open(f"{SRT_BASE_DIR}/db_migration_config.yml", "r") as migfile:
+# DB_MIG_CONFIG = yaml.safe_load(migfile)
+
+#################################
+# Debug Support
+#
+
+SQL_TRACE_log = []
+SQL_VERBOSE_log = []
+
+# Enable debug tracking, optional context
+def SQL_DEBUG(is_trace,context=None,is_verbose=False):
+ global SQL_TRACE
+ global SQL_VERBOSE
+ global SQL_CONTEXT
+ SQL_TRACE = is_trace
+ if context:
+ SQL_CONTEXT = context
+ if is_verbose:
+ SQL_VERBOSE = context
+ if SQL_TRACE:
+ print("SRTSQL_DEBUG:Trace=%s,Context=%s,Verbose=%s)" % (SQL_TRACE,context,is_verbose))
+ sys.stdout.flush()
+
+def _SQL_GET_MS():
+ if not SQL_TRACE: return 0
+ dt = datetime.now()
+ return (dt.minute * 100000000) + (dt.second * 1000000) + dt.microsecond
+
+def _SQL_TRACE_LOG_ADD(start,stop,loop):
+ global SQL_TRACE_log
+ if not SQL_TRACE: return
+ SQL_TRACE_log.append([SQL_CONTEXT,start,stop,loop])
+
+def SQL_DUMP():
+ if not SQL_TRACE: return
+ if not os.path.isdir(SQL_LOG_DIR):
+ os.makedirs(SQL_LOG_DIR)
+ log_file = '%s/SQL_TRACE_%s.log' % (SQL_LOG_DIR,SQL_CONTEXT)
+ with open(log_file, 'w') as fd:
+ print(" (Context) Start Stop (Retries)",file=fd)
+ print("===============================================",file=fd)
+ for context,start,stop,loop in SQL_TRACE_log:
+ print("sql_dump:(%3s) %d to %d (%d)" % (context[:3],start,stop,loop),file=fd)
+ print("SQL debug trace log:%s" % log_file)
+
+def SQL_DUMP_COMPARE(param,is_csv=False):
+ tag1,tag2 = param.split(',')
+ if not os.path.isdir(SQL_LOG_DIR):
+ os.makedirs(SQL_LOG_DIR)
+ log1 = '%s/SQL_TRACE_%s.log' % (SQL_LOG_DIR,tag1)
+ log2 = '%s/SQL_TRACE_%s.log' % (SQL_LOG_DIR,tag2)
+
+ log = []
+
+ def load_log(logfile):
+ p = re.compile(r'sql_dump:\((\w+)\) (\d+) to (\d+) \((\d+)\)')
+ with open(logfile, 'r') as fs:
+ for line in fs.readlines():
+ # sql_dump:(JOB) 39290879 to 39293849 (0)
+ m = p.match(line)
+ if not m:
+ continue
+ tag,start,stop,retry = m.groups()
+ log.append([tag,int(start),int(stop),retry])
+ # Load the logs
+ load_log(log1)
+ load_log(log2)
+ # Sort the log
+ def sortOnStart(e):
+ return e[1]
+ log.sort(key=sortOnStart)
+
+ # Display log table with diffs
+ if not is_csv:
+ print(" # |Tag|Start uSec|Stop uSec |Re|(diff prev )|(diff next )|(diff write)")
+ print("======|===|==========|==========|==|============|============|============")
+ else:
+ print("Index,Tag,Start,Stop,Retries,Diff_prev,Diff_next,Diff write")
+ logmax = len(log)
+ i = -1
+ for tag,start,stop,retry in log:
+ i += 1
+ if i == 0:
+ pre_diff = 0
+ else:
+ pre_diff = log[i][1] - log[i-1][2]
+ if i == (logmax - 1):
+ post_diff = 0
+ else:
+ post_diff = log[i+1][1] - log[i][2]
+ write_diff = log[i][2] - log[i][1]
+ if not is_csv:
+ print("[%4d]:%s,%010d,%010d,%s (^ %8d) (v %8d) (~ %8d)" % (i,tag,start,stop,retry,pre_diff,post_diff,write_diff))
+ else:
+ print("%d,%s,%010d,%010d,%s,%8d,%8d,%8d" % (i,tag,start,stop,retry,pre_diff,post_diff,write_diff))
+ if SQL_VERBOSE:
+ print('')
+ print('Executed SQL commands:')
+ for line in SQL_VERBOSE_log:
+ print(line)
+ print('')
+
+def SQL_FETCH_INDEXES(conn, dbconfig=None):
+ # Define the database type connection
+ if not dbconfig:
+ dbconfig = srt_dbconfig
+ dbtype = dbconfig['dbtype']
+
+ #Goal: Create a data structure that has:
+ # ordered list of tables
+ # table column names
+ # index(ordinal) of table
+ # columns: (table_name, column_name, ordinal_postion) -> Should be list[list]
+ # Formatting should not be done in the subroutine...do it in srtool_common.py
+ # (name, value) -> name is table_name,_column name and value is the index
+ # Should be returned as tuples rather than preformatted strings
+ # Returns should be consistent (also for error)
+
+ cur = conn.cursor()
+ if 'postgres' == dbtype:
+ sql = "SELECT * FROM information_schema.columns where table_schema = 'public' order by table_name,ordinal_position;"
+ print("cursor stat: {}".format(cur))
+ cur.execute(sql)
+ columns = cur.description
+ results = [{columns[index][0]:column for index, column in enumerate(value)} for value in cur.fetchall()]
+ tables = {}
+ # TODO last line of for is hardcoded in postgres format
+ for i in results:
+ if i['table_name'] not in tables:
+ # {'table_name' : 'column_name', 'ordinal_position'}
+ tables[i['table_name']] = {'ordinal_position' : 'column_name'}
+ tables[i['table_name']][i['ordinal_position']-1] = i['column_name']
+ for val_d in tables:
+ tables[val_d].pop('ordinal_position')
+ ret_list = []
+ for table in tables:
+ table_items = tables[table].items()
+ sorted_tabl = sorted(table_items)
+ for offset,i in enumerate(sorted_tabl):
+ table = table.replace('orm_','')
+ ret_list.append(("{}_{}".format(table.upper(), i[1].upper()), offset))
+ return(ret_list)
+ elif 'sqlite' == dbtype:
+ database_file = dbconfig['path']
+ create_re = re.compile(r"CREATE TABLE[A-Z ]* \"(\w+)\" \((.+)\);")
+ try:
+ cmd = ('sqlite3', database_file, '.schema') # must be abstracted
+ output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
+ except subprocess.CalledProcessError as e:
+ return([("ERROR","(%d) %s" % (e.returncode, e.output))])
+ ret_list = []
+ # print('RET LIST: {}'.format(ret_list))
+ # print('cmd OUTPUT: {}'.format(output))
+ # problem -> for loop is not executing (nothing returned from command 'sqlite3 srt-backup.sqlite .schema')
+ for line in output.decode("utf-8").splitlines():
+ print(line)
+ print('In for loop - retlist: {}'.format(ret_list))
+ match = create_re.match(line)
+ if not match:
+ print('ERROR: no match')
+ continue
+
+ table = match.group(1).upper()
+ table = table.replace('ORM_','')
+
+ columns = match.group(2)
+ for i, col in enumerate(columns.split(',')):
+ col = col.strip()
+ name = col[1:]
+ #
+ try:
+ name = name[:name.index('"')]
+ print('NOTE: passed try #2: {}'.format(name))
+ except Exception as e:
+ return([("ERROR","%s:%s:" % (e,col))])
+ name = col[:col.index(' ')]
+ ret_list.append(("%s_%s" % (table.upper(),name.upper()), i))
+ return(ret_list)
+ else:
+ return([("ERROR","No support for MySQL or MariahDB. Update coming..."),])
+
+#################################
+# SQL wrapper methods
+#
+
+def _SQL_ACTION(action,cur_conn,sql=None,params=None,dbconfig=None):
+ # Define the database type connection
+ if not dbconfig:
+ dbconfig = srt_dbconfig
+ dbtype = dbconfig['dbtype']
+
+ ret = None
+ timeout_count = 0
+ if SQL_VERBOSE:
+ SQL_VERBOSE_log.append("SQL_ACTION:%s:%s:%s:%s:" % (action,cur_conn,sql,params))
+ sleep_time = SQL_TIMEOUT_TIME
+ start = _SQL_GET_MS()
+ exception_occured = False
+ for x in range(0, SQL_TIMEOUT_MAX):
+ exception_occured = False
+ try:
+ if 'exec' == action:
+ # to account for difference between mysql/postgres and sqlite
+ if not dbtype == "sqlite":
+ sql = sql.replace("?", "%s")
+ if dbtype == "postgres": # for postgres case insenstive issue
+ sql = sql.replace('`', '"') # replace backticks with double quotes
+ if "INSERT INTO" in sql:
+ sql += " RETURNING *"
+ camel_case_columns = ["lastModifiedDate", "publishedDate", "cvssV3_baseScore", "cvssV3_baseSeverity", "cvssV2_baseScore", "cvssV2_severity"]
+ for col in camel_case_columns:
+ if col in sql and f'"{col}"' not in sql:
+ sql = sql.replace(f'{col}', f'"{col}"')
+ if params:
+ ret = cur_conn.execute(sql, params)
+ else:
+ ret = cur_conn.execute(sql)
+ elif 'commit' == action:
+ ret = cur_conn.commit()
+ except Exception as e:
+ exception_occured = True
+ print(f"Error occured while running\nsql: {sql}\nparams:{params}\naction:{action}")
+ print(e)
+ time.sleep(sleep_time)
+ timeout_count += 1
+ pass
+ finally:
+ _SQL_TRACE_LOG_ADD(start,_SQL_GET_MS(),timeout_count)
+ break
+ else:
+ # Give up, dump what we had, and trigger a proper error
+ SQL_TRACE_log_add(start,_SQL_GET_MS(),timeout_count)
+ sql_dump()
+ if 'exec' == action:
+ ret = cur_conn.execute(sql,params)
+ elif 'commit' == action:
+ ret = cur_conn.commit()
+ if not dbtype == "sqlite":
+ ret = cur_conn
+ return ret
+
+def SQL_CONNECT(column_names=False,dbconfig=None):
+ # Define the database type connection
+ if not dbconfig:
+ dbconfig = srt_dbconfig
+ dbtype = dbconfig['dbtype']
+
+ if dbtype == "mysql":
+ conn = MySQLdb.connect(
+ passwd=dbconfig["passwd"],
+ db=dbconfig["name"],
+ host=dbconfig["host"],
+ user=dbconfig["user"],
+ port=dbconfig["port"]
+ )
+ return conn
+ elif dbtype == "postgres":
+ if column_names:
+ conn = psycopg2.connect(
+ password=dbconfig["passwd"],
+ database=dbconfig["name"],
+ host=dbconfig["host"],
+ user=dbconfig["user"],
+ port=dbconfig["port"],
+ cursor_factory=RealDictCursor,
+ )
+ else:
+ conn = psycopg2.connect(
+ password=dbconfig["passwd"],
+ database=dbconfig["name"],
+ host=dbconfig["host"],
+ user=dbconfig["user"],
+ port=dbconfig["port"],
+ )
+ return conn
+ else: # Sqlite
+ conn = sqlite3.connect(dbconfig["path"])
+ if column_names:
+ conn.row_factory = sqlite3.Row
+ return conn
+
+def SQL_CURSOR(conn,dbconfig=None):
+ return(conn.cursor())
+
+def SQL_EXECUTE(cur,sql,params=None,dbconfig=None):
+ return(_SQL_ACTION('exec',cur,sql,params,dbconfig))
+
+def SQL_COMMIT(conn,dbconfig=None):
+ return(_SQL_ACTION('commit',conn,dbconfig))
+
+def SQL_CLOSE_CUR(cur,dbconfig=None):
+ return(cur.close())
+
+def SQL_CLOSE_CONN(conn,dbconfig=None):
+ return(conn.close())
+
+def SQL_GET_LAST_ROW_INSERTED_ID(cur,dbconfig=None):
+ # Define the database type connection
+ if not dbconfig:
+ dbconfig = srt_dbconfig
+ dbtype = dbconfig['dbtype']
+
+ if dbtype == "postgres":
+ return(SQL_FETCH_ONE(cur).id)
+ else:
+ return cur.lastrowid
+
+def SQL_FETCH_ONE(cur,dbconfig=None):
+ columns = cur.description
+ result = {columns[index][0]:column for index, column in enumerate(cur.fetchone()) }
+ return SimpleNamespace(**result)
+
+def SQL_FETCH_ALL(cur,dbconfig=None):
+ columns = cur.description
+ results = [{columns[index][0]:column for index, column in enumerate(value)} for value in cur.fetchall()]
+ return [SimpleNamespace(**result) for result in results]
+
+def GET_DB_TYPE(dbconfig=None):
+ # Define the database type connection
+ if not dbconfig:
+ dbconfig = srt_dbconfig
+ dbtype = dbconfig['dbtype']
+ return dbtype
+
+def SQL_BATCH_WRITE(cur_conn, table, records, dbconfig=None, fields=None, override_values=None):
+ '''
+ Batch write wrapper function
+ - Records must contain tuples of the same length
+
+ :param cur_conn: SQL connection
+ :param table: target table name
+ :param records: list of tuples containing records to be inserted
+ :param dbconfig: dbconfig['dbtype'] contains DB type
+ :param fields: list of specified fields to insert into
+ :param override_values: list of specified values
+
+ :return: SQL DB connection's cursor
+ '''
+
+ # Define the database type connection
+ if not dbconfig:
+ dbconfig = srt_dbconfig
+ dbtype = dbconfig['dbtype']
+
+ # invalid parameters check
+ if cur_conn == None or table == None or records == None:
+ raise Exception("SQL Batch Write Failed: invalid parameters provided")
+ if not isinstance(records, list) and not isinstance(records, tuple):
+ raise Exception("SQL Batch Write Failed: records must be of type 'list' or 'tuple'")
+
+ # invalid number of fields supplied check
+ std_record_ct = len(records[0])
+ for record in records:
+ if len(record) != std_record_ct:
+ raise BaseException("SQL Batch Write Failed: incorrect number of fields supplied")
+
+ # bulk insert
+ if fields != None:
+ _fields = "(" + ','.join([str(field) for field in fields]) + ")"
+ else:
+ _fields = ''
+
+ if dbtype == 'sqlite':
+ if override_values is None:
+ _ov = f"({','.join(['?'] * len(records[0]))})"
+ elif isinstance(override_values, list):
+ _ov = "(" + ','.join([str(ov) for ov in override_values]) + ")"
+ elif isinstance(override_values, str):
+ _ov = override_values
+ cur_conn.executemany(f"INSERT INTO {table}{_fields} VALUES{_ov};", records)
+ elif dbtype == 'postgres':
+ if override_values is None:
+ _ov = f"({','.join(['%s'] * len(records[0]))})"
+ elif isinstance(override_values, list):
+ _ov = "(" + ','.join([str(ov) for ov in override_values]) + ")"
+ elif isinstance(override_values, str):
+ _ov = override_values
+ psycopg2.extras.execute_batch(cur_conn, f"INSERT INTO {table}{_fields} VALUES{_ov};", records)
+
+ # conn.commit()
+ return cur_conn
+
+
+def SQL_BATCH_UPDATE(cur_conn, table, values_list, set_field, where_field, dbconfig=None):
+ '''
+ Batch update wrapper function (not tested)
+ - Records must contain tuples of the same length
+
+ :param cur_conn: SQL connection
+ :param table: target table name
+ :param values_list: parameter values provided to the SQL query
+ :param set_field: list containing the 'SET' parameterized fields in the SQL query
+ :param where_field: list containing the 'WHERE' parameterized fields in the SQL query
+ :param dbconfig: dbconfig['dbtype'] contains DB type
+
+ :return: SQL DB connection's cursor
+ '''
+
+ # Define the database type connection
+ if not dbconfig:
+ dbconfig = srt_dbconfig
+ dbtype = dbconfig['dbtype']
+
+ # invalid parameters check
+ if (cur_conn == None) or (table == None) or (set_field == None) or (where_field == None) or (values_list == None):
+ raise Exception("SQL Batch Update Failed: invalid parameters provided")
+
+ # invalid number of fields supplied check
+ if (len(set_field) + len(where_field)) != len(values_list[0]):
+ raise Exception(f"SQL Batch Update Failed: number of fields and values supplied mismatches ({len(set_field)},{len(where_field)},{len(values_list)})")
+
+ if dbtype == 'sqlite':
+ # generate the SQL command for sqlite
+ update_comm = f"UPDATE {table}"
+ set_comm = " SET " + ", ".join([f"{s_field} = ?" for s_field in set_field])
+ where_comm = " WHERE " + ", ".join([f"{w_field} = ?" for w_field in where_field])
+ sql = update_comm + set_comm + where_comm + ";"
+ cur_conn.executemany(sql, values_list)
+
+ elif dbtype == 'postgres':
+ # generate the SQL command for postgresql
+ update_comm = f"UPDATE {table}"
+ set_comm = " SET " + ", ".join([f"{s_field} = %s" for s_field in set_field])
+ where_comm = " WHERE " + ", ".join([f"{w_field} = %s" for w_field in where_field])
+ sql = update_comm + set_comm + where_comm + ";"
+ psycopg2.extras.execute_batch(cur_conn, sql, values_list)
+
+ # conn.commit()
+ return cur_conn
+