Fix database scheme and introduce automatic database migration
This commit is contained in:
@@ -7,7 +7,7 @@ import argon2
|
||||
|
||||
from .dns import DNSService
|
||||
from .email import EmailService
|
||||
from .models import init_database
|
||||
from .models import create_tables, init_database
|
||||
from .ratelimit import RateLimiter
|
||||
|
||||
|
||||
@@ -35,8 +35,9 @@ class Application:
|
||||
self.rate_limiter = None
|
||||
|
||||
def init_database(self):
|
||||
"""Initialize database connection."""
|
||||
"""Initialize database connection and run migrations."""
|
||||
init_database(self.config)
|
||||
create_tables()
|
||||
logging.debug("Database initialized")
|
||||
|
||||
def init_dns(self):
|
||||
|
||||
@@ -15,13 +15,6 @@ from .models import (
|
||||
from .validation import encode_hostname, encode_zone, ValidationError
|
||||
|
||||
|
||||
def cmd_init_db(args, app):
|
||||
"""Initialize database tables."""
|
||||
create_tables()
|
||||
print("Database tables created.")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_user_list(args, app):
|
||||
"""List all users."""
|
||||
users = User.select()
|
||||
@@ -198,9 +191,11 @@ def cmd_hostname_add(args, app):
|
||||
print(f"Error: User '{username}' not found.")
|
||||
return 1
|
||||
|
||||
# Check if hostname exists
|
||||
if Hostname.select().where(Hostname.hostname == hostname_str).exists():
|
||||
print(f"Error: Hostname '{hostname_str}' already exists.")
|
||||
# Check if hostname+zone exists
|
||||
if Hostname.select().where(
|
||||
(Hostname.hostname == hostname_str) & (Hostname.zone == zone)
|
||||
).exists():
|
||||
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' exists.")
|
||||
return 1
|
||||
|
||||
# Create hostname
|
||||
@@ -217,17 +212,18 @@ def cmd_hostname_add(args, app):
|
||||
|
||||
def cmd_hostname_delete(args, app):
|
||||
"""Delete a hostname."""
|
||||
# Validate and encode hostname
|
||||
# Validate and encode hostname and zone
|
||||
try:
|
||||
hostname_str = encode_hostname(args.hostname)
|
||||
zone = encode_zone(args.zone)
|
||||
except ValidationError as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
hostname = get_hostname(hostname_str)
|
||||
hostname = get_hostname(hostname_str, zone)
|
||||
except DoesNotExist:
|
||||
print(f"Error: Hostname '{hostname_str}' not found.")
|
||||
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' not found.")
|
||||
return 1
|
||||
|
||||
# Delete DNS records if active
|
||||
@@ -262,17 +258,18 @@ def cmd_hostname_delete(args, app):
|
||||
|
||||
def cmd_hostname_modify(args, app):
|
||||
"""Modify hostname settings."""
|
||||
# Validate and encode hostname
|
||||
# Validate and encode hostname and zone
|
||||
try:
|
||||
hostname_str = encode_hostname(args.hostname)
|
||||
zone = encode_zone(args.zone)
|
||||
except ValidationError as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
hostname = get_hostname(hostname_str)
|
||||
hostname = get_hostname(hostname_str, zone)
|
||||
except DoesNotExist:
|
||||
print(f"Error: Hostname '{hostname_str}' not found.")
|
||||
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' not found.")
|
||||
return 1
|
||||
|
||||
# Get new TTLs
|
||||
|
||||
@@ -15,7 +15,6 @@ from .cli import (
|
||||
cmd_hostname_delete,
|
||||
cmd_hostname_list,
|
||||
cmd_hostname_modify,
|
||||
cmd_init_db,
|
||||
cmd_user_add,
|
||||
cmd_user_delete,
|
||||
cmd_user_email,
|
||||
@@ -94,12 +93,14 @@ def build_parser():
|
||||
"delete", help="Delete hostname"
|
||||
)
|
||||
hostname_delete.add_argument("hostname", help="Hostname (FQDN)")
|
||||
hostname_delete.add_argument("zone", help="DNS zone")
|
||||
hostname_delete.set_defaults(func=cmd_hostname_delete)
|
||||
|
||||
hostname_modify = hostname_subparsers.add_parser(
|
||||
"modify", help="Modify hostname"
|
||||
)
|
||||
hostname_modify.add_argument("hostname", help="Hostname (FQDN)")
|
||||
hostname_modify.add_argument("zone", help="DNS zone")
|
||||
hostname_modify.add_argument("--dns-ttl", type=int, help="DNS record TTL")
|
||||
hostname_modify.add_argument("--expiry-ttl", type=int, help="Expiry TTL")
|
||||
hostname_modify.set_defaults(func=cmd_hostname_modify)
|
||||
@@ -163,7 +164,7 @@ def main():
|
||||
|
||||
# Handle --init-db
|
||||
if args.init_db:
|
||||
return cmd_init_db(args, app)
|
||||
return
|
||||
|
||||
# Handle --daemon
|
||||
if args.daemon:
|
||||
|
||||
@@ -20,7 +20,7 @@ from peewee import (
|
||||
db = SqliteDatabase(None)
|
||||
|
||||
# Current database schema version
|
||||
DATABASE_VERSION = 1
|
||||
DATABASE_VERSION = 2
|
||||
|
||||
|
||||
class BaseModel(Model):
|
||||
@@ -48,7 +48,7 @@ class Hostname(BaseModel):
|
||||
|
||||
id = AutoField()
|
||||
user = ForeignKeyField(User, backref="hostnames", on_delete="RESTRICT")
|
||||
hostname = CharField(max_length=255, unique=True)
|
||||
hostname = CharField(max_length=255)
|
||||
zone = CharField(max_length=255)
|
||||
dns_ttl = IntegerField()
|
||||
expiry_ttl = IntegerField()
|
||||
@@ -59,6 +59,9 @@ class Hostname(BaseModel):
|
||||
|
||||
class Meta:
|
||||
table_name = "hostnames"
|
||||
indexes = (
|
||||
(('hostname', 'zone'), True),
|
||||
)
|
||||
|
||||
|
||||
class Version(BaseModel):
|
||||
@@ -112,11 +115,78 @@ def init_database(config: dict):
|
||||
db.connect()
|
||||
|
||||
|
||||
def _migrate_v1_to_v2_sqlite():
|
||||
"""SQLite: recreate table (no ALTER TABLE for constraints)."""
|
||||
db.execute_sql('PRAGMA foreign_keys=OFF')
|
||||
db.execute_sql('''
|
||||
CREATE TABLE hostnames_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE RESTRICT,
|
||||
hostname VARCHAR(255) NOT NULL,
|
||||
zone VARCHAR(255) NOT NULL,
|
||||
dns_ttl INTEGER NOT NULL,
|
||||
expiry_ttl INTEGER NOT NULL,
|
||||
last_ipv4 VARCHAR(15),
|
||||
last_ipv4_update DATETIME,
|
||||
last_ipv6 VARCHAR(45),
|
||||
last_ipv6_update DATETIME,
|
||||
UNIQUE(hostname, zone)
|
||||
)
|
||||
''')
|
||||
db.execute_sql('''
|
||||
INSERT INTO hostnames_new
|
||||
SELECT id, user_id, hostname, zone, dns_ttl, expiry_ttl,
|
||||
last_ipv4, last_ipv4_update, last_ipv6, last_ipv6_update
|
||||
FROM hostnames
|
||||
''')
|
||||
db.execute_sql('DROP TABLE hostnames')
|
||||
db.execute_sql('ALTER TABLE hostnames_new RENAME TO hostnames')
|
||||
db.execute_sql('PRAGMA foreign_keys=ON')
|
||||
|
||||
|
||||
def _migrate_v1_to_v2_mariadb():
|
||||
"""MariaDB: ALTER TABLE to change constraints."""
|
||||
db.execute_sql('ALTER TABLE hostnames DROP INDEX hostnames_hostname')
|
||||
db.execute_sql(
|
||||
'ALTER TABLE hostnames ADD UNIQUE INDEX '
|
||||
'hostnames_hostname_zone (hostname, zone)'
|
||||
)
|
||||
|
||||
|
||||
def migrate_v1_to_v2():
|
||||
"""Migrate v1 to v2: hostname+zone unique constraint."""
|
||||
backend = db.__class__.__name__
|
||||
|
||||
if backend == 'SqliteDatabase':
|
||||
_migrate_v1_to_v2_sqlite()
|
||||
else:
|
||||
_migrate_v1_to_v2_mariadb()
|
||||
|
||||
Version.update(version=2).execute()
|
||||
logging.info("Database migrated from v1 to v2")
|
||||
|
||||
|
||||
def check_and_migrate():
|
||||
"""Check DB version and run migrations if needed."""
|
||||
try:
|
||||
version_row = Version.get()
|
||||
current = version_row.version
|
||||
except DoesNotExist:
|
||||
return
|
||||
|
||||
if current < 2:
|
||||
logging.info("Migrating DB: v%d -> v%d", current, DATABASE_VERSION)
|
||||
migrate_v1_to_v2()
|
||||
|
||||
|
||||
def create_tables():
|
||||
"""Create database tables if they don't exist."""
|
||||
if db.table_exists('version'):
|
||||
check_and_migrate()
|
||||
return
|
||||
|
||||
db.create_tables([User, Hostname, Version])
|
||||
if Version.select().count() == 0:
|
||||
Version.create(version=DATABASE_VERSION)
|
||||
Version.create(version=DATABASE_VERSION)
|
||||
logging.debug("Database tables created")
|
||||
|
||||
|
||||
@@ -136,12 +206,13 @@ def get_user(username: str):
|
||||
return User.get(User.username == username)
|
||||
|
||||
|
||||
def get_hostname(hostname: str):
|
||||
def get_hostname(hostname, zone):
|
||||
"""
|
||||
Get hostname by name.
|
||||
Get hostname by name and zone.
|
||||
|
||||
Args:
|
||||
hostname: Hostname to look up.
|
||||
zone: DNS zone.
|
||||
|
||||
Returns:
|
||||
Hostname instance.
|
||||
@@ -149,7 +220,9 @@ def get_hostname(hostname: str):
|
||||
Raises:
|
||||
DoesNotExist: If hostname not found.
|
||||
"""
|
||||
return Hostname.get(Hostname.hostname == hostname)
|
||||
return Hostname.get(
|
||||
(Hostname.hostname == hostname) & (Hostname.zone == zone)
|
||||
)
|
||||
|
||||
|
||||
def get_hostname_for_user(hostname: str, user: User):
|
||||
|
||||
Reference in New Issue
Block a user