diff options
Diffstat (limited to 'lib/orm/models.py')
-rw-r--r-- | lib/orm/models.py | 295 |
1 files changed, 284 insertions, 11 deletions
diff --git a/lib/orm/models.py b/lib/orm/models.py index 9b4f99ce..f5016b7d 100644 --- a/lib/orm/models.py +++ b/lib/orm/models.py @@ -4,7 +4,7 @@ # # Security Response Tool Implementation # -# Copyright (C) 2017 Wind River Systems +# Copyright (C) 2017-2021 Wind River Systems # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 as @@ -27,21 +27,29 @@ from django.db import transaction from django.core import validators from django.conf import settings import django.db.models.signals +from django.db.models import F, Q, Sum, Count +from django.contrib.auth.models import AbstractUser, Group, AnonymousUser +from srtgui.api import execute_process, execute_process_close_fds from users.models import SrtUser import sys import os import re +import itertools from signal import SIGUSR1 from datetime import datetime import json +import subprocess +import time +import signal +import pytz import logging logger = logging.getLogger("srt") # quick development/debugging support -from srtgui.api import _log +from srtgui.api import _log, parameter_join # Sqlite support @@ -74,7 +82,6 @@ if 'sqlite' in settings.DATABASES['default']['ENGINE']: return _base_insert(self, *args, **kwargs) QuerySet._insert = _insert - from django.utils import six def _create_object_from_params(self, lookup, params): """ Tries to create an object using passed params. @@ -89,7 +96,6 @@ if 'sqlite' in settings.DATABASES['default']['ENGINE']: return self.get(**lookup), False except self.model.DoesNotExist: pass - six.reraise(*exc_info) QuerySet._create_object_from_params = _create_object_from_params @@ -331,8 +337,10 @@ class Update(): PUBLISH_DATE = "Publish_Date(%s,%s)" AFFECTED_COMPONENT = "Affected_Component(%s,%s)" ACKNOWLEDGE_DATE = "AcknowledgeDate(%s,%s)" + PUBLIC = "Public(%s,%s)" ATTACH_CVE = "Attach_CVE(%s)" DETACH_CVE = "Detach_CVE(%s)" + MERGE_CVE = "Merge_CVE(%s)" ATTACH_VUL = "Attach_Vulnerability(%s)" DETACH_VUL = "Detach_Vulnerability(%s)" ATTACH_INV = "Attach_Investigration(%s)" @@ -404,11 +412,11 @@ class HelpText(models.Model): text = models.TextField() -#UPDATE_FREQUENCY: 0 = every minute, 1 = every hour, 2 = every day, 3 = every week, 4 = every month, 5 = every year +#UPDATE_FREQUENCY: 0 = every n minutes, 1 = every hour, 2 = every day, 3 = every week, 4 = every month, 5 = on demand class DataSource(models.Model): search_allowed_fields = ['key', 'name', 'description', 'init', 'update', 'lookup'] - #UPDATE FREQUENCT + #UPDATE FREQUENCY MINUTELY = 0 HOURLY = 1 DAILY = 2 @@ -416,6 +424,7 @@ class DataSource(models.Model): MONTHLY = 4 ONDEMAND = 5 ONSTARTUP = 6 + PREINIT = 7 FREQUENCY = ( (MINUTELY, 'Minute'), (HOURLY, 'Hourly'), @@ -424,6 +433,7 @@ class DataSource(models.Model): (MONTHLY, 'Monthly'), (ONDEMAND, 'OnDemand'), (ONSTARTUP, 'OnStartup'), + (PREINIT, 'PreInit'), ) # Global date format @@ -434,7 +444,7 @@ class DataSource(models.Model): LOOKUP_MISSING = 'LOOKUP-MISSING' PREVIEW_SOURCE = 'PREVIEW-SOURCE' - key = models.CharField(max_length=20) + key = models.CharField(max_length=80) data = models.CharField(max_length=20) source = models.CharField(max_length=20) name = models.CharField(max_length=20) @@ -570,6 +580,9 @@ class Cve(models.Model): def get_publish_text(self): return Cve.PUBLISH_STATE[int(self.publish_state)][1] @property + def get_public_text(self): + return 'Public' if self.public else 'Private' + @property def is_local(self): try: CveLocal.objects.get(name=self.name) @@ -592,6 +605,47 @@ class Cve(models.Model): if the_comments == the_packages: return the_comments return '%s' % (the_comments) + def propagate_private(self): + # Gather allowed users + user_id_list = [] + for cveaccess in CveAccess.objects.filter(cve=self): + user_id_list.append(cveaccess.user_id) + _log("BOO1:user_id=%s" % cveaccess.user_id) + + # Decend the object tree + for c2v in CveToVulnerablility.objects.filter(cve=self): + vulnerability = Vulnerability.objects.get(id=c2v.vulnerability_id) + _log("BOO2:v=%s,%s" % (vulnerability.name,self.public)) + vulnerability.public = self.public + vulnerability.save() + if not self.public: + # Remove existing users + for va in VulnerabilityAccess.objects.filter(vulnerability=vulnerability): + _log("BOO3:DEL:v=%s,%s" % (vulnerability.name,va.id)) + va.delete() + # Add valid user list + for user_id in user_id_list: + va,create = VulnerabilityAccess.objects.get_or_create(vulnerability=vulnerability,user_id=user_id) + _log("BOO4:ADD:v=%s,%s,%s" % (vulnerability.name,va.id,user_id)) + va.save() + + for v2i in VulnerabilityToInvestigation.objects.filter(vulnerability = vulnerability): + investigation = Investigation.objects.get(id=v2i.investigation_id) + _log("BOO5:i=%s,%s" % (investigation.name,self.public)) + investigation.public = self.public + investigation.save() + if not self.public: + # Remove existing users + for ia in InvestigationAccess.objects.filter(investigation=investigation): + _log("BOO6:DEL:v=%s,%s" % (investigation.name,ia.id)) + ia.delete() + # Add valid user list + for user_id in user_id_list: + ia,create = InvestigationAccess.objects.get_or_create(investigation=investigation,user_id=user_id) + _log("BOO7:ADD:i=%s,%s,%s" % (investigation.name,ia.id,user_id)) + ia.save() + + class CveDetail(): # CPE item list @@ -727,6 +781,10 @@ class CveSource(models.Model): cve = models.ForeignKey(Cve,related_name="cve_parent",blank=True, null=True,on_delete=models.CASCADE,) datasource = models.ForeignKey(DataSource,related_name="cve_datasource", blank=True, null=True,on_delete=models.CASCADE,) +class CveAccess(models.Model): + cve = models.ForeignKey(Cve,related_name="cve_users",on_delete=models.CASCADE,) + user = models.ForeignKey(SrtUser,related_name="cve_user",on_delete=models.CASCADE,) + class CveHistory(models.Model): search_allowed_fields = ['cve__name', 'comment', 'date', 'author'] cve = models.ForeignKey(Cve,related_name="cve_history",default=None, null=True, on_delete=models.CASCADE,) @@ -764,8 +822,8 @@ class Package(models.Model): ) mode = models.IntegerField(choices=MODE, default=FOR) - name = models.CharField(max_length=50, blank=True) - realname = models.CharField(max_length=50, blank=True) + name = models.CharField(max_length=80, blank=True) + realname = models.CharField(max_length=80, blank=True) invalidname = models.TextField(blank=True) weight = models.IntegerField(default=0) # computed count data @@ -812,7 +870,7 @@ class Package(models.Model): class PackageToCve(models.Model): package = models.ForeignKey(Package,related_name="package2cve",on_delete=models.CASCADE,) cve = models.ForeignKey(Cve,related_name="cve2package",on_delete=models.CASCADE,) - applicable = models.NullBooleanField(default=True, null=True) + applicable = models.BooleanField(null=True) # CPE Filtering @@ -860,6 +918,11 @@ class CveReference(models.Model): name = models.CharField(max_length=100, null=True) datasource = models.ForeignKey(DataSource,related_name="source_references", blank=True, null=True,on_delete=models.CASCADE,) +class RecipeTable(models.Model): + search_allowed_fields = ['recipe_name'] + recipe_name = models.CharField(max_length=50) + + # PRODUCT class Product(models.Model): @@ -870,7 +933,7 @@ class Product(models.Model): name = models.CharField(max_length=40) version = models.CharField(max_length=40) profile = models.CharField(max_length=40) - cpe = models.CharField(max_length=40) + cpe = models.CharField(max_length=255) defect_tags = models.TextField(blank=True, default='') product_tags = models.TextField(blank=True, default='') @@ -971,6 +1034,9 @@ class Vulnerability(models.Model): if self.cve_primary_name: return "%s (%s)" % (self.name,self.cve_primary_name) return "%s" % (self.name) + @property + def get_public_text(self): + return 'Public' if self.public else 'Private' @staticmethod def new_vulnerability_name(): # get next vulnerability name atomically @@ -1320,6 +1386,9 @@ class Investigation(models.Model): if self.vulnerability and self.vulnerability.cve_primary_name: return "%s (%s)" % (self.name,self.vulnerability.cve_primary_name.name) return "%s" % (self.name) + @property + def get_public_text(self): + return 'Public' if self.public else 'Private' @staticmethod def new_investigation_name(): current_investigation_index,create = SrtSetting.objects.get_or_create(name='current_investigation_index') @@ -1492,6 +1561,210 @@ class ErrorLog(models.Model): def get_severity_text(self): return ErrorLog.SEVERITY[int(self.severity)][1] +class Job(models.Model): + search_allowed_fields = ['name', 'title', 'description', 'status'] + # Job Status + NOTSTARTED = 0 + INPROGRESS = 1 + SUCCESS = 2 + ERRORS = 3 + CANCELLING = 4 + CANCELLED = 5 + STATUS = ( + (NOTSTARTED, 'NotStarted'), + (INPROGRESS, 'InProgress'), + (SUCCESS, 'Success'), + (ERRORS, 'Errors'), + (CANCELLING, 'Cancelling'), + (CANCELLED, 'Cancelled'), + ) + + # Required + name = models.CharField(max_length=50,default='') + description = models.TextField(blank=True) + command = models.TextField(blank=True) + log_file = models.TextField(blank=True) + # Optional + parent_name = models.CharField(max_length=50,default='') + options = models.TextField(blank=True) + user = models.ForeignKey(SrtUser,default=None,null=True,on_delete=models.CASCADE,) + # Managed + status = models.IntegerField(choices=STATUS, default=NOTSTARTED) + pid = models.IntegerField(default=0) + count = models.IntegerField(default=0) + max = models.IntegerField(default=0) + errors = models.IntegerField(default=0) + warnings = models.IntegerField(default=0) + refresh = models.IntegerField(default=0) + message = models.CharField(max_length=50,default='') + started_on = models.DateTimeField(null=True) + completed_on = models.DateTimeField(null=True) + + @property + def get_status_text(self): + for s_val,s_name in Job.STATUS: + if s_val == self.status: + return s_name + return "?STATUS?" + + @staticmethod + def get_recent(user=None): + """ + Return recent jobs as a list; if sprint is set, only return + jobs for that sprint + """ + + if user and not isinstance(user,AnonymousUser): + jobs = Job.objects.filter(user=user) + else: + jobs = Job.objects.all() + + finished_criteria = \ + Q(status=Job.SUCCESS) | \ + Q(status=Job.ERRORS) | \ + Q(status=Job.CANCELLED) + + recent_jobs = list(itertools.chain( + jobs.filter(status=Job.INPROGRESS).order_by("-started_on"), + jobs.filter(finished_criteria).order_by("-completed_on")[:3] + )) + + # add percentage done property to each job; this is used + # to show job progress in mrj_section.html + for job in jobs: + job.percentDone = job.completeper() + job.outcomeText = job.get_status_text + + return recent_jobs + + def completeper(self): + if self.max > 0: + completeper = (self.count * 100) // self.max + else: + completeper = 0 + return completeper + + def eta(self): + eta = datetime.now() + completeper = self.completeper() + if completeper() > 0: + eta += ((eta - self.started_on)*(100-completeper))/completeper + return eta + + @staticmethod + def start(name,description,command,options='',log_file='logs/run_job.log',job_id=1): + # The audit_job.py will set the pid and time values so that there is no db race condition + command = ['bin/common/srtool_job.py','--name',name,'--description',description,'--command',command,'--options',options,'--log',log_file] + if job_id: + command.extend(['--job-id',str(job_id)]) + _log("JOB_START:%s" % parameter_join(command)) +# subprocess.Popen(command,close_fds=True) +# result_returncode,result_stdout,result_stderr = execute_process(command) + execute_process_close_fds(command) + + def cancel(self): + if self.status == Job.INPROGRESS: + try: + if self.pid: + os.kill(self.pid, signal.SIGTERM) #or signal.SIGKILL + except Exception as e: + _log("ERROR_JOB:Cancel:%s" % (e)) + try: + self.status = Job.CANCELLING + self.completed_on = datetime.now() + self.pid = 0 + self.save() + except Exception as e: + _log("ERROR_JOB:Cancelled:%s" % (e)) + + def done(self): + if not self.pid: + return + if self.status == Job.INPROGRESS: + self.pid = 0 + self.completed_on = datetime.now() + self.status = Job.SUCCESS + ### TODO COUNT ERRORS AND WARNINGS + self.save() + elif self.status == Job.CANCELLING: + self.pid = 0 + self.completed_on = datetime.now() + self.status = Job.CANCELLED + self.errors = 1 + self.save() + + @staticmethod + def preclear_jobs(user=None,user_id=0,user_none=False): + # NOTE: preclear completed jobs so that this page comes up clean + # without completed progress bars hanging around + if (not user_id) and (not user) and (not user_none): + return + if user_none: + user_id = None + elif not user_id: + user_id = user.id + for job in Job.objects.filter(user_id=user_id): + if job.status in (Job.SUCCESS,Job.ERRORS): + job.delete() + +# Wrapper class to run internal 'jobs' with the progress bar +class Job_Local(): + job = None + log_file_fd = None + INTERNAL_COMMAND = '<internal>' + DEFAULT = -1 + DEFAULT_LOG = '.job_log.txt' + + def __init__(self, name, description='', options='', log_file=DEFAULT_LOG, user=None): + self.job = Job(name=name, description=description, options=options, log_file=log_file, user=user) + self.job.command = self.INTERNAL_COMMAND + self.job.started_on = datetime.now(pytz.utc) + self.job.completed_on = None + if log_file: + self.log_file_fd = open(self.job.log_file, 'w') + self.log_file_fd.write(f"JOB_START: {name},{description} @{self.job.started_on}\n" ) + self.job.status = Job.INPROGRESS + self.job.save() + + # If cnt == DEFAULT, increment existing cnt value + # If max == DEFAULT, use existing max value + def update(self,message,count=DEFAULT,max=DEFAULT): + if count == self.DEFAULT: + self.job.count += 1 + else: + self.job.count = count + if max != self.DEFAULT: + self.job.max = max + if self.job.count > self.job.max: + self.job.count = self.job.max + self.job.message = message + if True and self.log_file_fd: + self.log_file_fd.write(f"JOB_UPDATE({self.job.message},{self.job.count},{self.job.max})\n") + self.log_file_fd.flush() + self.job.save() + def add_warning(self,msg): + self.job.warnings += 1 + self.job.save() + if self.log_file_fd: + self.log_file_fd.write("WARNING: " + msg + "\n" ) + def add_error(self,msg): + self.job.errors += 1 + self.job.save() + if self.log_file_fd: + self.log_file_fd.write("ERROR: " + msg + "\n" ) + def done(self,sleep_time=4): + if sleep_time: + time.sleep(sleep_time) + self.update('Done',self.job.max,self.job.max) + self.job.completed_on = datetime.now(pytz.utc) + self.job.status = Job.ERRORS if self.job.errors else Job.SUCCESS + self.job.save() + if self.log_file_fd: + self.log_file_fd.write(f"JOB_STOP: W={self.job.warnings},E={self.job.errors} @{self.job.completed_on}\n" ) + self.log_file_fd.flush() + self.log_file_fd.close() + self.log_file_fd = None + # # Database Cache Support # |