fix: make stripe webhooks explicitly toggleable with mandatory secret validation (#23)

Co-authored-by: userAdityaa <aditya.chaudhary1558@gmail.com>
Co-committed-by: userAdityaa <aditya.chaudhary1558@gmail.com>
This commit is contained in:
2026-04-17 22:57:37 +00:00
committed by hodlbod
parent 87dcf53d74
commit 44f9928070
11 changed files with 145 additions and 41 deletions
+4 -3
View File
@@ -139,7 +139,7 @@ impl Api {
api: Arc::new(self),
};
Router::new()
let router = Router::new()
.route("/identity", get(get_identity))
.route("/plans", get(list_plans))
.route("/plans/:id", get(get_plan))
@@ -158,8 +158,9 @@ impl Api {
"/tenants/:pubkey/stripe/session",
get(create_stripe_session),
)
.route("/stripe/webhook", post(stripe_webhook))
.with_state(state)
.route("/stripe/webhook", post(stripe_webhook));
router.with_state(state)
}
fn extract_auth_pubkey(&self, headers: &HeaderMap) -> std::result::Result<String, ApiError> {
+105 -6
View File
@@ -95,6 +95,9 @@ impl Billing {
panic!("missing STRIPE_SECRET_KEY environment variable");
}
let stripe_webhook_secret = std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_default();
if stripe_webhook_secret.trim().is_empty() {
panic!("missing STRIPE_WEBHOOK_SECRET environment variable");
}
let btc_quote_api_base =
std::env::var("BTC_PRICE_API_BASE").unwrap_or_else(|_| COINBASE_SPOT_API.to_string());
Self {
@@ -949,7 +952,8 @@ mod tests {
use sqlx::SqlitePool;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use std::str::FromStr;
use std::sync::{Mutex, OnceLock};
use std::sync::OnceLock;
use tokio::sync::Mutex;
fn env_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@@ -964,6 +968,14 @@ mod tests {
}
}
#[allow(unused_unsafe)]
fn set_stripe_webhook_secret(value: Option<&str>) {
match value {
Some(v) => unsafe { std::env::set_var("STRIPE_WEBHOOK_SECRET", v) },
None => unsafe { std::env::remove_var("STRIPE_WEBHOOK_SECRET") },
}
}
struct StripeSecretKeyGuard {
previous: Option<String>,
}
@@ -982,6 +994,24 @@ mod tests {
}
}
struct StripeWebhookSecretGuard {
previous: Option<String>,
}
impl StripeWebhookSecretGuard {
fn set(value: Option<&str>) -> Self {
let previous = std::env::var("STRIPE_WEBHOOK_SECRET").ok();
set_stripe_webhook_secret(value);
Self { previous }
}
}
impl Drop for StripeWebhookSecretGuard {
fn drop(&mut self) {
set_stripe_webhook_secret(self.previous.as_deref());
}
}
async fn test_pool() -> SqlitePool {
let connect_options = SqliteConnectOptions::from_str("sqlite::memory:")
.expect("valid sqlite memory url")
@@ -1003,8 +1033,9 @@ mod tests {
#[tokio::test]
async fn billing_new_panics_without_stripe_secret_key() {
let _lock = env_lock().lock().expect("acquire env lock");
let _env = StripeSecretKeyGuard::set(None);
let _lock = env_lock().lock().await;
let _secret_env = StripeSecretKeyGuard::set(None);
let _webhook_env = StripeWebhookSecretGuard::set(Some("whsec_test_dummy"));
let pool = test_pool().await;
let query = Query::new(pool.clone());
@@ -1034,9 +1065,76 @@ mod tests {
}
#[tokio::test]
async fn billing_new_accepts_non_empty_stripe_secret_key() {
let _lock = env_lock().lock().expect("acquire env lock");
let _env = StripeSecretKeyGuard::set(Some("sk_test_dummy"));
async fn billing_new_panics_without_stripe_webhook_secret() {
let _lock = env_lock().lock().await;
let _secret_env = StripeSecretKeyGuard::set(Some("sk_test_dummy"));
let _webhook_env = StripeWebhookSecretGuard::set(None);
let pool = test_pool().await;
let query = Query::new(pool.clone());
let command = Command::new(pool);
let robot = Robot::test_stub();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Billing::new(query, command, robot)
}));
let panic_payload = match result {
Ok(_) => panic!("constructor should panic when STRIPE_WEBHOOK_SECRET is missing"),
Err(payload) => payload,
};
let panic_msg = if let Some(msg) = panic_payload.downcast_ref::<&str>() {
(*msg).to_string()
} else if let Some(msg) = panic_payload.downcast_ref::<String>() {
msg.clone()
} else {
String::new()
};
assert!(
panic_msg.contains("missing STRIPE_WEBHOOK_SECRET environment variable"),
"unexpected panic: {panic_msg}"
);
}
#[tokio::test]
async fn billing_new_panics_with_blank_stripe_webhook_secret() {
let _lock = env_lock().lock().await;
let _secret_env = StripeSecretKeyGuard::set(Some("sk_test_dummy"));
let _webhook_env = StripeWebhookSecretGuard::set(Some(" "));
let pool = test_pool().await;
let query = Query::new(pool.clone());
let command = Command::new(pool);
let robot = Robot::test_stub();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Billing::new(query, command, robot)
}));
let panic_payload = match result {
Ok(_) => panic!("constructor should panic when STRIPE_WEBHOOK_SECRET is blank"),
Err(payload) => payload,
};
let panic_msg = if let Some(msg) = panic_payload.downcast_ref::<&str>() {
(*msg).to_string()
} else if let Some(msg) = panic_payload.downcast_ref::<String>() {
msg.clone()
} else {
String::new()
};
assert!(
panic_msg.contains("missing STRIPE_WEBHOOK_SECRET environment variable"),
"unexpected panic: {panic_msg}"
);
}
#[tokio::test]
async fn billing_new_accepts_non_empty_stripe_secrets() {
let _lock = env_lock().lock().await;
let _secret_env = StripeSecretKeyGuard::set(Some("sk_test_dummy"));
let _webhook_env = StripeWebhookSecretGuard::set(Some("whsec_test_dummy"));
let pool = test_pool().await;
let billing = Billing::new(
@@ -1046,5 +1144,6 @@ mod tests {
);
assert_eq!(billing.stripe_secret_key, "sk_test_dummy");
assert_eq!(billing.stripe_webhook_secret, "whsec_test_dummy");
}
}
+1 -1
View File
@@ -32,7 +32,7 @@ impl Command {
.fetch_one(&mut **tx)
.await?
}
_ => anyhow::bail!("unknown resource_type: {}", resource_type),
_ => anyhow::bail!("unknown resource_type: {resource_type}"),
};
let id = uuid::Uuid::new_v4().to_string();
+1 -1
View File
@@ -169,7 +169,7 @@ impl Infra {
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("zooid sync returned {}: {}", status, body)
anyhow::bail!("zooid sync returned {status}: {body}")
}
Ok(())
}
+1 -1
View File
@@ -3,8 +3,8 @@ mod billing;
mod command;
mod infra;
mod models;
mod query;
mod pool;
mod query;
mod robot;
use anyhow::Result;
+1 -2
View File
@@ -21,8 +21,7 @@ pub async fn create_pool() -> Result<SqlitePool> {
std::fs::create_dir_all(parent)?;
}
let connect_options =
SqliteConnectOptions::from_str(&database_url)?.create_if_missing(true);
let connect_options = SqliteConnectOptions::from_str(&database_url)?.create_if_missing(true);
let pool = SqlitePoolOptions::new()
.max_connections(5)