use tokio_postgres::Client; pub fn tenant_role_name(bouncer_name: &str) -> String { format!("soju_{}", bouncer_name.replace('-', "_")) } pub fn tenant_db_name(bouncer_name: &str) -> String { format!("soju_{}", bouncer_name.replace('-', "_")) } fn quote_literal(s: &str) -> String { format!("'{}'", s.replace('\'', "''")) } pub fn create_role_sql(bouncer_name: &str, password: &str) -> String { let role = tenant_role_name(bouncer_name); format!("CREATE ROLE \"{role}\" WITH LOGIN PASSWORD {}", quote_literal(password)) } pub fn create_database_sql(bouncer_name: &str) -> String { let role = tenant_role_name(bouncer_name); let db = tenant_db_name(bouncer_name); format!("CREATE DATABASE \"{db}\" OWNER \"{role}\"") } pub fn drop_database_sql(bouncer_name: &str) -> String { let db = tenant_db_name(bouncer_name); format!("DROP DATABASE IF EXISTS \"{db}\"") } pub fn drop_role_sql(bouncer_name: &str) -> String { let role = tenant_role_name(bouncer_name); format!("DROP ROLE IF EXISTS \"{role}\"") } pub fn terminate_connections_sql(bouncer_name: &str) -> String { let db = tenant_db_name(bouncer_name); format!( "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {}", quote_literal(&db) ) } pub fn build_tenant_uri(host: &str, port: u16, bouncer_name: &str, password: &str) -> String { let role = tenant_role_name(bouncer_name); let db = tenant_db_name(bouncer_name); let sslmode = std::env::var("SOJU_DB_SSLMODE").unwrap_or_else(|_| "require".to_string()); format!("host={host} port={port} user={role} password={password} dbname={db} sslmode={sslmode}") } pub fn generate_password() -> String { use rand::Rng; use rand::distr::Alphanumeric; rand::rng() .sample_iter(&Alphanumeric) .take(32) .map(char::from) .collect() } pub async fn provision_tenant_db( client: &Client, bouncer_name: &str, password: &str, ) -> Result<(), Box> { let role = tenant_role_name(bouncer_name); let create_role = format!( "DO $$ BEGIN CREATE ROLE \"{}\" WITH LOGIN PASSWORD {}; EXCEPTION WHEN duplicate_object THEN NULL; END $$", role, quote_literal(password) ); client.batch_execute(&create_role).await?; let db = tenant_db_name(bouncer_name); let row = client .query_one( "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)", &[&db], ) .await?; let exists: bool = row.get(0); if !exists { client.batch_execute(&create_database_sql(bouncer_name)).await?; } Ok(()) } pub async fn deprovision_tenant_db( client: &Client, bouncer_name: &str, ) -> Result<(), Box> { client.batch_execute(&terminate_connections_sql(bouncer_name)).await?; client.batch_execute(&drop_database_sql(bouncer_name)).await?; client.batch_execute(&drop_role_sql(bouncer_name)).await?; Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn tenant_role_name_replaces_hyphens() { assert_eq!(tenant_role_name("my-bouncer"), "soju_my_bouncer"); } #[test] fn tenant_db_name_replaces_hyphens() { assert_eq!(tenant_db_name("my-bouncer"), "soju_my_bouncer"); } #[test] fn create_role_sql_uses_quoted_identifier() { let sql = create_role_sql("my-bouncer", "secretpass"); assert!(sql.contains("\"soju_my_bouncer\"")); assert!(sql.contains("PASSWORD")); } #[test] fn create_database_sql_sets_owner() { let sql = create_database_sql("my-bouncer"); assert!(sql.contains("OWNER \"soju_my_bouncer\"")); } #[test] fn drop_database_sql_uses_if_exists() { let sql = drop_database_sql("my-bouncer"); assert!(sql.contains("DROP DATABASE IF EXISTS \"soju_my_bouncer\"")); } #[test] fn drop_role_sql_uses_if_exists() { let sql = drop_role_sql("my-bouncer"); assert!(sql.contains("DROP ROLE IF EXISTS \"soju_my_bouncer\"")); } #[test] fn terminate_connections_sql_targets_correct_db() { let sql = terminate_connections_sql("my-bouncer"); assert!(sql.contains("soju_my_bouncer")); assert!(sql.contains("pg_terminate_backend")); } #[test] fn build_tenant_uri_has_all_fields() { let uri = build_tenant_uri("db.svc", 5432, "my-bouncer", "generated-pass"); assert!(uri.contains("user=soju_my_bouncer")); assert!(uri.contains("dbname=soju_my_bouncer")); assert!(uri.contains("password=generated-pass")); assert!(uri.contains("host=db.svc")); assert!(uri.contains("port=5432")); assert!(uri.contains("sslmode=require")); } #[test] fn quote_literal_escapes_single_quotes() { assert_eq!(quote_literal("it's"), "'it''s'"); } #[test] fn quote_literal_no_quotes() { assert_eq!(quote_literal("simple"), "'simple'"); } }