diff --git a/app.py b/app.py index e4862e0..775e2c9 100644 --- a/app.py +++ b/app.py @@ -51,12 +51,12 @@ def contains_banned_words(text): return True return False -def init_db(): - # TODO: No schema versioning — adding columns in the future requires manual DB updates. Consider a migration tool (e.g. Alembic). - conn = sqlite3.connect(DATABASE) - c = conn.cursor() - c.execute(''' - CREATE TABLE IF NOT EXISTS guests ( +# Each entry is a list of SQL statements for that schema version. +# To add a column or index in the future, append a new list — never modify existing entries. +MIGRATIONS = [ + # v1 — initial schema + [ + '''CREATE TABLE IF NOT EXISTS guests ( id INTEGER PRIMARY KEY AUTOINCREMENT, first_name TEXT NOT NULL, last_name TEXT NOT NULL, @@ -65,13 +65,40 @@ def init_db(): comment TEXT, newsletter_opt_in BOOLEAN DEFAULT 1, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP - ) - ''') - c.execute('CREATE INDEX IF NOT EXISTS idx_guests_id ON guests (id DESC)') - c.execute('CREATE INDEX IF NOT EXISTS idx_guests_email ON guests (email)') - conn.commit() + )''', + 'CREATE INDEX IF NOT EXISTS idx_guests_id ON guests (id DESC)', + 'CREATE INDEX IF NOT EXISTS idx_guests_email ON guests (email)', + ], +] + +def migrate_db(): + conn = sqlite3.connect(DATABASE) + c = conn.cursor() + + # Bootstrap the version table and seed it at 0 if empty + c.execute('CREATE TABLE IF NOT EXISTS schema_version (version INTEGER NOT NULL)') + if c.execute('SELECT COUNT(*) FROM schema_version').fetchone()[0] == 0: + c.execute('INSERT INTO schema_version VALUES (0)') + conn.commit() + + current = c.execute('SELECT version FROM schema_version').fetchone()[0] + pending = MIGRATIONS[current:] + + if not pending: + logger.info("Database schema is up to date at v%d.", current) + conn.close() + return + + for statements in pending: + current += 1 + logger.info("Applying migration v%d...", current) + for sql in statements: + c.execute(sql) + c.execute('UPDATE schema_version SET version = ?', (current,)) + conn.commit() + + logger.info("Database migrated to v%d.", current) conn.close() - logger.info("Database initialized.") def is_valid_email(email): try: @@ -81,7 +108,7 @@ def is_valid_email(email): return False with app.app_context(): - init_db() + migrate_db() @app.route('/', methods=['GET', 'POST']) @limiter.limit("5 per minute", methods=["POST"]) @@ -186,6 +213,6 @@ def api_guests(): return jsonify(guests) if __name__ == '__main__': - init_db() + migrate_db() logger.info("Starting development server at http://0.0.0.0:8000") app.run(host='0.0.0.0', port=8000)