use tokio_postgres::Client; #[derive(Debug, thiserror::Error)] pub enum MigrateError { #[error("user not found in source: {0}")] UserNotFound(String), #[error("network not found: {0}")] NetworkNotFound(String), #[error("database error: {0}")] Db(#[from] tokio_postgres::Error), } struct UserRow { username: String, password: Option, admin: bool, nick: Option, realname: Option, enabled: bool, max_networks: i32, } struct NetworkRow { id: i32, name: Option, addr: String, nick: Option, username: Option, realname: Option, certfp: Option, pass: Option, connect_commands: Option, sasl_mechanism: Option, sasl_plain_username: Option, sasl_plain_password: Option, sasl_external_cert: Option>, sasl_external_key: Option>, auto_away: bool, enabled: bool, } pub async fn migrate_user( source: &Client, target: &Client, username: &str, ) -> Result<(), MigrateError> { let user = fetch_user(source, username).await?; let new_user_id = insert_user(target, &user).await?; let networks = fetch_networks(source, username).await?; for net in &networks { let new_net_id = insert_network(target, new_user_id, net).await?; copy_channels(source, target, net.id, new_net_id).await?; copy_delivery_receipts(source, target, net.id, new_net_id).await?; copy_read_receipts(source, target, net.id, new_net_id).await?; copy_web_push_subscriptions_for_network(source, target, new_user_id, net.id, new_net_id) .await?; } copy_web_push_subscriptions_user_only(source, target, username, new_user_id).await?; source .execute(r#"DELETE FROM "User" WHERE username = $1"#, &[&username]) .await?; Ok(()) } pub async fn reverse_migrate_user( source: &Client, target: &Client, username: &str, keep_network: &str, ) -> Result<(), MigrateError> { let user = fetch_user(source, username).await?; let new_user_id = insert_user(target, &user).await?; let networks = fetch_networks(source, username).await?; let net = networks .iter() .find(|n| { n.name.as_deref() == Some(keep_network) || n.addr == keep_network }) .ok_or_else(|| MigrateError::NetworkNotFound(keep_network.to_string()))?; let new_net_id = insert_network(target, new_user_id, net).await?; copy_channels(source, target, net.id, new_net_id).await?; copy_delivery_receipts(source, target, net.id, new_net_id).await?; copy_read_receipts(source, target, net.id, new_net_id).await?; copy_web_push_subscriptions_for_network(source, target, new_user_id, net.id, new_net_id) .await?; source .execute(r#"DELETE FROM "User" WHERE username = $1"#, &[&username]) .await?; Ok(()) } pub async fn remove_user(client: &Client, username: &str) -> Result<(), MigrateError> { client .execute(r#"DELETE FROM "User" WHERE username = $1"#, &[&username]) .await?; Ok(()) } async fn fetch_user(client: &Client, username: &str) -> Result { let row = client .query_opt( r#"SELECT username, password, admin, nick, realname, enabled, max_networks FROM "User" WHERE username = $1"#, &[&username], ) .await? .ok_or_else(|| MigrateError::UserNotFound(username.to_string()))?; Ok(UserRow { username: row.get("username"), password: row.get("password"), admin: row.get("admin"), nick: row.get("nick"), realname: row.get("realname"), enabled: row.get("enabled"), max_networks: row.get("max_networks"), }) } async fn insert_user(client: &Client, user: &UserRow) -> Result { let row = client .query_one( r#"INSERT INTO "User" (username, password, admin, nick, realname, enabled, max_networks) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id"#, &[ &user.username, &user.password, &user.admin, &user.nick, &user.realname, &user.enabled, &user.max_networks, ], ) .await?; Ok(row.get("id")) } async fn fetch_networks( client: &Client, username: &str, ) -> Result, MigrateError> { let rows = client .query( r#"SELECT n.id, n.name, n.addr, n.nick, n.username, n.realname, n.certfp, n.pass, n.connect_commands, n.sasl_mechanism::text, n.sasl_plain_username, n.sasl_plain_password, n.sasl_external_cert, n.sasl_external_key, n.auto_away, n.enabled FROM "Network" n JOIN "User" u ON n."user" = u.id WHERE u.username = $1 ORDER BY n.id"#, &[&username], ) .await?; Ok(rows .iter() .map(|r| NetworkRow { id: r.get("id"), name: r.get("name"), addr: r.get("addr"), nick: r.get("nick"), username: r.get("username"), realname: r.get("realname"), certfp: r.get("certfp"), pass: r.get("pass"), connect_commands: r.get("connect_commands"), sasl_mechanism: r.get(9), sasl_plain_username: r.get("sasl_plain_username"), sasl_plain_password: r.get("sasl_plain_password"), sasl_external_cert: r.get("sasl_external_cert"), sasl_external_key: r.get("sasl_external_key"), auto_away: r.get("auto_away"), enabled: r.get("enabled"), }) .collect()) } async fn insert_network( client: &Client, user_id: i32, net: &NetworkRow, ) -> Result { let row = client .query_one( r#"INSERT INTO "Network" ("user", name, addr, nick, username, realname, certfp, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10::sasl_mechanism, $11, $12, $13, $14, $15, $16) RETURNING id"#, &[ &user_id, &net.name, &net.addr, &net.nick, &net.username, &net.realname, &net.certfp, &net.pass, &net.connect_commands, &net.sasl_mechanism, &net.sasl_plain_username, &net.sasl_plain_password, &net.sasl_external_cert, &net.sasl_external_key, &net.auto_away, &net.enabled, ], ) .await?; Ok(row.get("id")) } async fn copy_channels( source: &Client, target: &Client, src_net_id: i32, dst_net_id: i32, ) -> Result<(), MigrateError> { let rows = source .query( r#"SELECT name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on FROM "Channel" WHERE network = $1"#, &[&src_net_id], ) .await?; for r in &rows { let name: String = r.get("name"); let key: Option = r.get("key"); let detached: bool = r.get("detached"); let detached_internal_msgid: Option = r.get("detached_internal_msgid"); let relay_detached: i32 = r.get("relay_detached"); let reattach_on: i32 = r.get("reattach_on"); let detach_after: i32 = r.get("detach_after"); let detach_on: i32 = r.get("detach_on"); target .execute( r#"INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#, &[ &dst_net_id, &name, &key, &detached, &detached_internal_msgid, &relay_detached, &reattach_on, &detach_after, &detach_on, ], ) .await?; } Ok(()) } async fn copy_delivery_receipts( source: &Client, target: &Client, src_net_id: i32, dst_net_id: i32, ) -> Result<(), MigrateError> { let rows = source .query( r#"SELECT target, client, internal_msgid FROM "DeliveryReceipt" WHERE network = $1"#, &[&src_net_id], ) .await?; for r in &rows { let receipt_target: String = r.get("target"); let receipt_client: String = r.get("client"); let internal_msgid: String = r.get("internal_msgid"); target .execute( r#"INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid) VALUES ($1, $2, $3, $4)"#, &[&dst_net_id, &receipt_target, &receipt_client, &internal_msgid], ) .await?; } Ok(()) } async fn copy_read_receipts( source: &Client, target: &Client, src_net_id: i32, dst_net_id: i32, ) -> Result<(), MigrateError> { let rows = source .query( r#"SELECT target, timestamp FROM "ReadReceipt" WHERE network = $1"#, &[&src_net_id], ) .await?; for r in &rows { let receipt_target: String = r.get("target"); let timestamp: chrono::DateTime = r.get("timestamp"); target .execute( r#"INSERT INTO "ReadReceipt" (network, target, timestamp) VALUES ($1, $2, $3)"#, &[&dst_net_id, &receipt_target, ×tamp], ) .await?; } Ok(()) } async fn copy_web_push_subscriptions_for_network( source: &Client, target: &Client, dst_user_id: i32, src_net_id: i32, dst_net_id: i32, ) -> Result<(), MigrateError> { let rows = source .query( r#"SELECT created_at, updated_at, endpoint, key_vapid, key_auth, key_p256dh FROM "WebPushSubscription" WHERE network = $1"#, &[&src_net_id], ) .await?; for r in &rows { let created_at: chrono::DateTime = r.get("created_at"); let updated_at: chrono::DateTime = r.get("updated_at"); let endpoint: String = r.get("endpoint"); let key_vapid: Option = r.get("key_vapid"); let key_auth: Option = r.get("key_auth"); let key_p256dh: Option = r.get("key_p256dh"); target .execute( r#"INSERT INTO "WebPushSubscription" (created_at, updated_at, "user", network, endpoint, key_vapid, key_auth, key_p256dh) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, &[ &created_at, &updated_at, &dst_user_id, &dst_net_id, &endpoint, &key_vapid, &key_auth, &key_p256dh, ], ) .await?; } Ok(()) } async fn copy_web_push_subscriptions_user_only( source: &Client, target: &Client, username: &str, dst_user_id: i32, ) -> Result<(), MigrateError> { let rows = source .query( r#"SELECT wps.created_at, wps.updated_at, wps.endpoint, wps.key_vapid, wps.key_auth, wps.key_p256dh FROM "WebPushSubscription" wps JOIN "User" u ON wps."user" = u.id WHERE u.username = $1 AND wps.network IS NULL"#, &[&username], ) .await?; for r in &rows { let created_at: chrono::DateTime = r.get("created_at"); let updated_at: chrono::DateTime = r.get("updated_at"); let endpoint: String = r.get("endpoint"); let key_vapid: Option = r.get("key_vapid"); let key_auth: Option = r.get("key_auth"); let key_p256dh: Option = r.get("key_p256dh"); target .execute( r#"INSERT INTO "WebPushSubscription" (created_at, updated_at, "user", network, endpoint, key_vapid, key_auth, key_p256dh) VALUES ($1, $2, $3, NULL, $4, $5, $6, $7)"#, &[ &created_at, &updated_at, &dst_user_id, &endpoint, &key_vapid, &key_auth, &key_p256dh, ], ) .await?; } Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn migrate_error_display_user_not_found() { let err = MigrateError::UserNotFound("alice".to_string()); assert_eq!(err.to_string(), "user not found in source: alice"); } #[test] fn migrate_error_display_network_not_found() { let err = MigrateError::NetworkNotFound("libera".to_string()); assert_eq!(err.to_string(), "network not found: libera"); } }