sql_manager.py 4.99 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2017, Daniele Venzano
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Interface to PostgresQL for Zoe state."""

import logging

import psycopg2
import psycopg2.extras

23 24 25
from zoe_lib.config import get_conf
from zoe_lib.version import SQL_SCHEMA_VERSION
import zoe_lib.exceptions
26

27 28 29
from .service import ServiceTable
from .execution import ExecutionTable
from .port import PortTable
30 31 32
from .quota import QuotaTable
from .user import UserTable
from .role import RoleTable
33

34 35 36 37 38 39 40 41
log = logging.getLogger(__name__)

psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json)


class SQLManager:
    """The SQLManager class, should be used as a singleton."""
    def __init__(self, conf):
42
        self.dbuser = conf.dbuser
43 44 45 46 47 48 49 50 51 52
        self.password = conf.dbpass
        self.host = conf.dbhost
        self.port = conf.dbport
        self.dbname = conf.dbname
        self.schema = conf.deployment_name
        self.conn = None
        self._connect()

    def _connect(self):
        dsn = 'dbname=' + self.dbname + \
53
              ' user=' + self.dbuser + \
54 55 56 57 58 59
              ' password=' + self.password + \
              ' host=' + self.host + \
              ' port=' + str(self.port)

        self.conn = psycopg2.connect(dsn)

60 61
    def cursor(self):
        """Get a cursor, making sure the connection to the database is established."""
62 63 64 65 66
        try:
            cur = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
        except psycopg2.InterfaceError:
            self._connect()
            cur = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
67 68 69 70 71 72
        try:
            cur.execute('SET search_path TO {},public'.format(self.schema))
        except psycopg2.InternalError:
            self._connect()
            cur = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
            cur.execute('SET search_path TO {},public'.format(self.schema))
73 74
        return cur

75 76 77 78
    def commit(self):
        """Commit a transaction."""
        self.conn.commit()

79 80 81
    @property
    def executions(self) -> ExecutionTable:
        """Access the execution state."""
82
        return ExecutionTable(self)
83

84 85 86
    @property
    def services(self) -> ServiceTable:
        """Access the service state."""
87
        return ServiceTable(self)
88

89 90 91
    @property
    def ports(self) -> PortTable:
        """Access the port state."""
92
        return PortTable(self)
93

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    @property
    def quota(self) -> QuotaTable:
        """Access the quota state."""
        return QuotaTable(self)

    @property
    def role(self) -> RoleTable:
        """Access the role state."""
        return RoleTable(self)

    @property
    def user(self) -> UserTable:
        """Access the user state."""
        return UserTable(self)

109
    def _create_tables(self):
110 111 112
        self.quota.create()
        self.role.create()
        self.user.create()
113 114 115
        self.executions.create()
        self.services.create()
        self.ports.create()
116

117 118 119
    def init_db(self, force=False):
        """DB init entrypoint."""
        cur = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
120

121
        cur.execute("CREATE TABLE IF NOT EXISTS public.versions (deployment text, version integer)")
122

123
        cur.execute('SET search_path TO {},public'.format(get_conf().deployment_name))
124

125 126 127
        if force:
            cur.execute("DELETE FROM public.versions WHERE deployment = %s", (get_conf().deployment_name,))
            cur.execute('DROP SCHEMA IF EXISTS {} CASCADE'.format(get_conf().deployment_name))
128

129 130
        if not self._check_schema_version(cur, get_conf().deployment_name):
            self._create_tables()
131

132
        self.commit()
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        cur.close()

    def _check_schema_version(self, cur, deployment_name):
        """Check if the schema version matches this source code version."""
        cur.execute("SELECT version FROM public.versions WHERE deployment = %s", (deployment_name,))
        row = cur.fetchone()
        if row is None:
            cur.execute("INSERT INTO public.versions (deployment, version) VALUES (%s, %s)", (deployment_name, SQL_SCHEMA_VERSION))
            cur.execute("SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_namespace WHERE nspname = %s)", (deployment_name,))
            if not cur.fetchone()[0]:
                cur.execute('CREATE SCHEMA {}'.format(deployment_name))
            return False  # Tables need to be created
        else:
            if row[0] == SQL_SCHEMA_VERSION:
                return True
            else:
                raise zoe_lib.exceptions.ZoeLibException('SQL database schema version mismatch: need {}, found {}'.format(SQL_SCHEMA_VERSION, row[0]))