diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 125 |
1 files changed, 99 insertions, 26 deletions
diff --git a/src/main.rs b/src/main.rs index e307b8e..1dde8a7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,23 @@ -use std::net::SocketAddr; - -use url::Url; - +use axum::extract::connect_info; +use axum::http::Request; use axum::routing::{get, post}; use axum::Router; use sqlx::postgres::PgPoolOptions; use sqlx::{Pool, Postgres}; -use sqids::Sqids; +use hyper::body::Incoming; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server; +use info_utils::prelude::*; use serde::Deserialize; +use sqids::Sqids; +use tower::Service; +use url::Url; -use info_utils::prelude::*; +use std::env; +use std::sync::Arc; pub mod get; pub mod post; @@ -39,22 +44,39 @@ pub struct CreateForm { pub url: url::Url, } +#[derive(Clone)] +#[allow(dead_code)] +pub struct UdsConnectInfo { + pub peer_addr: Arc<tokio::net::unix::SocketAddr>, + pub peer_cred: tokio::net::unix::UCred, +} + +impl connect_info::Connected<&tokio::net::UnixStream> for UdsConnectInfo { + fn connect_info(target: &tokio::net::UnixStream) -> Self { + let peer_addr = target.peer_addr().unwrap(); + let peer_cred = target.peer_cred().unwrap(); + + Self { + peer_addr: Arc::new(peer_addr), + peer_cred, + } + } +} + #[tokio::main] async fn main() -> eyre::Result<()> { color_eyre::install()?; let db_pool = init_db().await?; - let host = std::env::var("CHELA_HOST").unwrap_or("localhost".to_string()); + let host = env::var("CHELA_HOST").unwrap_or("localhost".to_string()); + let alphabet = env::var("CHELA_ALPHABET") + .unwrap_or("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".to_string()); let sqids = Sqids::builder() - .alphabet( - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - .chars() - .collect(), - ) + .alphabet(alphabet.chars().collect()) .blocklist(["create".to_string()].into()) .build()?; - let main_page_redirect = std::env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default(); - let behind_proxy = std::env::var("CHELA_BEHIND_PROXY").is_ok(); + let main_page_redirect = env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default(); + let behind_proxy = env::var("CHELA_BEHIND_PROXY").is_ok(); let server_state = ServerState { db_pool, host, @@ -63,17 +85,69 @@ async fn main() -> eyre::Result<()> { behind_proxy, }; - let address = std::env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string()); - let port = 3000; + serve(server_state).await?; + Ok(()) +} + +async fn serve(state: ServerState) -> eyre::Result<()> { + let unix_socket = env::var("CHELA_UNIX_SOCKET").unwrap_or_default(); + if unix_socket.is_empty() { + let router = Router::new() + .route("/", get(get::index)) + .route("/:id", get(get::id)) + .route("/create", get(get::create_id)) + .route("/", post(post::create_link)) + .layer(axum::Extension(state)); + let address = env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string()); + let port = 3000; + let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?; + log!("Listening at {}:{}", address, port); + axum::serve( + listener, + router.into_make_service_with_connect_info::<std::net::SocketAddr>(), + ) + .await?; + } else { + let router = Router::new() + .route("/", get(get::index)) + .route("/:id", get(get::id_unix)) + .route("/create", get(get::create_id)) + .route("/", post(post::create_link)) + .layer(axum::Extension(state)); + let unix_socket_path = std::path::Path::new(&unix_socket); + if unix_socket_path.exists() { + tokio::fs::remove_file(unix_socket_path).await?; + } + let listener = tokio::net::UnixListener::bind(unix_socket_path)?; + log!("Listening via Unix socket at {}", unix_socket); + tokio::spawn(async move { + let mut service = router.into_make_service_with_connect_info::<UdsConnectInfo>(); + loop { + let (socket, _remote_addr) = listener.accept().await.unwrap(); + let tower_service = match service.call(&socket).await { + Ok(value) => value, + Err(err) => match err {}, + }; + + tokio::spawn(async move { + let socket = TokioIo::new(socket); + let hyper_service = + hyper::service::service_fn(move |request: Request<Incoming>| { + tower_service.clone().call(request) + }); + + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(socket, hyper_service) + .await + { + warn!("Failed to serve connection: {}", err); + } + }); + } + }) + .await?; + } - let router = init_routes(server_state); - let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?; - log!("Listening at {}:{}", address, port); - axum::serve( - listener, - router.into_make_service_with_connect_info::<SocketAddr>(), - ) - .await?; Ok(()) } @@ -81,7 +155,7 @@ async fn init_db() -> eyre::Result<Pool<Postgres>> { let db_pool = PgPoolOptions::new() .max_connections(15) .connect( - std::env::var("DATABASE_URL") + env::var("DATABASE_URL") .expect("DATABASE_URL must be set") .as_str(), ) @@ -128,7 +202,6 @@ CREATE TABLE IF NOT EXISTS chela.tracking ( fn init_routes(state: ServerState) -> Router { Router::new() .route("/", get(get::index)) - .route("/:id", get(get::id)) .route("/create", get(get::create_id)) .route("/", post(post::create_link)) .layer(axum::Extension(state)) |