diff --git a/README.md b/README.md index 5703bf2..e6c396f 100644 --- a/README.md +++ b/README.md @@ -317,6 +317,7 @@ wget -qO- --user=username --password=password \ - Rate limiting protects against brute-force attacks - Database file should have restricted permissions - Consider fail2ban for additional protection +- **Database backups**: Create recurring backups before upgrades. Migration errors require manual recovery from backup. ## TTL Behavior diff --git a/src/ddns_service/models.py b/src/ddns_service/models.py index d2ec8f1..ab9d856 100644 --- a/src/ddns_service/models.py +++ b/src/ddns_service/models.py @@ -22,6 +22,25 @@ db = SqliteDatabase(None) # Current database schema version DATABASE_VERSION = 2 +# Migration column mappings: key = target version +# Values: {table: {old_col: new_col}} - None value = drop column +MIGRATION_COLUMN_MAPS = { + 2: { + 'hostnames': { + 'id': 'id', + 'user_id': 'user_id', + 'hostname': 'hostname', + 'zone': 'zone', + 'dns_ttl': 'dns_ttl', + 'expiry_ttl': 'expiry_ttl', + 'last_ipv4': 'last_ipv4', + 'last_ipv4_update': 'last_ipv4_update', + 'last_ipv6': 'last_ipv6', + 'last_ipv6_update': 'last_ipv6_update', + } + } +} + class BaseModel(Model): """Base model with database binding.""" @@ -73,6 +92,13 @@ class Version(BaseModel): table_name = "version" +# Table name to model class mapping for migrations +TABLE_TO_MODEL = { + 'users': User, + 'hostnames': Hostname, +} + + def init_database(config: dict): """ Initialize database connection based on config. @@ -107,7 +133,8 @@ def init_database(config: dict): User._meta.database = db Hostname._meta.database = db Version._meta.database = db - logging.debug(f"Database backend: MariaDB db={config['database']['database']}") + db_name = config['database']['database'] + logging.debug(f"Database backend: MariaDB db={db_name}") else: raise ValueError(f"Unknown database backend: {backend}") @@ -115,55 +142,73 @@ def init_database(config: dict): db.connect() -def _migrate_v1_to_v2_sqlite(): - """SQLite: recreate table (no ALTER TABLE for constraints).""" - db.execute_sql('PRAGMA foreign_keys=OFF') - db.execute_sql(''' - CREATE TABLE hostnames_new ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE RESTRICT, - hostname VARCHAR(255) NOT NULL, - zone VARCHAR(255) NOT NULL, - dns_ttl INTEGER NOT NULL, - expiry_ttl INTEGER NOT NULL, - last_ipv4 VARCHAR(15), - last_ipv4_update DATETIME, - last_ipv6 VARCHAR(45), - last_ipv6_update DATETIME, - UNIQUE(hostname, zone) - ) - ''') - db.execute_sql(''' - INSERT INTO hostnames_new - SELECT id, user_id, hostname, zone, dns_ttl, expiry_ttl, - last_ipv4, last_ipv4_update, last_ipv6, last_ipv6_update - FROM hostnames - ''') - db.execute_sql('DROP TABLE hostnames') - db.execute_sql('ALTER TABLE hostnames_new RENAME TO hostnames') - db.execute_sql('PRAGMA foreign_keys=ON') +def _migrate_table_sqlite(model_class, from_version: int, column_map: dict): + """ + Migrate a single SQLite table using Peewee model for schema. + Args: + model_class: Peewee Model class for new schema. + from_version: Version we're migrating from (for backup suffix). + column_map: {old_column: new_column} mapping, None = drop column. + """ + table_name = model_class._meta.table_name + backup_name = f"{table_name}_{from_version}" -def _migrate_v1_to_v2_mariadb(): - """MariaDB: ALTER TABLE to change constraints.""" - db.execute_sql('ALTER TABLE hostnames DROP INDEX hostnames_hostname') + # Rename existing table to backup + db.execute_sql(f'ALTER TABLE "{table_name}" RENAME TO "{backup_name}"') + + # Drop indexes on backup table (they keep original names after rename) + cursor = db.execute_sql( + "SELECT name FROM sqlite_master WHERE type='index' " + f"AND tbl_name='{backup_name}' AND name NOT LIKE 'sqlite_%'" + ) + for (index_name,) in cursor.fetchall(): + db.execute_sql(f'DROP INDEX "{index_name}"') + + # Create new table using Peewee (includes all indices/constraints) + db.create_tables([model_class]) + + # Build column lists for INSERT + old_cols = [] + new_cols = [] + for old, new in column_map.items(): + if new is not None: + old_cols.append(f'"{old}"') + new_cols.append(f'"{new}"') + + old_cols_str = ', '.join(old_cols) + new_cols_str = ', '.join(new_cols) + + # Copy data from backup to new table db.execute_sql( - 'ALTER TABLE hostnames ADD UNIQUE INDEX ' - 'hostnames_hostname_zone (hostname, zone)' + f'INSERT INTO "{table_name}" ({new_cols_str}) ' + f'SELECT {old_cols_str} FROM "{backup_name}"' ) + # Drop backup table + db.execute_sql(f'DROP TABLE "{backup_name}"') -def migrate_v1_to_v2(): - """Migrate v1 to v2: hostname+zone unique constraint.""" - backend = db.__class__.__name__ - if backend == 'SqliteDatabase': - _migrate_v1_to_v2_sqlite() - else: - _migrate_v1_to_v2_mariadb() +def _migrate_sqlite(from_version: int, to_version: int): + """Migrate SQLite from from_version to to_version.""" + db.execute_sql('PRAGMA foreign_keys=OFF') + try: + tables = MIGRATION_COLUMN_MAPS[to_version] + for table_name, column_map in tables.items(): + model = TABLE_TO_MODEL[table_name] + _migrate_table_sqlite(model, from_version, column_map) + finally: + db.execute_sql('PRAGMA foreign_keys=ON') - Version.update(version=2).execute() - logging.info("Database migrated from v1 to v2") + +def _migrate_mariadb(to_version: int): + """Migrate MariaDB to target version using ALTER TABLE.""" + if to_version == 2: + db.execute_sql('ALTER TABLE hostnames DROP INDEX hostnames_hostname') + db.execute_sql( + 'ALTER TABLE hostnames ADD UNIQUE INDEX ' + 'hostnames_hostname_zone (hostname, zone)' + ) def check_and_migrate(): @@ -174,9 +219,22 @@ def check_and_migrate(): except DoesNotExist: return - if current < 2: - logging.info("Migrating DB: v%d -> v%d", current, DATABASE_VERSION) - migrate_v1_to_v2() + if current >= DATABASE_VERSION: + return + + backend = db.__class__.__name__ + + for target in range(current + 1, DATABASE_VERSION + 1): + logging.info("Migrating DB: v%d -> v%d", target - 1, target) + + if backend == 'SqliteDatabase': + _migrate_sqlite(target - 1, target) + else: + _migrate_mariadb(target) + + Version.update(version=target).execute() + + logging.info("Database migration complete") def create_tables(): @@ -239,9 +297,8 @@ def get_hostname_for_user(hostname: str, user: User): Raises: DoesNotExist: If hostname not found or not owned by user. """ - return Hostname.get( - ((Hostname.hostname + '.' + Hostname.zone) == hostname) & (Hostname.user == user) - ) + fqdn = Hostname.hostname + '.' + Hostname.zone + return Hostname.get((fqdn == hostname) & (Hostname.user == user)) # Re-export DoesNotExist for convenience