aboutsummaryrefslogtreecommitdiffstats
path: root/recipes-security/bastille/files/set_required_questions.py
blob: 4a28358c3501afdf4432b2d9e1d55d7961cbd437 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python

#Signed-off-by: Anne Mulhern <mulhern@yoctoproject.org>

import argparse, os, shutil, sys, tempfile, traceback
from os import path



def get_config(lines):
  """
  From a sequence of lines retrieve the question file name, question identifier
  pairs.
  """
  for l in lines:
    if not l.startswith("#"):
      try:
        (coord, value) = l.split("=")
        try:
          (fname, ident) = coord.split(".")
          yield fname, ident
        except ValueError as e:
          raise ValueError("Badly formatted coordinates %s in line %s." % (coord, l.strip()))
      except ValueError as e:
        raise ValueError("Skipping badly formatted line %s, %s" % (l.strip(), e))



def check_contains(line, name):
  """
  Check if the value field for REQUIRE_DISTRO contains the given name.
  @param name line The REQUIRE_DISTRO line
  @param name name The name to look for in the value field of the line.
  """
  try:
    (label, distros) = line.split(":")
    return name in distros.split()
  except ValueError as e:
    raise ValueError("Error splitting REQUIRE_DISTRO line: %s" % e)



def add_requires(the_ident, distro, lines):

  """
  Yield a sequence of lines the same as lines except that where
  the_ident matches a question identifier change the REQUIRE_DISTRO so that
  it includes the specified distro.

  @param name the_ident The question identifier to be matched.
  @param name distro The distribution to added to the questions REQUIRE_DISTRO
                     field.
  @param lines The sequence to be processed.
  """
  for l in lines:
    yield l
    if l.startswith("LABEL:"):
      try:
        (label, ident) = l.split(":")
        if ident.strip() == the_ident:
          break
      except ValueError as e:
        raise ValueError("Unexpected line %s in questions file." % l.strip())
  for l in lines:
    if l.startswith("REQUIRE_DISTRO"):
      if not check_contains(l, distro):
        yield l.rstrip() + " " + distro + "\n"
      else:
        yield l
      break;
    else:
      yield l
  for l in lines:
    yield l



def xform_file(qfile, distro, qlabel):
  """
  Transform a Questions file.
  @param name qfile The designated questions file.
  @param name distro The distribution to add to the required distributions.
  @param name qlabel The question label for which the distro is to be added.
  """
  questions_in = open(qfile)
  questions_out = tempfile.NamedTemporaryFile(delete=False)
  for l in add_requires(qlabel, distro, questions_in):
    questions_out.write(l)
  questions_out.close()
  questions_in.close()
  shutil.copystat(qfile, questions_out.name)
  os.remove(qfile)
  shutil.move(questions_out.name, qfile)



def handle_args(parser):
  parser.add_argument('config_file',
                      help = "Configuration file path.")
  parser.add_argument('questions_dir',
                      help = "Directory containing Questions files.")
  parser.add_argument('--distro', '-d',
                      help = "The distribution, the default is Yocto.",
                      default = "Yocto")
  parser.add_argument('--debug', '-b',
                      help = "Print debug information.",
                      action = 'store_true')
  return parser.parse_args()



def check_args(args):
  args.config_file = os.path.abspath(args.config_file)
  args.questions_dir = os.path.abspath(args.questions_dir)

  if not os.path.isdir(args.questions_dir):
    raise ValueError("Specified Questions directory %s does not exist or is not a directory." % args.questions_dir)

  if not os.path.isfile(args.config_file):
    raise ValueError("Specified configuration file %s not found." % args.config_file)



def main():
  opts = handle_args(argparse.ArgumentParser(description="A simple script that sets required questions based on the question/answer pairs in a configuration file."))

  try:
    check_args(opts)
  except ValueError as e:
    if opts.debug:
      traceback.print_exc()
    else:
      sys.exit("Fatal error:\n%s" % e)


  try:
    config_in = open(opts.config_file)
    for qfile, qlabel in get_config(config_in):
      questions_file = os.path.join(opts.questions_dir, qfile + ".txt")
      xform_file(questions_file, opts.distro, qlabel)
    config_in.close()

  except IOError as e:
    if opts.debug:
      traceback.print_exc()
    else:
      sys.exit("Fatal error reading or writing file:\n%s" % e)
  except ValueError as e:
    if opts.debug:
      traceback.print_exc()
    else:
      sys.exit("Fatal error:\n%s" % e)



if __name__ == "__main__":
  main()