aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs125
1 files changed, 99 insertions, 26 deletions
diff --git a/src/main.rs b/src/main.rs
index e307b8e..1dde8a7 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,18 +1,23 @@
-use std::net::SocketAddr;
-
-use url::Url;
-
+use axum::extract::connect_info;
+use axum::http::Request;
use axum::routing::{get, post};
use axum::Router;
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};
-use sqids::Sqids;
+use hyper::body::Incoming;
+use hyper_util::rt::{TokioExecutor, TokioIo};
+use hyper_util::server;
+use info_utils::prelude::*;
use serde::Deserialize;
+use sqids::Sqids;
+use tower::Service;
+use url::Url;
-use info_utils::prelude::*;
+use std::env;
+use std::sync::Arc;
pub mod get;
pub mod post;
@@ -39,22 +44,39 @@ pub struct CreateForm {
pub url: url::Url,
}
+#[derive(Clone)]
+#[allow(dead_code)]
+pub struct UdsConnectInfo {
+ pub peer_addr: Arc<tokio::net::unix::SocketAddr>,
+ pub peer_cred: tokio::net::unix::UCred,
+}
+
+impl connect_info::Connected<&tokio::net::UnixStream> for UdsConnectInfo {
+ fn connect_info(target: &tokio::net::UnixStream) -> Self {
+ let peer_addr = target.peer_addr().unwrap();
+ let peer_cred = target.peer_cred().unwrap();
+
+ Self {
+ peer_addr: Arc::new(peer_addr),
+ peer_cred,
+ }
+ }
+}
+
#[tokio::main]
async fn main() -> eyre::Result<()> {
color_eyre::install()?;
let db_pool = init_db().await?;
- let host = std::env::var("CHELA_HOST").unwrap_or("localhost".to_string());
+ let host = env::var("CHELA_HOST").unwrap_or("localhost".to_string());
+ let alphabet = env::var("CHELA_ALPHABET")
+ .unwrap_or("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".to_string());
let sqids = Sqids::builder()
- .alphabet(
- "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
- .chars()
- .collect(),
- )
+ .alphabet(alphabet.chars().collect())
.blocklist(["create".to_string()].into())
.build()?;
- let main_page_redirect = std::env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default();
- let behind_proxy = std::env::var("CHELA_BEHIND_PROXY").is_ok();
+ let main_page_redirect = env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default();
+ let behind_proxy = env::var("CHELA_BEHIND_PROXY").is_ok();
let server_state = ServerState {
db_pool,
host,
@@ -63,17 +85,69 @@ async fn main() -> eyre::Result<()> {
behind_proxy,
};
- let address = std::env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string());
- let port = 3000;
+ serve(server_state).await?;
+ Ok(())
+}
+
+async fn serve(state: ServerState) -> eyre::Result<()> {
+ let unix_socket = env::var("CHELA_UNIX_SOCKET").unwrap_or_default();
+ if unix_socket.is_empty() {
+ let router = Router::new()
+ .route("/", get(get::index))
+ .route("/:id", get(get::id))
+ .route("/create", get(get::create_id))
+ .route("/", post(post::create_link))
+ .layer(axum::Extension(state));
+ let address = env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string());
+ let port = 3000;
+ let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
+ log!("Listening at {}:{}", address, port);
+ axum::serve(
+ listener,
+ router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
+ )
+ .await?;
+ } else {
+ let router = Router::new()
+ .route("/", get(get::index))
+ .route("/:id", get(get::id_unix))
+ .route("/create", get(get::create_id))
+ .route("/", post(post::create_link))
+ .layer(axum::Extension(state));
+ let unix_socket_path = std::path::Path::new(&unix_socket);
+ if unix_socket_path.exists() {
+ tokio::fs::remove_file(unix_socket_path).await?;
+ }
+ let listener = tokio::net::UnixListener::bind(unix_socket_path)?;
+ log!("Listening via Unix socket at {}", unix_socket);
+ tokio::spawn(async move {
+ let mut service = router.into_make_service_with_connect_info::<UdsConnectInfo>();
+ loop {
+ let (socket, _remote_addr) = listener.accept().await.unwrap();
+ let tower_service = match service.call(&socket).await {
+ Ok(value) => value,
+ Err(err) => match err {},
+ };
+
+ tokio::spawn(async move {
+ let socket = TokioIo::new(socket);
+ let hyper_service =
+ hyper::service::service_fn(move |request: Request<Incoming>| {
+ tower_service.clone().call(request)
+ });
+
+ if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
+ .serve_connection_with_upgrades(socket, hyper_service)
+ .await
+ {
+ warn!("Failed to serve connection: {}", err);
+ }
+ });
+ }
+ })
+ .await?;
+ }
- let router = init_routes(server_state);
- let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
- log!("Listening at {}:{}", address, port);
- axum::serve(
- listener,
- router.into_make_service_with_connect_info::<SocketAddr>(),
- )
- .await?;
Ok(())
}
@@ -81,7 +155,7 @@ async fn init_db() -> eyre::Result<Pool<Postgres>> {
let db_pool = PgPoolOptions::new()
.max_connections(15)
.connect(
- std::env::var("DATABASE_URL")
+ env::var("DATABASE_URL")
.expect("DATABASE_URL must be set")
.as_str(),
)
@@ -128,7 +202,6 @@ CREATE TABLE IF NOT EXISTS chela.tracking (
fn init_routes(state: ServerState) -> Router {
Router::new()
.route("/", get(get::index))
- .route("/:id", get(get::id))
.route("/create", get(get::create_id))
.route("/", post(post::create_link))
.layer(axum::Extension(state))