Refactor billing to manage subscriptions/invoices internally

This commit is contained in:
Jon Staab
2026-05-26 14:25:21 -07:00
parent 28cd7b0a9a
commit 7a2baf6f82
23 changed files with 1464 additions and 1694 deletions
+68 -265
View File
@@ -7,91 +7,21 @@
use anyhow::{Result, anyhow};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::collections::BTreeMap;
use crate::env::Env;
use crate::env;
const STRIPE_API: &str = "https://api.stripe.com/v1";
// Webhooks
const WEBHOOK_TOLERANCE_SECS: i64 = 300;
#[derive(serde::Deserialize)]
pub struct StripeWebhookEvent {
#[serde(rename = "type")]
pub event_type: String,
pub data: StripeWebhookEventData,
}
#[derive(serde::Deserialize)]
pub struct StripeWebhookEventData {
pub object: serde_json::Value,
}
// API return types
#[derive(serde::Deserialize)]
pub struct StripeSubscription {
pub id: String,
pub status: String,
#[serde(deserialize_with = "deserialize_list")]
pub items: Vec<StripeSubscriptionItem>,
}
#[derive(serde::Deserialize)]
pub struct StripeSubscriptionItem {
pub id: String,
pub price: StripePrice,
#[serde(default = "default_quantity")]
pub quantity: i64,
}
#[derive(serde::Deserialize)]
pub struct StripePrice {
pub id: String,
}
#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct StripeInvoice {
pub id: String,
pub customer: String,
pub status: String,
pub amount_due: i64,
pub currency: String,
pub period_start: i64,
pub period_end: i64,
}
#[derive(serde::Deserialize)]
struct StripeList<T> {
data: Vec<T>,
}
fn deserialize_list<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
where
D: serde::Deserializer<'de>,
T: serde::Deserialize<'de>,
{
Ok(<StripeList<T> as serde::Deserialize>::deserialize(deserializer)?.data)
}
fn default_quantity() -> i64 {
1
}
// Stripe struct and impl
#[derive(Clone)]
pub struct Stripe {
env: Env,
http: reqwest::Client,
}
impl Stripe {
pub fn new(env: &Env) -> Self {
pub fn new() -> Self {
Self {
env: env.clone(),
http: reqwest::Client::new(),
}
}
@@ -101,23 +31,17 @@ impl Stripe {
fn get(&self, path: &str) -> reqwest::RequestBuilder {
self.http
.get(format!("{STRIPE_API}{path}"))
.bearer_auth(&self.env.stripe_secret_key)
.bearer_auth(&env::get().stripe_secret_key)
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
self.http
.post(format!("{STRIPE_API}{path}"))
.bearer_auth(&self.env.stripe_secret_key)
}
fn delete(&self, path: &str) -> reqwest::RequestBuilder {
self.http
.delete(format!("{STRIPE_API}{path}"))
.bearer_auth(&self.env.stripe_secret_key)
.bearer_auth(&env::get().stripe_secret_key)
}
fn idempotency_key(&self, parts: &[&str]) -> String {
let mut mac = Hmac::<Sha256>::new_from_slice(self.env.stripe_secret_key.as_bytes())
let mut mac = Hmac::<Sha256>::new_from_slice(env::get().stripe_secret_key.as_bytes())
.expect("HMAC accepts any key length");
for (i, part) in parts.iter().enumerate() {
if i > 0 {
@@ -146,153 +70,74 @@ impl Stripe {
Ok(customer_id.to_string())
}
// --- Subscriptions ---
pub async fn get_subscription(
&self,
subscription_id: &str,
) -> Result<Option<StripeSubscription>> {
let body = self
.get(&format!("/subscriptions/{subscription_id}"))
.send_optional_json()
.await?;
body.map(serde_json::from_value)
.transpose()
.map_err(Into::into)
}
/// Stripe requires at least one item to create a subscription, so the desired
/// items are sent inline here; [`crate::billing`] reconciles from there.
pub async fn create_subscription(
&self,
customer_id: &str,
items: &BTreeMap<String, i64>,
) -> Result<StripeSubscription> {
let mut form: Vec<(String, String)> = vec![
("customer".to_string(), customer_id.to_string()),
(
"collection_method".to_string(),
"charge_automatically".to_string(),
),
];
let mut key_parts: Vec<String> =
vec!["create_subscription".to_string(), customer_id.to_string()];
for (index, (price_id, quantity)) in items.iter().enumerate() {
form.push((format!("items[{index}][price]"), price_id.clone()));
form.push((format!("items[{index}][quantity]"), quantity.to_string()));
key_parts.push(format!("{price_id}={quantity}"));
}
let key_refs: Vec<&str> = key_parts.iter().map(String::as_str).collect();
Ok(self
.post("/subscriptions")
.header("Idempotency-Key", self.idempotency_key(&key_refs))
.form(&form)
.send_ok()
.await?
.json()
.await?)
}
pub async fn create_subscription_item(
&self,
subscription_id: &str,
price_id: &str,
quantity: i64,
) -> Result<()> {
let quantity = quantity.to_string();
self.post("/subscription_items")
.header(
"Idempotency-Key",
self.idempotency_key(&["create_subscription_item", subscription_id, price_id]),
)
.form(&[
("subscription", subscription_id),
("price", price_id),
("quantity", quantity.as_str()),
])
.send_ok()
.await?;
Ok(())
}
pub async fn set_subscription_item_quantity(&self, item_id: &str, quantity: i64) -> Result<()> {
self.post(&format!("/subscription_items/{item_id}"))
.form(&[("quantity", quantity.to_string())])
.send_ok()
.await?;
Ok(())
}
pub async fn delete_subscription_item(&self, item_id: &str) -> Result<()> {
self.delete(&format!("/subscription_items/{item_id}"))
.send_ok()
.await?;
Ok(())
}
pub async fn cancel_subscription(&self, subscription_id: &str) -> Result<()> {
self.delete(&format!("/subscriptions/{subscription_id}"))
.send_ok()
.await?;
Ok(())
}
// --- Invoices ---
pub async fn list_invoices(&self, customer_id: &str) -> Result<Vec<StripeInvoice>> {
let list: StripeList<StripeInvoice> = self
.get("/invoices")
.query(&[("customer", customer_id)])
.send_ok()
.await?
.json()
.await?;
Ok(list.data)
}
pub async fn get_invoice(&self, invoice_id: &str) -> Result<Option<StripeInvoice>> {
let body = self
.get(&format!("/invoices/{invoice_id}"))
.send_optional_json()
.await?;
body.map(serde_json::from_value)
.transpose()
.map_err(Into::into)
}
pub async fn pay_invoice(&self, invoice_id: &str) -> Result<()> {
self.post(&format!("/invoices/{invoice_id}/pay"))
.header(
"Idempotency-Key",
self.idempotency_key(&["pay_invoice", invoice_id]),
)
.send_ok()
.await?;
Ok(())
}
pub async fn pay_invoice_out_of_band(&self, invoice_id: &str) -> Result<()> {
self.post(&format!("/invoices/{invoice_id}/pay"))
.header(
"Idempotency-Key",
self.idempotency_key(&["pay_invoice_oob", invoice_id]),
)
.form(&[("paid_out_of_band", "true")])
.send_ok()
.await?;
Ok(())
}
// --- Payment methods ---
pub async fn has_payment_method(&self, customer_id: &str) -> Result<bool> {
/// Return the id of the customer's first saved payment method, or `None` if
/// they have none. The returned `pm_…` id can be charged off-session via
/// [`Self::create_payment_intent`]. We don't track a Stripe "default" payment
/// method, so the first one Stripe lists is the one we'll charge.
pub async fn get_saved_payment_method(&self, customer_id: &str) -> Result<Option<String>> {
let body = self
.get("/payment_methods")
.query(&[("customer", customer_id), ("type", "card")])
.send_json()
.await?;
Ok(body["data"].as_array().is_some_and(|a| !a.is_empty()))
Ok(body["data"]
.as_array()
.and_then(|methods| methods.first())
.and_then(|method| method["id"].as_str())
.map(str::to_string))
}
// --- Intents ---
/// Create and immediately confirm an off-session PaymentIntent charging a
/// saved payment method. `amount` is in the currency's minor units (cents for
/// `usd`). Returns the PaymentIntent id on success.
///
/// A decline or an issuer authentication demand (`authentication_required`,
/// which we can't satisfy off-session) comes back from Stripe as an HTTP
/// error, so the caller naturally falls through to another payment method.
/// The charge is made idempotent on `invoice_id`, so a retried collection
/// reuses the same charge instead of billing the payment method twice.
pub async fn create_payment_intent(
&self,
customer_id: &str,
payment_method_id: &str,
invoice_id: &str,
amount: i64,
currency: &str,
) -> Result<String> {
let amount = amount.to_string();
let body = self
.post("/payment_intents")
.header(
"Idempotency-Key",
self.idempotency_key(&["payment_intent", invoice_id]),
)
.form(&[
("amount", amount.as_str()),
("currency", currency),
("customer", customer_id),
("payment_method", payment_method_id),
("off_session", "true"),
("confirm", "true"),
])
.send_json()
.await?;
// A successful off-session charge settles synchronously. Anything
// else (e.g. `requires_action`) can't be completed without the customer,
// so treat it as a failure and let the caller fall back.
let status = body["status"].as_str().unwrap_or_default();
if status != "succeeded" {
return Err(anyhow!("payment intent not succeeded (status: {status})"));
}
body["id"]
.as_str()
.map(str::to_string)
.ok_or_else(|| anyhow!("missing payment intent id"))
}
// --- Portal ---
@@ -316,47 +161,13 @@ impl Stripe {
.map(str::to_string)
.ok_or_else(|| anyhow!("missing portal session url"))
}
// --- Webhooks ---
pub fn get_webhook_event(&self, payload: &str, signature: &str) -> Result<StripeWebhookEvent> {
let mut timestamp = None;
let mut sig = None;
for part in signature.split(',') {
if let Some(t) = part.strip_prefix("t=") {
timestamp = Some(t);
} else if let Some(v) = part.strip_prefix("v1=") {
sig = Some(v);
}
}
let timestamp = timestamp.ok_or_else(|| anyhow!("missing webhook timestamp"))?;
let signature = sig.ok_or_else(|| anyhow!("missing webhook signature"))?;
let signed_payload = format!("{timestamp}.{payload}");
let mut mac = Hmac::<Sha256>::new_from_slice(self.env.stripe_webhook_secret.as_bytes())
.map_err(|e| anyhow!("invalid webhook secret: {e}"))?;
mac.update(signed_payload.as_bytes());
let expected = hex::encode(mac.finalize().into_bytes());
if expected != signature {
return Err(anyhow!("webhook signature mismatch"));
}
let ts: i64 = timestamp
.parse()
.map_err(|_| anyhow!("bad webhook timestamp"))?;
let now = chrono::Utc::now().timestamp();
if (now - ts).abs() > WEBHOOK_TOLERANCE_SECS {
return Err(anyhow!("webhook timestamp outside tolerance"));
}
Ok(serde_json::from_str(payload)?)
}
}
// Stripe request util
trait StripeRequest {
async fn send_ok(self) -> Result<reqwest::Response>;
async fn send_json(self) -> Result<serde_json::Value>;
async fn send_optional_json(self) -> Result<Option<serde_json::Value>>;
}
impl StripeRequest for reqwest::RequestBuilder {
@@ -367,14 +178,6 @@ impl StripeRequest for reqwest::RequestBuilder {
async fn send_json(self) -> Result<serde_json::Value> {
Ok(self.send_ok().await?.json().await?)
}
async fn send_optional_json(self) -> Result<Option<serde_json::Value>> {
let resp = self.send().await?;
if resp.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
Ok(Some(error_for_status(resp).await?.json().await?))
}
}
/// Give callers an actionable message instead of a bare "400 Bad Request"