1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11 Clock, Page, Pagination,
12 upstream_oauth2::{
13 UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
14 },
15};
16use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::{PgConnection, types::Json};
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26 DatabaseError, DatabaseInconsistencyError,
27 filter::{Filter, StatementExt},
28 iden::UpstreamOAuthProviders,
29 pagination::QueryBuilderExt,
30 tracing::ExecuteExt,
31};
32
33pub struct PgUpstreamOAuthProviderRepository<'c> {
36 conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthProviderRepository<'c> {
40 pub fn new(conn: &'c mut PgConnection) -> Self {
43 Self { conn }
44 }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct ProviderLookup {
50 upstream_oauth_provider_id: Uuid,
51 issuer: Option<String>,
52 human_name: Option<String>,
53 brand_name: Option<String>,
54 scope: String,
55 client_id: String,
56 encrypted_client_secret: Option<String>,
57 token_endpoint_signing_alg: Option<String>,
58 token_endpoint_auth_method: String,
59 id_token_signed_response_alg: String,
60 fetch_userinfo: bool,
61 userinfo_signed_response_alg: Option<String>,
62 created_at: DateTime<Utc>,
63 disabled_at: Option<DateTime<Utc>>,
64 claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
65 jwks_uri_override: Option<String>,
66 authorization_endpoint_override: Option<String>,
67 token_endpoint_override: Option<String>,
68 userinfo_endpoint_override: Option<String>,
69 discovery_mode: String,
70 pkce_mode: String,
71 response_mode: Option<String>,
72 additional_parameters: Option<Json<Vec<(String, String)>>>,
73 forward_login_hint: bool,
74 on_backchannel_logout: String,
75}
76
77impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
78 type Error = DatabaseInconsistencyError;
79
80 fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
81 let id = value.upstream_oauth_provider_id.into();
82 let scope = value.scope.parse().map_err(|e| {
83 DatabaseInconsistencyError::on("upstream_oauth_providers")
84 .column("scope")
85 .row(id)
86 .source(e)
87 })?;
88 let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
89 DatabaseInconsistencyError::on("upstream_oauth_providers")
90 .column("token_endpoint_auth_method")
91 .row(id)
92 .source(e)
93 })?;
94 let token_endpoint_signing_alg = value
95 .token_endpoint_signing_alg
96 .map(|x| x.parse())
97 .transpose()
98 .map_err(|e| {
99 DatabaseInconsistencyError::on("upstream_oauth_providers")
100 .column("token_endpoint_signing_alg")
101 .row(id)
102 .source(e)
103 })?;
104 let id_token_signed_response_alg =
105 value.id_token_signed_response_alg.parse().map_err(|e| {
106 DatabaseInconsistencyError::on("upstream_oauth_providers")
107 .column("id_token_signed_response_alg")
108 .row(id)
109 .source(e)
110 })?;
111
112 let userinfo_signed_response_alg = value
113 .userinfo_signed_response_alg
114 .map(|x| x.parse())
115 .transpose()
116 .map_err(|e| {
117 DatabaseInconsistencyError::on("upstream_oauth_providers")
118 .column("userinfo_signed_response_alg")
119 .row(id)
120 .source(e)
121 })?;
122
123 let authorization_endpoint_override = value
124 .authorization_endpoint_override
125 .map(|x| x.parse())
126 .transpose()
127 .map_err(|e| {
128 DatabaseInconsistencyError::on("upstream_oauth_providers")
129 .column("authorization_endpoint_override")
130 .row(id)
131 .source(e)
132 })?;
133
134 let token_endpoint_override = value
135 .token_endpoint_override
136 .map(|x| x.parse())
137 .transpose()
138 .map_err(|e| {
139 DatabaseInconsistencyError::on("upstream_oauth_providers")
140 .column("token_endpoint_override")
141 .row(id)
142 .source(e)
143 })?;
144
145 let userinfo_endpoint_override = value
146 .userinfo_endpoint_override
147 .map(|x| x.parse())
148 .transpose()
149 .map_err(|e| {
150 DatabaseInconsistencyError::on("upstream_oauth_providers")
151 .column("userinfo_endpoint_override")
152 .row(id)
153 .source(e)
154 })?;
155
156 let jwks_uri_override = value
157 .jwks_uri_override
158 .map(|x| x.parse())
159 .transpose()
160 .map_err(|e| {
161 DatabaseInconsistencyError::on("upstream_oauth_providers")
162 .column("jwks_uri_override")
163 .row(id)
164 .source(e)
165 })?;
166
167 let discovery_mode = value.discovery_mode.parse().map_err(|e| {
168 DatabaseInconsistencyError::on("upstream_oauth_providers")
169 .column("discovery_mode")
170 .row(id)
171 .source(e)
172 })?;
173
174 let pkce_mode = value.pkce_mode.parse().map_err(|e| {
175 DatabaseInconsistencyError::on("upstream_oauth_providers")
176 .column("pkce_mode")
177 .row(id)
178 .source(e)
179 })?;
180
181 let response_mode = value
182 .response_mode
183 .map(|x| x.parse())
184 .transpose()
185 .map_err(|e| {
186 DatabaseInconsistencyError::on("upstream_oauth_providers")
187 .column("response_mode")
188 .row(id)
189 .source(e)
190 })?;
191
192 let additional_authorization_parameters = value
193 .additional_parameters
194 .map(|Json(x)| x)
195 .unwrap_or_default();
196
197 let on_backchannel_logout = value.on_backchannel_logout.parse().map_err(|e| {
198 DatabaseInconsistencyError::on("upstream_oauth_providers")
199 .column("on_backchannel_logout")
200 .row(id)
201 .source(e)
202 })?;
203
204 Ok(UpstreamOAuthProvider {
205 id,
206 issuer: value.issuer,
207 human_name: value.human_name,
208 brand_name: value.brand_name,
209 scope,
210 client_id: value.client_id,
211 encrypted_client_secret: value.encrypted_client_secret,
212 token_endpoint_auth_method,
213 token_endpoint_signing_alg,
214 id_token_signed_response_alg,
215 fetch_userinfo: value.fetch_userinfo,
216 userinfo_signed_response_alg,
217 created_at: value.created_at,
218 disabled_at: value.disabled_at,
219 claims_imports: value.claims_imports.0,
220 authorization_endpoint_override,
221 token_endpoint_override,
222 userinfo_endpoint_override,
223 jwks_uri_override,
224 discovery_mode,
225 pkce_mode,
226 response_mode,
227 additional_authorization_parameters,
228 forward_login_hint: value.forward_login_hint,
229 on_backchannel_logout,
230 })
231 }
232}
233
234impl Filter for UpstreamOAuthProviderFilter<'_> {
235 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
236 sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
237 Expr::col((
238 UpstreamOAuthProviders::Table,
239 UpstreamOAuthProviders::DisabledAt,
240 ))
241 .is_null()
242 .eq(enabled)
243 }))
244 }
245}
246
247#[async_trait]
248impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
249 type Error = DatabaseError;
250
251 #[tracing::instrument(
252 name = "db.upstream_oauth_provider.lookup",
253 skip_all,
254 fields(
255 db.query.text,
256 upstream_oauth_provider.id = %id,
257 ),
258 err,
259 )]
260 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
261 let res = sqlx::query_as!(
262 ProviderLookup,
263 r#"
264 SELECT
265 upstream_oauth_provider_id,
266 issuer,
267 human_name,
268 brand_name,
269 scope,
270 client_id,
271 encrypted_client_secret,
272 token_endpoint_signing_alg,
273 token_endpoint_auth_method,
274 id_token_signed_response_alg,
275 fetch_userinfo,
276 userinfo_signed_response_alg,
277 created_at,
278 disabled_at,
279 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
280 jwks_uri_override,
281 authorization_endpoint_override,
282 token_endpoint_override,
283 userinfo_endpoint_override,
284 discovery_mode,
285 pkce_mode,
286 response_mode,
287 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
288 forward_login_hint,
289 on_backchannel_logout
290 FROM upstream_oauth_providers
291 WHERE upstream_oauth_provider_id = $1
292 "#,
293 Uuid::from(id),
294 )
295 .traced()
296 .fetch_optional(&mut *self.conn)
297 .await?;
298
299 let res = res
300 .map(UpstreamOAuthProvider::try_from)
301 .transpose()
302 .map_err(DatabaseError::from)?;
303
304 Ok(res)
305 }
306
307 #[tracing::instrument(
308 name = "db.upstream_oauth_provider.add",
309 skip_all,
310 fields(
311 db.query.text,
312 upstream_oauth_provider.id,
313 upstream_oauth_provider.issuer = params.issuer,
314 upstream_oauth_provider.client_id = %params.client_id,
315 ),
316 err,
317 )]
318 async fn add(
319 &mut self,
320 rng: &mut (dyn RngCore + Send),
321 clock: &dyn Clock,
322 params: UpstreamOAuthProviderParams,
323 ) -> Result<UpstreamOAuthProvider, Self::Error> {
324 let created_at = clock.now();
325 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
326 tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
327
328 sqlx::query!(
329 r#"
330 INSERT INTO upstream_oauth_providers (
331 upstream_oauth_provider_id,
332 issuer,
333 human_name,
334 brand_name,
335 scope,
336 token_endpoint_auth_method,
337 token_endpoint_signing_alg,
338 id_token_signed_response_alg,
339 fetch_userinfo,
340 userinfo_signed_response_alg,
341 client_id,
342 encrypted_client_secret,
343 claims_imports,
344 authorization_endpoint_override,
345 token_endpoint_override,
346 userinfo_endpoint_override,
347 jwks_uri_override,
348 discovery_mode,
349 pkce_mode,
350 response_mode,
351 forward_login_hint,
352 on_backchannel_logout,
353 created_at
354 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
355 $12, $13, $14, $15, $16, $17, $18, $19, $20,
356 $21, $22, $23)
357 "#,
358 Uuid::from(id),
359 params.issuer.as_deref(),
360 params.human_name.as_deref(),
361 params.brand_name.as_deref(),
362 params.scope.to_string(),
363 params.token_endpoint_auth_method.to_string(),
364 params
365 .token_endpoint_signing_alg
366 .as_ref()
367 .map(ToString::to_string),
368 params.id_token_signed_response_alg.to_string(),
369 params.fetch_userinfo,
370 params
371 .userinfo_signed_response_alg
372 .as_ref()
373 .map(ToString::to_string),
374 ¶ms.client_id,
375 params.encrypted_client_secret.as_deref(),
376 Json(¶ms.claims_imports) as _,
377 params
378 .authorization_endpoint_override
379 .as_ref()
380 .map(ToString::to_string),
381 params
382 .token_endpoint_override
383 .as_ref()
384 .map(ToString::to_string),
385 params
386 .userinfo_endpoint_override
387 .as_ref()
388 .map(ToString::to_string),
389 params.jwks_uri_override.as_ref().map(ToString::to_string),
390 params.discovery_mode.as_str(),
391 params.pkce_mode.as_str(),
392 params.response_mode.as_ref().map(ToString::to_string),
393 params.forward_login_hint,
394 params.on_backchannel_logout.as_str(),
395 created_at,
396 )
397 .traced()
398 .execute(&mut *self.conn)
399 .await?;
400
401 Ok(UpstreamOAuthProvider {
402 id,
403 issuer: params.issuer,
404 human_name: params.human_name,
405 brand_name: params.brand_name,
406 scope: params.scope,
407 client_id: params.client_id,
408 encrypted_client_secret: params.encrypted_client_secret,
409 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
410 token_endpoint_auth_method: params.token_endpoint_auth_method,
411 id_token_signed_response_alg: params.id_token_signed_response_alg,
412 fetch_userinfo: params.fetch_userinfo,
413 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
414 created_at,
415 disabled_at: None,
416 claims_imports: params.claims_imports,
417 authorization_endpoint_override: params.authorization_endpoint_override,
418 token_endpoint_override: params.token_endpoint_override,
419 userinfo_endpoint_override: params.userinfo_endpoint_override,
420 jwks_uri_override: params.jwks_uri_override,
421 discovery_mode: params.discovery_mode,
422 pkce_mode: params.pkce_mode,
423 response_mode: params.response_mode,
424 additional_authorization_parameters: params.additional_authorization_parameters,
425 on_backchannel_logout: params.on_backchannel_logout,
426 forward_login_hint: params.forward_login_hint,
427 })
428 }
429
430 #[tracing::instrument(
431 name = "db.upstream_oauth_provider.delete_by_id",
432 skip_all,
433 fields(
434 db.query.text,
435 upstream_oauth_provider.id = %id,
436 ),
437 err,
438 )]
439 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
440 {
443 let span = info_span!(
444 "db.oauth2_client.delete_by_id.authorization_sessions",
445 upstream_oauth_provider.id = %id,
446 { DB_QUERY_TEXT } = tracing::field::Empty,
447 );
448 sqlx::query!(
449 r#"
450 DELETE FROM upstream_oauth_authorization_sessions
451 WHERE upstream_oauth_provider_id = $1
452 "#,
453 Uuid::from(id),
454 )
455 .record(&span)
456 .execute(&mut *self.conn)
457 .instrument(span)
458 .await?;
459 }
460
461 {
464 let span = info_span!(
465 "db.oauth2_client.delete_by_id.links",
466 upstream_oauth_provider.id = %id,
467 { DB_QUERY_TEXT } = tracing::field::Empty,
468 );
469 sqlx::query!(
470 r#"
471 DELETE FROM upstream_oauth_links
472 WHERE upstream_oauth_provider_id = $1
473 "#,
474 Uuid::from(id),
475 )
476 .record(&span)
477 .execute(&mut *self.conn)
478 .instrument(span)
479 .await?;
480 }
481
482 let res = sqlx::query!(
483 r#"
484 DELETE FROM upstream_oauth_providers
485 WHERE upstream_oauth_provider_id = $1
486 "#,
487 Uuid::from(id),
488 )
489 .traced()
490 .execute(&mut *self.conn)
491 .await?;
492
493 DatabaseError::ensure_affected_rows(&res, 1)
494 }
495
496 #[tracing::instrument(
497 name = "db.upstream_oauth_provider.add",
498 skip_all,
499 fields(
500 db.query.text,
501 upstream_oauth_provider.id = %id,
502 upstream_oauth_provider.issuer = params.issuer,
503 upstream_oauth_provider.client_id = %params.client_id,
504 ),
505 err,
506 )]
507 async fn upsert(
508 &mut self,
509 clock: &dyn Clock,
510 id: Ulid,
511 params: UpstreamOAuthProviderParams,
512 ) -> Result<UpstreamOAuthProvider, Self::Error> {
513 let created_at = clock.now();
514
515 let created_at = sqlx::query_scalar!(
516 r#"
517 INSERT INTO upstream_oauth_providers (
518 upstream_oauth_provider_id,
519 issuer,
520 human_name,
521 brand_name,
522 scope,
523 token_endpoint_auth_method,
524 token_endpoint_signing_alg,
525 id_token_signed_response_alg,
526 fetch_userinfo,
527 userinfo_signed_response_alg,
528 client_id,
529 encrypted_client_secret,
530 claims_imports,
531 authorization_endpoint_override,
532 token_endpoint_override,
533 userinfo_endpoint_override,
534 jwks_uri_override,
535 discovery_mode,
536 pkce_mode,
537 response_mode,
538 additional_parameters,
539 forward_login_hint,
540 ui_order,
541 on_backchannel_logout,
542 created_at
543 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
544 $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
545 $21, $22, $23, $24, $25)
546 ON CONFLICT (upstream_oauth_provider_id)
547 DO UPDATE
548 SET
549 issuer = EXCLUDED.issuer,
550 human_name = EXCLUDED.human_name,
551 brand_name = EXCLUDED.brand_name,
552 scope = EXCLUDED.scope,
553 token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
554 token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
555 id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
556 fetch_userinfo = EXCLUDED.fetch_userinfo,
557 userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
558 disabled_at = NULL,
559 client_id = EXCLUDED.client_id,
560 encrypted_client_secret = EXCLUDED.encrypted_client_secret,
561 claims_imports = EXCLUDED.claims_imports,
562 authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
563 token_endpoint_override = EXCLUDED.token_endpoint_override,
564 userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
565 jwks_uri_override = EXCLUDED.jwks_uri_override,
566 discovery_mode = EXCLUDED.discovery_mode,
567 pkce_mode = EXCLUDED.pkce_mode,
568 response_mode = EXCLUDED.response_mode,
569 additional_parameters = EXCLUDED.additional_parameters,
570 forward_login_hint = EXCLUDED.forward_login_hint,
571 ui_order = EXCLUDED.ui_order,
572 on_backchannel_logout = EXCLUDED.on_backchannel_logout
573 RETURNING created_at
574 "#,
575 Uuid::from(id),
576 params.issuer.as_deref(),
577 params.human_name.as_deref(),
578 params.brand_name.as_deref(),
579 params.scope.to_string(),
580 params.token_endpoint_auth_method.to_string(),
581 params
582 .token_endpoint_signing_alg
583 .as_ref()
584 .map(ToString::to_string),
585 params.id_token_signed_response_alg.to_string(),
586 params.fetch_userinfo,
587 params
588 .userinfo_signed_response_alg
589 .as_ref()
590 .map(ToString::to_string),
591 ¶ms.client_id,
592 params.encrypted_client_secret.as_deref(),
593 Json(¶ms.claims_imports) as _,
594 params
595 .authorization_endpoint_override
596 .as_ref()
597 .map(ToString::to_string),
598 params
599 .token_endpoint_override
600 .as_ref()
601 .map(ToString::to_string),
602 params
603 .userinfo_endpoint_override
604 .as_ref()
605 .map(ToString::to_string),
606 params.jwks_uri_override.as_ref().map(ToString::to_string),
607 params.discovery_mode.as_str(),
608 params.pkce_mode.as_str(),
609 params.response_mode.as_ref().map(ToString::to_string),
610 Json(¶ms.additional_authorization_parameters) as _,
611 params.forward_login_hint,
612 params.ui_order,
613 params.on_backchannel_logout.as_str(),
614 created_at,
615 )
616 .traced()
617 .fetch_one(&mut *self.conn)
618 .await?;
619
620 Ok(UpstreamOAuthProvider {
621 id,
622 issuer: params.issuer,
623 human_name: params.human_name,
624 brand_name: params.brand_name,
625 scope: params.scope,
626 client_id: params.client_id,
627 encrypted_client_secret: params.encrypted_client_secret,
628 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
629 token_endpoint_auth_method: params.token_endpoint_auth_method,
630 id_token_signed_response_alg: params.id_token_signed_response_alg,
631 fetch_userinfo: params.fetch_userinfo,
632 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
633 created_at,
634 disabled_at: None,
635 claims_imports: params.claims_imports,
636 authorization_endpoint_override: params.authorization_endpoint_override,
637 token_endpoint_override: params.token_endpoint_override,
638 userinfo_endpoint_override: params.userinfo_endpoint_override,
639 jwks_uri_override: params.jwks_uri_override,
640 discovery_mode: params.discovery_mode,
641 pkce_mode: params.pkce_mode,
642 response_mode: params.response_mode,
643 additional_authorization_parameters: params.additional_authorization_parameters,
644 forward_login_hint: params.forward_login_hint,
645 on_backchannel_logout: params.on_backchannel_logout,
646 })
647 }
648
649 #[tracing::instrument(
650 name = "db.upstream_oauth_provider.disable",
651 skip_all,
652 fields(
653 db.query.text,
654 %upstream_oauth_provider.id,
655 ),
656 err,
657 )]
658 async fn disable(
659 &mut self,
660 clock: &dyn Clock,
661 mut upstream_oauth_provider: UpstreamOAuthProvider,
662 ) -> Result<UpstreamOAuthProvider, Self::Error> {
663 let disabled_at = clock.now();
664 let res = sqlx::query!(
665 r#"
666 UPDATE upstream_oauth_providers
667 SET disabled_at = $2
668 WHERE upstream_oauth_provider_id = $1
669 "#,
670 Uuid::from(upstream_oauth_provider.id),
671 disabled_at,
672 )
673 .traced()
674 .execute(&mut *self.conn)
675 .await?;
676
677 DatabaseError::ensure_affected_rows(&res, 1)?;
678
679 upstream_oauth_provider.disabled_at = Some(disabled_at);
680
681 Ok(upstream_oauth_provider)
682 }
683
684 #[tracing::instrument(
685 name = "db.upstream_oauth_provider.list",
686 skip_all,
687 fields(
688 db.query.text,
689 ),
690 err,
691 )]
692 async fn list(
693 &mut self,
694 filter: UpstreamOAuthProviderFilter<'_>,
695 pagination: Pagination,
696 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
697 let (sql, arguments) = Query::select()
698 .expr_as(
699 Expr::col((
700 UpstreamOAuthProviders::Table,
701 UpstreamOAuthProviders::UpstreamOAuthProviderId,
702 )),
703 ProviderLookupIden::UpstreamOauthProviderId,
704 )
705 .expr_as(
706 Expr::col((
707 UpstreamOAuthProviders::Table,
708 UpstreamOAuthProviders::Issuer,
709 )),
710 ProviderLookupIden::Issuer,
711 )
712 .expr_as(
713 Expr::col((
714 UpstreamOAuthProviders::Table,
715 UpstreamOAuthProviders::HumanName,
716 )),
717 ProviderLookupIden::HumanName,
718 )
719 .expr_as(
720 Expr::col((
721 UpstreamOAuthProviders::Table,
722 UpstreamOAuthProviders::BrandName,
723 )),
724 ProviderLookupIden::BrandName,
725 )
726 .expr_as(
727 Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
728 ProviderLookupIden::Scope,
729 )
730 .expr_as(
731 Expr::col((
732 UpstreamOAuthProviders::Table,
733 UpstreamOAuthProviders::ClientId,
734 )),
735 ProviderLookupIden::ClientId,
736 )
737 .expr_as(
738 Expr::col((
739 UpstreamOAuthProviders::Table,
740 UpstreamOAuthProviders::EncryptedClientSecret,
741 )),
742 ProviderLookupIden::EncryptedClientSecret,
743 )
744 .expr_as(
745 Expr::col((
746 UpstreamOAuthProviders::Table,
747 UpstreamOAuthProviders::TokenEndpointSigningAlg,
748 )),
749 ProviderLookupIden::TokenEndpointSigningAlg,
750 )
751 .expr_as(
752 Expr::col((
753 UpstreamOAuthProviders::Table,
754 UpstreamOAuthProviders::TokenEndpointAuthMethod,
755 )),
756 ProviderLookupIden::TokenEndpointAuthMethod,
757 )
758 .expr_as(
759 Expr::col((
760 UpstreamOAuthProviders::Table,
761 UpstreamOAuthProviders::IdTokenSignedResponseAlg,
762 )),
763 ProviderLookupIden::IdTokenSignedResponseAlg,
764 )
765 .expr_as(
766 Expr::col((
767 UpstreamOAuthProviders::Table,
768 UpstreamOAuthProviders::FetchUserinfo,
769 )),
770 ProviderLookupIden::FetchUserinfo,
771 )
772 .expr_as(
773 Expr::col((
774 UpstreamOAuthProviders::Table,
775 UpstreamOAuthProviders::UserinfoSignedResponseAlg,
776 )),
777 ProviderLookupIden::UserinfoSignedResponseAlg,
778 )
779 .expr_as(
780 Expr::col((
781 UpstreamOAuthProviders::Table,
782 UpstreamOAuthProviders::CreatedAt,
783 )),
784 ProviderLookupIden::CreatedAt,
785 )
786 .expr_as(
787 Expr::col((
788 UpstreamOAuthProviders::Table,
789 UpstreamOAuthProviders::DisabledAt,
790 )),
791 ProviderLookupIden::DisabledAt,
792 )
793 .expr_as(
794 Expr::col((
795 UpstreamOAuthProviders::Table,
796 UpstreamOAuthProviders::ClaimsImports,
797 )),
798 ProviderLookupIden::ClaimsImports,
799 )
800 .expr_as(
801 Expr::col((
802 UpstreamOAuthProviders::Table,
803 UpstreamOAuthProviders::JwksUriOverride,
804 )),
805 ProviderLookupIden::JwksUriOverride,
806 )
807 .expr_as(
808 Expr::col((
809 UpstreamOAuthProviders::Table,
810 UpstreamOAuthProviders::TokenEndpointOverride,
811 )),
812 ProviderLookupIden::TokenEndpointOverride,
813 )
814 .expr_as(
815 Expr::col((
816 UpstreamOAuthProviders::Table,
817 UpstreamOAuthProviders::AuthorizationEndpointOverride,
818 )),
819 ProviderLookupIden::AuthorizationEndpointOverride,
820 )
821 .expr_as(
822 Expr::col((
823 UpstreamOAuthProviders::Table,
824 UpstreamOAuthProviders::UserinfoEndpointOverride,
825 )),
826 ProviderLookupIden::UserinfoEndpointOverride,
827 )
828 .expr_as(
829 Expr::col((
830 UpstreamOAuthProviders::Table,
831 UpstreamOAuthProviders::DiscoveryMode,
832 )),
833 ProviderLookupIden::DiscoveryMode,
834 )
835 .expr_as(
836 Expr::col((
837 UpstreamOAuthProviders::Table,
838 UpstreamOAuthProviders::PkceMode,
839 )),
840 ProviderLookupIden::PkceMode,
841 )
842 .expr_as(
843 Expr::col((
844 UpstreamOAuthProviders::Table,
845 UpstreamOAuthProviders::ResponseMode,
846 )),
847 ProviderLookupIden::ResponseMode,
848 )
849 .expr_as(
850 Expr::col((
851 UpstreamOAuthProviders::Table,
852 UpstreamOAuthProviders::AdditionalParameters,
853 )),
854 ProviderLookupIden::AdditionalParameters,
855 )
856 .expr_as(
857 Expr::col((
858 UpstreamOAuthProviders::Table,
859 UpstreamOAuthProviders::ForwardLoginHint,
860 )),
861 ProviderLookupIden::ForwardLoginHint,
862 )
863 .expr_as(
864 Expr::col((
865 UpstreamOAuthProviders::Table,
866 UpstreamOAuthProviders::OnBackchannelLogout,
867 )),
868 ProviderLookupIden::OnBackchannelLogout,
869 )
870 .from(UpstreamOAuthProviders::Table)
871 .apply_filter(filter)
872 .generate_pagination(
873 (
874 UpstreamOAuthProviders::Table,
875 UpstreamOAuthProviders::UpstreamOAuthProviderId,
876 ),
877 pagination,
878 )
879 .build_sqlx(PostgresQueryBuilder);
880
881 let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
882 .traced()
883 .fetch_all(&mut *self.conn)
884 .await?;
885
886 let page = pagination
887 .process(edges)
888 .try_map(UpstreamOAuthProvider::try_from)?;
889
890 return Ok(page);
891 }
892
893 #[tracing::instrument(
894 name = "db.upstream_oauth_provider.count",
895 skip_all,
896 fields(
897 db.query.text,
898 ),
899 err,
900 )]
901 async fn count(
902 &mut self,
903 filter: UpstreamOAuthProviderFilter<'_>,
904 ) -> Result<usize, Self::Error> {
905 let (sql, arguments) = Query::select()
906 .expr(
907 Expr::col((
908 UpstreamOAuthProviders::Table,
909 UpstreamOAuthProviders::UpstreamOAuthProviderId,
910 ))
911 .count(),
912 )
913 .from(UpstreamOAuthProviders::Table)
914 .apply_filter(filter)
915 .build_sqlx(PostgresQueryBuilder);
916
917 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
918 .traced()
919 .fetch_one(&mut *self.conn)
920 .await?;
921
922 count
923 .try_into()
924 .map_err(DatabaseError::to_invalid_operation)
925 }
926
927 #[tracing::instrument(
928 name = "db.upstream_oauth_provider.all_enabled",
929 skip_all,
930 fields(
931 db.query.text,
932 ),
933 err,
934 )]
935 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
936 let res = sqlx::query_as!(
937 ProviderLookup,
938 r#"
939 SELECT
940 upstream_oauth_provider_id,
941 issuer,
942 human_name,
943 brand_name,
944 scope,
945 client_id,
946 encrypted_client_secret,
947 token_endpoint_signing_alg,
948 token_endpoint_auth_method,
949 id_token_signed_response_alg,
950 fetch_userinfo,
951 userinfo_signed_response_alg,
952 created_at,
953 disabled_at,
954 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
955 jwks_uri_override,
956 authorization_endpoint_override,
957 token_endpoint_override,
958 userinfo_endpoint_override,
959 discovery_mode,
960 pkce_mode,
961 response_mode,
962 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
963 forward_login_hint,
964 on_backchannel_logout
965 FROM upstream_oauth_providers
966 WHERE disabled_at IS NULL
967 ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
968 "#,
969 )
970 .traced()
971 .fetch_all(&mut *self.conn)
972 .await?;
973
974 let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
975 Ok(res?)
976 }
977}