subcog/observability/
request_context.rs1use std::cell::RefCell;
4use std::future::Future;
5use uuid::Uuid;
6
7#[derive(Clone, Debug)]
9pub struct RequestContext {
10 request_id: String,
11}
12
13impl RequestContext {
14 #[must_use]
16 pub fn new() -> Self {
17 Self {
18 request_id: Uuid::new_v4().to_string(),
19 }
20 }
21
22 #[must_use]
24 pub fn from_id(request_id: impl Into<String>) -> Self {
25 Self {
26 request_id: request_id.into(),
27 }
28 }
29
30 #[must_use]
32 pub fn request_id(&self) -> &str {
33 &self.request_id
34 }
35}
36
37impl Default for RequestContext {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43tokio::task_local! {
44 static TASK_CONTEXT: RequestContext;
45}
46
47thread_local! {
48 static THREAD_CONTEXT: RefCell<Option<RequestContext>> = const { RefCell::new(None) };
49}
50
51pub struct RequestContextGuard {
53 previous: Option<RequestContext>,
54}
55
56impl Drop for RequestContextGuard {
57 fn drop(&mut self) {
58 THREAD_CONTEXT.with(|slot| {
59 *slot.borrow_mut() = self.previous.take();
60 });
61 }
62}
63
64#[must_use]
66pub fn enter_request_context(context: RequestContext) -> RequestContextGuard {
67 let previous = THREAD_CONTEXT.with(|slot| slot.borrow_mut().replace(context));
68 RequestContextGuard { previous }
69}
70
71pub async fn scope_request_context<F, T>(context: RequestContext, fut: F) -> T
73where
74 F: Future<Output = T>,
75{
76 TASK_CONTEXT
77 .scope(context.clone(), async move {
78 let _guard = enter_request_context(context);
79 fut.await
80 })
81 .await
82}
83
84#[must_use]
86pub fn current_request_id() -> Option<String> {
87 if let Ok(id) = TASK_CONTEXT.try_with(|ctx| ctx.request_id.clone()) {
88 return Some(id);
89 }
90
91 THREAD_CONTEXT.with(|slot| slot.borrow().as_ref().map(|ctx| ctx.request_id.clone()))
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn test_thread_context_guard_propagates_request_id() {
100 let context = RequestContext::from_id("thread-test");
101 let _guard = enter_request_context(context);
102 assert_eq!(current_request_id().as_deref(), Some("thread-test"));
103 }
104
105 #[tokio::test]
106 async fn test_scope_request_context_propagates_across_await() {
107 let context = RequestContext::from_id("async-test");
108 let observed = scope_request_context(context, async {
109 tokio::task::yield_now().await;
110 current_request_id()
111 })
112 .await;
113 assert_eq!(observed.as_deref(), Some("async-test"));
114 }
115}