use axum::{ extract::{Multipart, Path, State}, http::StatusCode, response::{IntoResponse, Redirect, Response}, }; use askama::Template; use askama_web::WebTemplate; use sha2::{Digest, Sha256}; use sqlx::FromRow; use crate::auth_guard::{AuthUser, OptionalAuth}; use crate::csam::CsamResult; use crate::state::AppState; const MAX_FILE_SIZE: usize = 10 * 1024 * 1024; const FREE_STORAGE_LIMIT: i64 = 50 * 1024 * 1024; const PRO_STORAGE_LIMIT: i64 = 1024 * 1024 * 1024; const THUMB_MAX_DIM: u32 = 300; const FREE_EXPIRY_DAYS: i64 = 90; const PRO_EXPIRY_DAYS: i64 = 90; const ALLOWED_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"]; fn sanitize_filename(input: &str) -> String { input.chars() .map(|c| if c.is_alphanumeric() || c == '.' || c == '_' || c == '-' { c } else { '_' }) .collect() } fn detect_content_type(data: &[u8]) -> Option<&'static str> { if data.len() < 12 { return None; } if data[0..3] == [0xFF, 0xD8, 0xFF] { Some("image/jpeg") } else if data[0..4] == [0x89, 0x50, 0x4E, 0x47] { Some("image/png") } else if data[0..3] == [0x47, 0x49, 0x46] { Some("image/gif") } else if data[0..4] == [0x52, 0x49, 0x46, 0x46] && data[8..12] == [0x57, 0x45, 0x42, 0x50] { Some("image/webp") } else { None } } fn generate_thumbnail_from_image(img: &::image::DynamicImage) -> Result, String> { let thumb = img.resize(THUMB_MAX_DIM, THUMB_MAX_DIM, ::image::imageops::FilterType::Lanczos3); let mut buf = std::io::Cursor::new(Vec::new()); thumb .write_to(&mut buf, ::image::ImageFormat::WebP) .map_err(|e| format!("thumbnail encode failed: {e}"))?; Ok(buf.into_inner()) } fn format_time(t: time::OffsetDateTime) -> String { format!( "{:04}-{:02}-{:02} {:02}:{:02} UTC", t.year(), t.month() as u8, t.day(), t.hour(), t.minute(), ) } fn format_size(bytes: i64) -> String { if bytes < 1024 { format!("{bytes} B") } else if bytes < 1024 * 1024 { format!("{:.1} KB", bytes as f64 / 1024.0) } else { format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0)) } } #[derive(FromRow)] #[allow(dead_code)] struct ImageRow { id: String, user_id: String, filename: String, content_type: String, size_bytes: i64, width: Option, height: Option, created_at: time::OffsetDateTime, expires_at: Option, } #[derive(Template, WebTemplate)] #[template(path = "upload.html")] #[allow(dead_code)] pub struct UploadTemplate { pub email: Option, } pub async fn index(AuthUser(user): AuthUser) -> UploadTemplate { UploadTemplate { email: user.email, } } pub async fn upload( State(state): State, AuthUser(user): AuthUser, mut multipart: Multipart, ) -> Result { let is_pro = user.is_pro(); { let mut limiter = state.upload_limiter.lock().await; let entries = limiter.entry(user.sub.clone()).or_default(); if !crate::state::check_rate_limit(entries, is_pro) { return Err((StatusCode::TOO_MANY_REQUESTS, "upload rate limit exceeded").into_response()); } } let storage_limit = if is_pro { PRO_STORAGE_LIMIT } else { FREE_STORAGE_LIMIT }; let used: (Option,) = sqlx::query_as("SELECT sum(size_bytes) FROM images WHERE user_id = $1") .bind(&user.sub) .fetch_one(&state.db) .await .map_err(|e| { tracing::error!("storage query failed: {e}"); (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response() })?; let used_bytes = used.0.unwrap_or(0); let mut file_data: Option<(String, String, Vec)> = None; while let Ok(Some(field)) = multipart.next_field().await { if field.name() == Some("file") { let filename = field .file_name() .unwrap_or("upload") .to_string(); let content_type = field .content_type() .unwrap_or("application/octet-stream") .to_string(); let data = field.bytes().await.map_err(|e| { tracing::error!("multipart read failed: {e}"); (StatusCode::BAD_REQUEST, "failed to read upload").into_response() })?; file_data = Some((filename, content_type, data.to_vec())); break; } } let (filename, claimed_type, data) = file_data.ok_or_else(|| { (StatusCode::BAD_REQUEST, "no file uploaded").into_response() })?; if data.len() > MAX_FILE_SIZE { return Err((StatusCode::BAD_REQUEST, "file too large (10MB max)").into_response()); } let detected_type = detect_content_type(&data).ok_or_else(|| { (StatusCode::BAD_REQUEST, "unsupported image format").into_response() })?; if !ALLOWED_TYPES.contains(&claimed_type.as_str()) || detected_type != claimed_type { return Err( (StatusCode::BAD_REQUEST, "content type mismatch or unsupported format").into_response(), ); } if let Some(scanner) = &state.csam_scanner { match scanner.check(&data).await { Ok(CsamResult::Match) => { let hash = hex::encode(Sha256::digest(&data)); tracing::error!("CSAM match detected for user {} hash {}", user.sub, hash); let _ = sqlx::query( "INSERT INTO csam_incidents (user_id, image_hash) VALUES ($1, $2)", ) .bind(&user.sub) .bind(&hash) .execute(&state.db) .await; return Err((StatusCode::FORBIDDEN, "upload rejected").into_response()); } Ok(CsamResult::Clean) => {} Err(e) => { tracing::error!("CSAM scanner unavailable: {e}"); return Err(( StatusCode::SERVICE_UNAVAILABLE, "upload temporarily unavailable", ) .into_response()); } } } let size_bytes = data.len() as i64; if used_bytes + size_bytes > storage_limit { return Err(( StatusCode::FORBIDDEN, "storage limit reached -- upgrade to pro for more space", ) .into_response()); } const MAX_DECODED_BYTES: u64 = 100 * 1024 * 1024; let img = ::image::load_from_memory(&data).map_err(|_| { (StatusCode::BAD_REQUEST, "invalid image data").into_response() })?; let (w, h) = (img.width(), img.height()); if (w as u64) * (h as u64) * 4 > MAX_DECODED_BYTES { return Err((StatusCode::BAD_REQUEST, "image dimensions too large").into_response()); } let (width, height) = (Some(w as i32), Some(h as i32)); let thumb_data = generate_thumbnail_from_image(&img).map_err(|e| { tracing::error!("thumbnail generation failed: {e}"); (StatusCode::INTERNAL_SERVER_ERROR, "thumbnail failed").into_response() })?; let expires_at = if !is_pro { Some(time::OffsetDateTime::now_utc() + time::Duration::days(FREE_EXPIRY_DAYS)) } else if user.content_expires() { Some(time::OffsetDateTime::now_utc() + time::Duration::days(PRO_EXPIRY_DAYS)) } else { None }; let id = nanoid::nanoid!(8); let filename: String = filename .chars() .map(|c| if c.is_alphanumeric() || c == '.' || c == '_' || c == '-' { c } else { '_' }) .collect(); let original_key = format!("{id}/{filename}"); let thumb_key = format!("{id}/thumb.webp"); state .storage .put(&original_key, &data, &claimed_type) .await .map_err(|e| { tracing::error!("S3 upload failed: {e}"); (StatusCode::INTERNAL_SERVER_ERROR, "upload failed").into_response() })?; state .storage .put(&thumb_key, &thumb_data, "image/webp") .await .map_err(|e| { tracing::error!("S3 thumbnail upload failed: {e}"); (StatusCode::INTERNAL_SERVER_ERROR, "thumbnail upload failed").into_response() })?; sqlx::query( "INSERT INTO images (id, user_id, filename, content_type, size_bytes, width, height, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", ) .bind(&id) .bind(&user.sub) .bind(&filename) .bind(&claimed_type) .bind(size_bytes) .bind(width) .bind(height) .bind(expires_at) .execute(&state.db) .await .map_err(|e| { tracing::error!("image insert failed: {e}"); (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response() })?; Ok(Redirect::to(&format!("/{id}"))) } #[derive(Template, WebTemplate)] #[template(path = "view.html")] pub struct ViewTemplate { pub id: String, pub filename: String, pub content_type: String, pub size: String, pub width: Option, pub height: Option, pub created_at: String, pub is_owner: bool, } pub async fn view( State(state): State, OptionalAuth(user): OptionalAuth, Path(id): Path, ) -> Result { let row = sqlx::query_as::<_, ImageRow>( "SELECT * FROM images WHERE id = $1 AND hidden = false", ) .bind(&id) .fetch_optional(&state.db) .await .map_err(|e| { tracing::error!("image query failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR })? .ok_or(StatusCode::NOT_FOUND)?; let is_owner = user.map(|u| u.sub == row.user_id).unwrap_or(false); Ok(ViewTemplate { id: row.id, filename: row.filename, content_type: row.content_type, size: format_size(row.size_bytes), width: row.width, height: row.height, created_at: format_time(row.created_at), is_owner, }) } pub async fn raw( State(state): State, Path(id): Path, ) -> Result { let row = sqlx::query_as::<_, ImageRow>( "SELECT * FROM images WHERE id = $1 AND hidden = false", ) .bind(&id) .fetch_optional(&state.db) .await .map_err(|e| { tracing::error!("image query failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR })? .ok_or(StatusCode::NOT_FOUND)?; let key = format!("{}/{}", row.id, row.filename); let (data, content_type) = state.storage.get(&key).await.map_err(|e| { tracing::error!("S3 get failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR })?; Ok(( [(axum::http::header::CONTENT_TYPE, content_type)], data, ) .into_response()) } pub async fn thumb( State(state): State, Path(id): Path, ) -> Result { let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM images WHERE id = $1 AND hidden = false)") .bind(&id) .fetch_one(&state.db) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if !exists { return Err(StatusCode::NOT_FOUND); } let key = format!("{id}/thumb.webp"); let (data, _) = state.storage.get(&key).await.map_err(|e| { tracing::error!("S3 thumb get failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR })?; Ok(( [(axum::http::header::CONTENT_TYPE, "image/webp".to_string())], data, ) .into_response()) } #[allow(dead_code)] pub struct ImageEntry { pub id: String, pub filename: String, pub size: String, pub created_at: String, } #[derive(Template, WebTemplate)] #[template(path = "my.html")] pub struct MyTemplate { pub images: Vec, } pub async fn my_images( State(state): State, AuthUser(user): AuthUser, ) -> Result { let rows = sqlx::query_as::<_, ImageRow>( "SELECT * FROM images WHERE user_id = $1 ORDER BY created_at DESC", ) .bind(&user.sub) .fetch_all(&state.db) .await .map_err(|e| { tracing::error!("image list query failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR })?; let images = rows .into_iter() .map(|r| ImageEntry { id: r.id, filename: r.filename, size: format_size(r.size_bytes), created_at: format_time(r.created_at), }) .collect(); Ok(MyTemplate { images }) } pub async fn delete( State(state): State, AuthUser(user): AuthUser, Path(id): Path, ) -> Result { let row = sqlx::query_as::<_, ImageRow>("SELECT * FROM images WHERE id = $1 AND user_id = $2") .bind(&id) .bind(&user.sub) .fetch_optional(&state.db) .await .map_err(|e| { tracing::error!("image query failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR })? .ok_or(StatusCode::NOT_FOUND)?; let original_key = format!("{}/{}", row.id, row.filename); let thumb_key = format!("{}/thumb.webp", row.id); let _ = state.storage.delete(&original_key).await; let _ = state.storage.delete(&thumb_key).await; sqlx::query("DELETE FROM images WHERE id = $1") .bind(&id) .execute(&state.db) .await .map_err(|e| { tracing::error!("image delete failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR })?; Ok(Redirect::to("/my")) } fn check_internal_token(headers: &axum::http::HeaderMap) -> Result<(), StatusCode> { let expected = std::env::var("INTERNAL_API_TOKEN").unwrap_or_default(); if expected.is_empty() { return Err(StatusCode::FORBIDDEN); } match headers.get("x-internal-token").and_then(|v| v.to_str().ok()) { Some(token) if token == expected => Ok(()), _ => Err(StatusCode::FORBIDDEN), } } pub async fn internal_hide( headers: axum::http::HeaderMap, State(state): State, Path(id): Path, ) -> StatusCode { if let Err(code) = check_internal_token(&headers) { return code; } match sqlx::query("UPDATE images SET hidden = true WHERE id = $1") .bind(&id) .execute(&state.db) .await { Ok(r) if r.rows_affected() > 0 => StatusCode::OK, Ok(_) => StatusCode::NOT_FOUND, Err(e) => { tracing::error!("hide image failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR } } } pub async fn internal_unhide( headers: axum::http::HeaderMap, State(state): State, Path(id): Path, ) -> StatusCode { if let Err(code) = check_internal_token(&headers) { return code; } match sqlx::query("UPDATE images SET hidden = false WHERE id = $1") .bind(&id) .execute(&state.db) .await { Ok(r) if r.rows_affected() > 0 => StatusCode::OK, Ok(_) => StatusCode::NOT_FOUND, Err(e) => { tracing::error!("unhide image failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR } } } pub async fn internal_delete( headers: axum::http::HeaderMap, State(state): State, Path(id): Path, ) -> StatusCode { if let Err(code) = check_internal_token(&headers) { return code; } let row = match sqlx::query_as::<_, ImageRow>("SELECT * FROM images WHERE id = $1") .bind(&id) .fetch_optional(&state.db) .await { Ok(Some(r)) => r, Ok(None) => return StatusCode::NOT_FOUND, Err(e) => { tracing::error!("internal delete query failed: {e}"); return StatusCode::INTERNAL_SERVER_ERROR; } }; let original_key = format!("{}/{}", row.id, row.filename); let thumb_key = format!("{}/thumb.webp", row.id); let _ = state.storage.delete(&original_key).await; let _ = state.storage.delete(&thumb_key).await; match sqlx::query("DELETE FROM images WHERE id = $1") .bind(&id) .execute(&state.db) .await { Ok(_) => StatusCode::OK, Err(e) => { tracing::error!("internal delete failed: {e}"); StatusCode::INTERNAL_SERVER_ERROR } } }