aboutsummaryrefslogtreecommitdiffstats
path: root/lib/python2.7/site-packages/sqlalchemy_migrate-0.7.2-py2.7.egg/migrate/versioning/version.py
blob: d5a5be98bae5f2a18c24271c8a0002c97d9d63c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import re
import shutil
import logging

from migrate import exceptions
from migrate.versioning import pathed, script
from datetime import datetime


log = logging.getLogger(__name__)

class VerNum(object):
    """A version number that behaves like a string and int at the same time"""

    _instances = dict()

    def __new__(cls, value):
        val = str(value)
        if val not in cls._instances:
            cls._instances[val] = super(VerNum, cls).__new__(cls)
        ret = cls._instances[val]
        return ret

    def __init__(self,value):
        self.value = str(int(value))
        if self < 0:
            raise ValueError("Version number cannot be negative")

    def __add__(self, value):
        ret = int(self) + int(value)
        return VerNum(ret)

    def __sub__(self, value):
        return self + (int(value) * -1)

    def __cmp__(self, value):
        return int(self) - int(value)

    def __repr__(self):
        return "<VerNum(%s)>" % self.value

    def __str__(self):
        return str(self.value)

    def __int__(self):
        return int(self.value)


class Collection(pathed.Pathed):
    """A collection of versioning scripts in a repository"""

    FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')

    def __init__(self, path):
        """Collect current version scripts in repository
        and store them in self.versions
        """
        super(Collection, self).__init__(path)
        
        # Create temporary list of files, allowing skipped version numbers.
        files = os.listdir(path)
        if '1' in files:
            # deprecation
            raise Exception('It looks like you have a repository in the old '
                'format (with directories for each version). '
                'Please convert repository before proceeding.')

        tempVersions = dict()
        for filename in files:
            match = self.FILENAME_WITH_VERSION.match(filename)
            if match:
                num = int(match.group(1))
                tempVersions.setdefault(num, []).append(filename)
            else:
                pass  # Must be a helper file or something, let's ignore it.

        # Create the versions member where the keys
        # are VerNum's and the values are Version's.
        self.versions = dict()
        for num, files in tempVersions.items():
            self.versions[VerNum(num)] = Version(num, path, files)

    @property
    def latest(self):
        """:returns: Latest version in Collection"""
        return max([VerNum(0)] + self.versions.keys())

    def _next_ver_num(self, use_timestamp_numbering):
        if use_timestamp_numbering == True:
            return VerNum(int(datetime.utcnow().strftime('%Y%m%d%H%M%S')))
        else:
            return self.latest + 1

    def create_new_python_version(self, description, **k):
        """Create Python files for new version"""
        ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
        extra = str_to_filename(description)

        if extra:
            if extra == '_':
                extra = ''
            elif not extra.startswith('_'):
                extra = '_%s' % extra

        filename = '%03d%s.py' % (ver, extra)
        filepath = self._version_path(filename)

        script.PythonScript.create(filepath, **k)
        self.versions[ver] = Version(ver, self.path, [filename])
        
    def create_new_sql_version(self, database, description, **k):
        """Create SQL files for new version"""
        ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
        self.versions[ver] = Version(ver, self.path, [])

        extra = str_to_filename(description)

        if extra:
            if extra == '_':
                extra = ''
            elif not extra.startswith('_'):
                extra = '_%s' % extra

        # Create new files.
        for op in ('upgrade', 'downgrade'):
            filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op)
            filepath = self._version_path(filename)
            script.SqlScript.create(filepath, **k)
            self.versions[ver].add_script(filepath)
        
    def version(self, vernum=None):
        """Returns latest Version if vernum is not given.
        Otherwise, returns wanted version"""
        if vernum is None:
            vernum = self.latest
        return self.versions[VerNum(vernum)]

    @classmethod
    def clear(cls):
        super(Collection, cls).clear()

    def _version_path(self, ver):
        """Returns path of file in versions repository"""
        return os.path.join(self.path, str(ver))


class Version(object):
    """A single version in a collection
    :param vernum: Version Number 
    :param path: Path to script files
    :param filelist: List of scripts
    :type vernum: int, VerNum
    :type path: string
    :type filelist: list
    """

    def __init__(self, vernum, path, filelist):
        self.version = VerNum(vernum)

        # Collect scripts in this folder
        self.sql = dict()
        self.python = None

        for script in filelist:
            self.add_script(os.path.join(path, script))
    
    def script(self, database=None, operation=None):
        """Returns SQL or Python Script"""
        for db in (database, 'default'):
            # Try to return a .sql script first
            try:
                return self.sql[db][operation]
            except KeyError:
                continue  # No .sql script exists

        # TODO: maybe add force Python parameter?
        ret = self.python

        assert ret is not None, \
            "There is no script for %d version" % self.version
        return ret

    def add_script(self, path):
        """Add script to Collection/Version"""
        if path.endswith(Extensions.py):
            self._add_script_py(path)
        elif path.endswith(Extensions.sql):
            self._add_script_sql(path)

    SQL_FILENAME = re.compile(r'^.*\.sql')

    def _add_script_sql(self, path):
        basename = os.path.basename(path)
        match = self.SQL_FILENAME.match(basename)
        
        if match:
            basename = basename.replace('.sql', '')
            parts = basename.split('_')
            if len(parts) < 3:
                raise exceptions.ScriptError(
                    "Invalid SQL script name %s " % basename + \
                    "(needs to be ###_description_database_operation.sql)")
            version = parts[0]
            op = parts[-1]
            dbms = parts[-2]
        else:
            raise exceptions.ScriptError(
                "Invalid SQL script name %s " % basename + \
                "(needs to be ###_description_database_operation.sql)")

        # File the script into a dictionary
        self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)

    def _add_script_py(self, path):
        if self.python is not None:
            raise exceptions.ScriptError('You can only have one Python script '
                'per version, but you have: %s and %s' % (self.python, path))
        self.python = script.PythonScript(path)


class Extensions:
    """A namespace for file extensions"""
    py = 'py'
    sql = 'sql'

def str_to_filename(s):
    """Replaces spaces, (double and single) quotes
    and double underscores to underscores
    """

    s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
    while '__' in s:
        s = s.replace('__', '_')
    return s