db_init.py 4.54 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2016, Daniele Venzano
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16 17
"""Database initialization."""

18 19 20
import psycopg2
import psycopg2.extras

21
import zoe_api.exceptions
22
from zoe_lib.config import get_conf
23

24
SQL_SCHEMA_VERSION = 4  # ---> Increment this value every time the schema changes !!! <---
25 26 27


def version_table(cur):
28
    """Create the version table."""
29 30 31 32
    cur.execute("CREATE TABLE IF NOT EXISTS public.versions (deployment text, version integer)")


def schema(cur, deployment_name):
33
    """Create the schema for the configured deployment name."""
34 35
    cur.execute("SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_namespace WHERE nspname = %s)", (deployment_name,))
    if not cur.fetchone()[0]:
36
        cur.execute('CREATE SCHEMA {}'.format(deployment_name))
37 38 39


def check_schema_version(cur, deployment_name):
40
    """Check if the schema version matches this source code version."""
41 42 43 44 45 46 47 48 49 50
    cur.execute("SELECT version FROM public.versions WHERE deployment = %s", (deployment_name,))
    row = cur.fetchone()
    if row is None:
        cur.execute("INSERT INTO public.versions (deployment, version) VALUES (%s, %s)", (deployment_name, SQL_SCHEMA_VERSION))
        schema(cur, deployment_name)
        return False  # Tables need to be created
    else:
        if row[0] == SQL_SCHEMA_VERSION:
            return True
        else:
51
            raise zoe_api.exceptions.ZoeException('SQL database schema version mismatch: need {}, found {}'.format(SQL_SCHEMA_VERSION, row[0]))
52 53 54


def create_tables(cur):
55
    """Create the Zoe database tables."""
56 57 58
    cur.execute('''CREATE TABLE execution (
        id SERIAL PRIMARY KEY,
        name TEXT NOT NULL,
59
        user_id TEXT NOT NULL,
60 61 62 63 64 65 66 67 68 69 70
        description JSON NOT NULL,
        status TEXT NOT NULL,
        execution_manager_id TEXT NULL,
        time_submit TIMESTAMP NOT NULL,
        time_start TIMESTAMP NULL,
        time_end TIMESTAMP NULL,
        error_message TEXT NULL
        )''')
    cur.execute('''CREATE TABLE service (
        id SERIAL PRIMARY KEY,
        status TEXT NOT NULL,
71 72
        error_message TEXT NULL DEFAULT NULL,
        description JSON NOT NULL,
73
        execution_id INT REFERENCES execution ON DELETE CASCADE,
74
        service_group TEXT NOT NULL,
75
        name TEXT NOT NULL,
76 77
        backend_id TEXT NULL DEFAULT NULL,
        backend_status TEXT NOT NULL DEFAULT 'undefined',
78 79
        ip_address CIDR NULL DEFAULT NULL,
        essential BOOLEAN NOT NULL DEFAULT FALSE
80
        )''')
81 82
    cur.execute('''CREATE TABLE port (
        id SERIAL PRIMARY KEY,
83
        service_id INT REFERENCES service ON DELETE CASCADE,
84 85 86 87 88
        internal_name TEXT NOT NULL,
        external_ip INET NULL,
        external_port INT NULL,
        description JSON NOT NULL
    )''')
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    #Create oauth_client and oauth_token tables for oAuth2
    cur.execute('''CREATE TABLE oauth_client (
        identifier TEXT PRIMARY KEY,
        secret TEXT,
        role TEXT,
        redirect_uris TEXT,
        authorized_grants TEXT,
        authorized_response_types TEXT
        )''')
    cur.execute('''CREATE TABLE oauth_token (
        client_id TEXT PRIMARY KEY,
        grant_type TEXT,
        token TEXT,
        data TEXT,
        expires_at TIMESTAMP,
        refresh_token TEXT,
        refresh_token_expires_at TIMESTAMP,
        scopes TEXT,
        user_id TEXT
        )''')
109

110

111
def init(force=False):
112
    """DB init entrypoint."""
113 114 115 116 117 118 119 120 121 122 123
    dsn = 'dbname=' + get_conf().dbname + \
        ' user=' + get_conf().dbuser + \
        ' password=' + get_conf().dbpass + \
        ' host=' + get_conf().dbhost + \
        ' port=' + str(get_conf().dbport)

    conn = psycopg2.connect(dsn)
    cur = conn.cursor()

    version_table(cur)
    cur.execute('SET search_path TO {},public'.format(get_conf().deployment_name))
124 125 126

    if force:
        cur.execute("DELETE FROM public.versions WHERE deployment = %s", (get_conf().deployment_name,))
127
        cur.execute('DROP SCHEMA IF EXISTS {} CASCADE'.format(get_conf().deployment_name))
128

129 130 131 132 133 134 135
    if not check_schema_version(cur, get_conf().deployment_name):
        create_tables(cur)

    conn.commit()
    cur.close()
    conn.close()
    return