1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
11};
12use mas_iana::oauth::PkceCodeChallengeMethod;
13use mas_storage::{Clock, oauth2::OAuth2AuthorizationGrantRepository};
14use oauth2_types::{requests::ResponseMode, scope::Scope};
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
22
23pub struct PgOAuth2AuthorizationGrantRepository<'c> {
26 conn: &'c mut PgConnection,
27}
28
29impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
30 pub fn new(conn: &'c mut PgConnection) -> Self {
33 Self { conn }
34 }
35}
36
37#[allow(clippy::struct_excessive_bools)]
38struct GrantLookup {
39 oauth2_authorization_grant_id: Uuid,
40 created_at: DateTime<Utc>,
41 cancelled_at: Option<DateTime<Utc>>,
42 fulfilled_at: Option<DateTime<Utc>>,
43 exchanged_at: Option<DateTime<Utc>>,
44 scope: String,
45 state: Option<String>,
46 nonce: Option<String>,
47 redirect_uri: String,
48 response_mode: String,
49 response_type_code: bool,
50 response_type_id_token: bool,
51 authorization_code: Option<String>,
52 code_challenge: Option<String>,
53 code_challenge_method: Option<String>,
54 login_hint: Option<String>,
55 locale: Option<String>,
56 oauth2_client_id: Uuid,
57 oauth2_session_id: Option<Uuid>,
58}
59
60impl TryFrom<GrantLookup> for AuthorizationGrant {
61 type Error = DatabaseInconsistencyError;
62
63 fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
64 let id = value.oauth2_authorization_grant_id.into();
65 let scope: Scope = value.scope.parse().map_err(|e| {
66 DatabaseInconsistencyError::on("oauth2_authorization_grants")
67 .column("scope")
68 .row(id)
69 .source(e)
70 })?;
71
72 let stage = match (
73 value.fulfilled_at,
74 value.exchanged_at,
75 value.cancelled_at,
76 value.oauth2_session_id,
77 ) {
78 (None, None, None, None) => AuthorizationGrantStage::Pending,
79 (Some(fulfilled_at), None, None, Some(session_id)) => {
80 AuthorizationGrantStage::Fulfilled {
81 session_id: session_id.into(),
82 fulfilled_at,
83 }
84 }
85 (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
86 AuthorizationGrantStage::Exchanged {
87 session_id: session_id.into(),
88 fulfilled_at,
89 exchanged_at,
90 }
91 }
92 (None, None, Some(cancelled_at), None) => {
93 AuthorizationGrantStage::Cancelled { cancelled_at }
94 }
95 _ => {
96 return Err(
97 DatabaseInconsistencyError::on("oauth2_authorization_grants")
98 .column("stage")
99 .row(id),
100 );
101 }
102 };
103
104 let pkce = match (value.code_challenge, value.code_challenge_method) {
105 (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
106 Some(Pkce {
107 challenge_method: PkceCodeChallengeMethod::Plain,
108 challenge,
109 })
110 }
111 (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
112 challenge_method: PkceCodeChallengeMethod::S256,
113 challenge,
114 }),
115 (None, None) => None,
116 _ => {
117 return Err(
118 DatabaseInconsistencyError::on("oauth2_authorization_grants")
119 .column("code_challenge_method")
120 .row(id),
121 );
122 }
123 };
124
125 let code: Option<AuthorizationCode> =
126 match (value.response_type_code, value.authorization_code, pkce) {
127 (false, None, None) => None,
128 (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
129 _ => {
130 return Err(
131 DatabaseInconsistencyError::on("oauth2_authorization_grants")
132 .column("authorization_code")
133 .row(id),
134 );
135 }
136 };
137
138 let redirect_uri = value.redirect_uri.parse().map_err(|e| {
139 DatabaseInconsistencyError::on("oauth2_authorization_grants")
140 .column("redirect_uri")
141 .row(id)
142 .source(e)
143 })?;
144
145 let response_mode = value.response_mode.parse().map_err(|e| {
146 DatabaseInconsistencyError::on("oauth2_authorization_grants")
147 .column("response_mode")
148 .row(id)
149 .source(e)
150 })?;
151
152 Ok(AuthorizationGrant {
153 id,
154 stage,
155 client_id: value.oauth2_client_id.into(),
156 code,
157 scope,
158 state: value.state,
159 nonce: value.nonce,
160 response_mode,
161 redirect_uri,
162 created_at: value.created_at,
163 response_type_id_token: value.response_type_id_token,
164 login_hint: value.login_hint,
165 locale: value.locale,
166 })
167 }
168}
169
170#[async_trait]
171impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
172 type Error = DatabaseError;
173
174 #[tracing::instrument(
175 name = "db.oauth2_authorization_grant.add",
176 skip_all,
177 fields(
178 db.query.text,
179 grant.id,
180 grant.scope = %scope,
181 %client.id,
182 ),
183 err,
184 )]
185 async fn add(
186 &mut self,
187 rng: &mut (dyn RngCore + Send),
188 clock: &dyn Clock,
189 client: &Client,
190 redirect_uri: Url,
191 scope: Scope,
192 code: Option<AuthorizationCode>,
193 state: Option<String>,
194 nonce: Option<String>,
195 response_mode: ResponseMode,
196 response_type_id_token: bool,
197 login_hint: Option<String>,
198 locale: Option<String>,
199 ) -> Result<AuthorizationGrant, Self::Error> {
200 let code_challenge = code
201 .as_ref()
202 .and_then(|c| c.pkce.as_ref())
203 .map(|p| &p.challenge);
204 let code_challenge_method = code
205 .as_ref()
206 .and_then(|c| c.pkce.as_ref())
207 .map(|p| p.challenge_method.to_string());
208 let code_str = code.as_ref().map(|c| &c.code);
209
210 let created_at = clock.now();
211 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
212 tracing::Span::current().record("grant.id", tracing::field::display(id));
213
214 sqlx::query!(
215 r#"
216 INSERT INTO oauth2_authorization_grants (
217 oauth2_authorization_grant_id,
218 oauth2_client_id,
219 redirect_uri,
220 scope,
221 state,
222 nonce,
223 response_mode,
224 code_challenge,
225 code_challenge_method,
226 response_type_code,
227 response_type_id_token,
228 authorization_code,
229 login_hint,
230 locale,
231 created_at
232 )
233 VALUES
234 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
235 "#,
236 Uuid::from(id),
237 Uuid::from(client.id),
238 redirect_uri.to_string(),
239 scope.to_string(),
240 state,
241 nonce,
242 response_mode.to_string(),
243 code_challenge,
244 code_challenge_method,
245 code.is_some(),
246 response_type_id_token,
247 code_str,
248 login_hint,
249 locale,
250 created_at,
251 )
252 .traced()
253 .execute(&mut *self.conn)
254 .await?;
255
256 Ok(AuthorizationGrant {
257 id,
258 stage: AuthorizationGrantStage::Pending,
259 code,
260 redirect_uri,
261 client_id: client.id,
262 scope,
263 state,
264 nonce,
265 response_mode,
266 created_at,
267 response_type_id_token,
268 login_hint,
269 locale,
270 })
271 }
272
273 #[tracing::instrument(
274 name = "db.oauth2_authorization_grant.lookup",
275 skip_all,
276 fields(
277 db.query.text,
278 grant.id = %id,
279 ),
280 err,
281 )]
282 async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
283 let res = sqlx::query_as!(
284 GrantLookup,
285 r#"
286 SELECT oauth2_authorization_grant_id
287 , created_at
288 , cancelled_at
289 , fulfilled_at
290 , exchanged_at
291 , scope
292 , state
293 , redirect_uri
294 , response_mode
295 , nonce
296 , oauth2_client_id
297 , authorization_code
298 , response_type_code
299 , response_type_id_token
300 , code_challenge
301 , code_challenge_method
302 , login_hint
303 , locale
304 , oauth2_session_id
305 FROM
306 oauth2_authorization_grants
307
308 WHERE oauth2_authorization_grant_id = $1
309 "#,
310 Uuid::from(id),
311 )
312 .traced()
313 .fetch_optional(&mut *self.conn)
314 .await?;
315
316 let Some(res) = res else { return Ok(None) };
317
318 Ok(Some(res.try_into()?))
319 }
320
321 #[tracing::instrument(
322 name = "db.oauth2_authorization_grant.find_by_code",
323 skip_all,
324 fields(
325 db.query.text,
326 ),
327 err,
328 )]
329 async fn find_by_code(
330 &mut self,
331 code: &str,
332 ) -> Result<Option<AuthorizationGrant>, Self::Error> {
333 let res = sqlx::query_as!(
334 GrantLookup,
335 r#"
336 SELECT oauth2_authorization_grant_id
337 , created_at
338 , cancelled_at
339 , fulfilled_at
340 , exchanged_at
341 , scope
342 , state
343 , redirect_uri
344 , response_mode
345 , nonce
346 , oauth2_client_id
347 , authorization_code
348 , response_type_code
349 , response_type_id_token
350 , code_challenge
351 , code_challenge_method
352 , login_hint
353 , locale
354 , oauth2_session_id
355 FROM
356 oauth2_authorization_grants
357
358 WHERE authorization_code = $1
359 "#,
360 code,
361 )
362 .traced()
363 .fetch_optional(&mut *self.conn)
364 .await?;
365
366 let Some(res) = res else { return Ok(None) };
367
368 Ok(Some(res.try_into()?))
369 }
370
371 #[tracing::instrument(
372 name = "db.oauth2_authorization_grant.fulfill",
373 skip_all,
374 fields(
375 db.query.text,
376 %grant.id,
377 client.id = %grant.client_id,
378 %session.id,
379 ),
380 err,
381 )]
382 async fn fulfill(
383 &mut self,
384 clock: &dyn Clock,
385 session: &Session,
386 grant: AuthorizationGrant,
387 ) -> Result<AuthorizationGrant, Self::Error> {
388 let fulfilled_at = clock.now();
389 let res = sqlx::query!(
390 r#"
391 UPDATE oauth2_authorization_grants
392 SET fulfilled_at = $2
393 , oauth2_session_id = $3
394 WHERE oauth2_authorization_grant_id = $1
395 "#,
396 Uuid::from(grant.id),
397 fulfilled_at,
398 Uuid::from(session.id),
399 )
400 .traced()
401 .execute(&mut *self.conn)
402 .await?;
403
404 DatabaseError::ensure_affected_rows(&res, 1)?;
405
406 let grant = grant
408 .fulfill(fulfilled_at, session)
409 .map_err(DatabaseError::to_invalid_operation)?;
410
411 Ok(grant)
412 }
413
414 #[tracing::instrument(
415 name = "db.oauth2_authorization_grant.exchange",
416 skip_all,
417 fields(
418 db.query.text,
419 %grant.id,
420 client.id = %grant.client_id,
421 ),
422 err,
423 )]
424 async fn exchange(
425 &mut self,
426 clock: &dyn Clock,
427 grant: AuthorizationGrant,
428 ) -> Result<AuthorizationGrant, Self::Error> {
429 let exchanged_at = clock.now();
430 let res = sqlx::query!(
431 r#"
432 UPDATE oauth2_authorization_grants
433 SET exchanged_at = $2
434 WHERE oauth2_authorization_grant_id = $1
435 "#,
436 Uuid::from(grant.id),
437 exchanged_at,
438 )
439 .traced()
440 .execute(&mut *self.conn)
441 .await?;
442
443 DatabaseError::ensure_affected_rows(&res, 1)?;
444
445 let grant = grant
446 .exchange(exchanged_at)
447 .map_err(DatabaseError::to_invalid_operation)?;
448
449 Ok(grant)
450 }
451}