aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShav Kinderlehrer <[email protected]>2024-04-07 20:06:27 -0400
committerShav Kinderlehrer <[email protected]>2024-04-07 20:39:02 -0400
commit3b7d5454ccb2875fb45a46ea138d1bffec5b7542 (patch)
tree372304aee1540daa1b42767fa47b26886e432d9d
parentf080854b84d80f6063b4f9392d059a84ec09e66c (diff)
downloadchela-3b7d5454ccb2875fb45a46ea138d1bffec5b7542.tar.gz
chela-3b7d5454ccb2875fb45a46ea138d1bffec5b7542.zip
Add support for Unix sockets and custom alphabetv1.1.0
-rw-r--r--Cargo.lock5
-rw-r--r--Cargo.toml7
-rw-r--r--README.md6
-rw-r--r--src/get.rs29
-rw-r--r--src/main.rs125
5 files changed, 138 insertions, 34 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 131718f..15a7252 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -194,16 +194,19 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chela"
-version = "1.0.0"
+version = "1.1.0"
dependencies = [
"axum",
"color-eyre",
"eyre",
+ "hyper",
+ "hyper-util",
"info_utils",
"serde",
"sqids",
"sqlx",
"tokio",
+ "tower",
"url",
]
diff --git a/Cargo.toml b/Cargo.toml
index 741083f..7f8867f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,17 +1,20 @@
[package]
name = "chela"
-version = "1.0.0"
+version = "1.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
-axum = "0.7.5"
+axum = { version = "0.7.5", features = ["tokio"] }
color-eyre = "0.6.3"
eyre = "0.6.12"
+hyper = "1.2.0"
+hyper-util = { version = "0.1.3", features = ["tokio"] }
info_utils = "2.2.3"
serde = "1.0.197"
sqids = "0.4.1"
sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "macros", "migrate", "tls-rustls"] }
tokio = { version = "1.37.0", features = ["full"] }
+tower = "0.4.13"
url = { version = "2.5.0", features = ["serde"] }
diff --git a/README.md b/README.md
index 2f24b88..60ba570 100644
--- a/README.md
+++ b/README.md
@@ -63,6 +63,12 @@ A page that Chela will redirect to when `/` is requested instead of replying wit
##### `CHELA_BEHIND_PROXY`
If this variable is set, Chela will use the `X-Real-IP` header as the client IP address rather than the connection address.
+##### `CHELA_UNIX_SOCKET`
+If you would like Chela to listen for HTTP requests over a Unix socket, set this variable to the socket path that it should use. By default, Chela will listen via a Tcp socket.
+
+##### `CHELA_ALPHABET`
+If this variable is set, Chela will use the characters in `CHELA_ALPHABET` to create IDs for URLs. The default alphabet is `abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ`. See [here](https://sqids.org/faq#unique)
+
### Manually
#### Build
```bash
diff --git a/src/get.rs b/src/get.rs
index 61c5dd2..efe2711 100644
--- a/src/get.rs
+++ b/src/get.rs
@@ -9,6 +9,7 @@ use axum::Extension;
use info_utils::prelude::*;
use crate::ServerState;
+use crate::UdsConnectInfo;
use crate::UrlRow;
pub async fn index(Extension(state): Extension<ServerState>) -> impl IntoResponse {
@@ -34,6 +35,16 @@ pub async fn index(Extension(state): Extension<ServerState>) -> impl IntoRespons
.into_response()
}
+pub async fn id_unix(
+ headers: HeaderMap,
+ ConnectInfo(addr): ConnectInfo<UdsConnectInfo>,
+ Extension(state): Extension<ServerState>,
+ Path(id): Path<String>,
+) -> impl IntoResponse {
+ let ip = format!("{:?}", addr.peer_addr);
+ run_id(headers, ip, state, id).await
+}
+
/// # Panics
/// Will panic if `parse()` fails
pub async fn id(
@@ -42,8 +53,17 @@ pub async fn id(
Extension(state): Extension<ServerState>,
Path(id): Path<String>,
) -> impl IntoResponse {
- let mut show_request = false;
let ip = get_ip(&headers, addr, &state).unwrap_or_default();
+ run_id(headers, ip, state, id).await
+}
+
+async fn run_id(
+ headers: HeaderMap,
+ ip: String,
+ state: ServerState,
+ id: String,
+) -> impl IntoResponse {
+ let mut show_request = false;
log!("Request for '{}' from {}", id.clone(), ip);
let mut use_id = id;
if use_id.ends_with('+') {
@@ -66,7 +86,7 @@ pub async fn id(
.into_response();
}
log!("Redirecting {} -> {}", it.id, it.url);
- save_analytics(headers, it.clone(), addr, state).await;
+ save_analytics(headers, it.clone(), ip, state).await;
let mut response_headers = HeaderMap::new();
response_headers.insert("Cache-Control", "private, max-age=90".parse().unwrap());
response_headers.insert("Location", it.url.parse().unwrap());
@@ -92,9 +112,8 @@ pub async fn id(
(StatusCode::NOT_FOUND, Html("<pre>Not found.</pre>")).into_response()
}
-async fn save_analytics(headers: HeaderMap, item: UrlRow, addr: SocketAddr, state: ServerState) {
+async fn save_analytics(headers: HeaderMap, item: UrlRow, ip: String, state: ServerState) {
let id = item.id;
- let ip = get_ip(&headers, addr, &state);
let referer = match headers.get("referer") {
Some(it) => {
if let Ok(i) = it.to_str() {
@@ -130,7 +149,7 @@ VALUES ($1,$2,$3,$4)
.await;
if res.is_ok() {
- log!("Saved analytics for '{id}' from {}", ip.unwrap_or_default());
+ log!("Saved analytics for '{id}' from {}", ip);
}
}
diff --git a/src/main.rs b/src/main.rs
index e307b8e..1dde8a7 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,18 +1,23 @@
-use std::net::SocketAddr;
-
-use url::Url;
-
+use axum::extract::connect_info;
+use axum::http::Request;
use axum::routing::{get, post};
use axum::Router;
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};
-use sqids::Sqids;
+use hyper::body::Incoming;
+use hyper_util::rt::{TokioExecutor, TokioIo};
+use hyper_util::server;
+use info_utils::prelude::*;
use serde::Deserialize;
+use sqids::Sqids;
+use tower::Service;
+use url::Url;
-use info_utils::prelude::*;
+use std::env;
+use std::sync::Arc;
pub mod get;
pub mod post;
@@ -39,22 +44,39 @@ pub struct CreateForm {
pub url: url::Url,
}
+#[derive(Clone)]
+#[allow(dead_code)]
+pub struct UdsConnectInfo {
+ pub peer_addr: Arc<tokio::net::unix::SocketAddr>,
+ pub peer_cred: tokio::net::unix::UCred,
+}
+
+impl connect_info::Connected<&tokio::net::UnixStream> for UdsConnectInfo {
+ fn connect_info(target: &tokio::net::UnixStream) -> Self {
+ let peer_addr = target.peer_addr().unwrap();
+ let peer_cred = target.peer_cred().unwrap();
+
+ Self {
+ peer_addr: Arc::new(peer_addr),
+ peer_cred,
+ }
+ }
+}
+
#[tokio::main]
async fn main() -> eyre::Result<()> {
color_eyre::install()?;
let db_pool = init_db().await?;
- let host = std::env::var("CHELA_HOST").unwrap_or("localhost".to_string());
+ let host = env::var("CHELA_HOST").unwrap_or("localhost".to_string());
+ let alphabet = env::var("CHELA_ALPHABET")
+ .unwrap_or("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".to_string());
let sqids = Sqids::builder()
- .alphabet(
- "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
- .chars()
- .collect(),
- )
+ .alphabet(alphabet.chars().collect())
.blocklist(["create".to_string()].into())
.build()?;
- let main_page_redirect = std::env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default();
- let behind_proxy = std::env::var("CHELA_BEHIND_PROXY").is_ok();
+ let main_page_redirect = env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default();
+ let behind_proxy = env::var("CHELA_BEHIND_PROXY").is_ok();
let server_state = ServerState {
db_pool,
host,
@@ -63,17 +85,69 @@ async fn main() -> eyre::Result<()> {
behind_proxy,
};
- let address = std::env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string());
- let port = 3000;
+ serve(server_state).await?;
+ Ok(())
+}
+
+async fn serve(state: ServerState) -> eyre::Result<()> {
+ let unix_socket = env::var("CHELA_UNIX_SOCKET").unwrap_or_default();
+ if unix_socket.is_empty() {
+ let router = Router::new()
+ .route("/", get(get::index))
+ .route("/:id", get(get::id))
+ .route("/create", get(get::create_id))
+ .route("/", post(post::create_link))
+ .layer(axum::Extension(state));
+ let address = env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string());
+ let port = 3000;
+ let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
+ log!("Listening at {}:{}", address, port);
+ axum::serve(
+ listener,
+ router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
+ )
+ .await?;
+ } else {
+ let router = Router::new()
+ .route("/", get(get::index))
+ .route("/:id", get(get::id_unix))
+ .route("/create", get(get::create_id))
+ .route("/", post(post::create_link))
+ .layer(axum::Extension(state));
+ let unix_socket_path = std::path::Path::new(&unix_socket);
+ if unix_socket_path.exists() {
+ tokio::fs::remove_file(unix_socket_path).await?;
+ }
+ let listener = tokio::net::UnixListener::bind(unix_socket_path)?;
+ log!("Listening via Unix socket at {}", unix_socket);
+ tokio::spawn(async move {
+ let mut service = router.into_make_service_with_connect_info::<UdsConnectInfo>();
+ loop {
+ let (socket, _remote_addr) = listener.accept().await.unwrap();
+ let tower_service = match service.call(&socket).await {
+ Ok(value) => value,
+ Err(err) => match err {},
+ };
+
+ tokio::spawn(async move {
+ let socket = TokioIo::new(socket);
+ let hyper_service =
+ hyper::service::service_fn(move |request: Request<Incoming>| {
+ tower_service.clone().call(request)
+ });
+
+ if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
+ .serve_connection_with_upgrades(socket, hyper_service)
+ .await
+ {
+ warn!("Failed to serve connection: {}", err);
+ }
+ });
+ }
+ })
+ .await?;
+ }
- let router = init_routes(server_state);
- let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
- log!("Listening at {}:{}", address, port);
- axum::serve(
- listener,
- router.into_make_service_with_connect_info::<SocketAddr>(),
- )
- .await?;
Ok(())
}
@@ -81,7 +155,7 @@ async fn init_db() -> eyre::Result<Pool<Postgres>> {
let db_pool = PgPoolOptions::new()
.max_connections(15)
.connect(
- std::env::var("DATABASE_URL")
+ env::var("DATABASE_URL")
.expect("DATABASE_URL must be set")
.as_str(),
)
@@ -128,7 +202,6 @@ CREATE TABLE IF NOT EXISTS chela.tracking (
fn init_routes(state: ServerState) -> Router {
Router::new()
.route("/", get(get::index))
- .route("/:id", get(get::id))
.route("/create", get(get::create_id))
.route("/", post(post::create_link))
.layer(axum::Extension(state))