mas_storage_pg/upstream_oauth2/
provider.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11    Clock, Page, Pagination,
12    upstream_oauth2::{
13        UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
14    },
15};
16use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::{PgConnection, types::Json};
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26    DatabaseError, DatabaseInconsistencyError,
27    filter::{Filter, StatementExt},
28    iden::UpstreamOAuthProviders,
29    pagination::QueryBuilderExt,
30    tracing::ExecuteExt,
31};
32
33/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL
34/// connection
35pub struct PgUpstreamOAuthProviderRepository<'c> {
36    conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthProviderRepository<'c> {
40    /// Create a new [`PgUpstreamOAuthProviderRepository`] from an active
41    /// PostgreSQL connection
42    pub fn new(conn: &'c mut PgConnection) -> Self {
43        Self { conn }
44    }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct ProviderLookup {
50    upstream_oauth_provider_id: Uuid,
51    issuer: Option<String>,
52    human_name: Option<String>,
53    brand_name: Option<String>,
54    scope: String,
55    client_id: String,
56    encrypted_client_secret: Option<String>,
57    token_endpoint_signing_alg: Option<String>,
58    token_endpoint_auth_method: String,
59    id_token_signed_response_alg: String,
60    fetch_userinfo: bool,
61    userinfo_signed_response_alg: Option<String>,
62    created_at: DateTime<Utc>,
63    disabled_at: Option<DateTime<Utc>>,
64    claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
65    jwks_uri_override: Option<String>,
66    authorization_endpoint_override: Option<String>,
67    token_endpoint_override: Option<String>,
68    userinfo_endpoint_override: Option<String>,
69    discovery_mode: String,
70    pkce_mode: String,
71    response_mode: Option<String>,
72    additional_parameters: Option<Json<Vec<(String, String)>>>,
73    forward_login_hint: bool,
74    on_backchannel_logout: String,
75}
76
77impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
78    type Error = DatabaseInconsistencyError;
79
80    fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
81        let id = value.upstream_oauth_provider_id.into();
82        let scope = value.scope.parse().map_err(|e| {
83            DatabaseInconsistencyError::on("upstream_oauth_providers")
84                .column("scope")
85                .row(id)
86                .source(e)
87        })?;
88        let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
89            DatabaseInconsistencyError::on("upstream_oauth_providers")
90                .column("token_endpoint_auth_method")
91                .row(id)
92                .source(e)
93        })?;
94        let token_endpoint_signing_alg = value
95            .token_endpoint_signing_alg
96            .map(|x| x.parse())
97            .transpose()
98            .map_err(|e| {
99                DatabaseInconsistencyError::on("upstream_oauth_providers")
100                    .column("token_endpoint_signing_alg")
101                    .row(id)
102                    .source(e)
103            })?;
104        let id_token_signed_response_alg =
105            value.id_token_signed_response_alg.parse().map_err(|e| {
106                DatabaseInconsistencyError::on("upstream_oauth_providers")
107                    .column("id_token_signed_response_alg")
108                    .row(id)
109                    .source(e)
110            })?;
111
112        let userinfo_signed_response_alg = value
113            .userinfo_signed_response_alg
114            .map(|x| x.parse())
115            .transpose()
116            .map_err(|e| {
117                DatabaseInconsistencyError::on("upstream_oauth_providers")
118                    .column("userinfo_signed_response_alg")
119                    .row(id)
120                    .source(e)
121            })?;
122
123        let authorization_endpoint_override = value
124            .authorization_endpoint_override
125            .map(|x| x.parse())
126            .transpose()
127            .map_err(|e| {
128                DatabaseInconsistencyError::on("upstream_oauth_providers")
129                    .column("authorization_endpoint_override")
130                    .row(id)
131                    .source(e)
132            })?;
133
134        let token_endpoint_override = value
135            .token_endpoint_override
136            .map(|x| x.parse())
137            .transpose()
138            .map_err(|e| {
139                DatabaseInconsistencyError::on("upstream_oauth_providers")
140                    .column("token_endpoint_override")
141                    .row(id)
142                    .source(e)
143            })?;
144
145        let userinfo_endpoint_override = value
146            .userinfo_endpoint_override
147            .map(|x| x.parse())
148            .transpose()
149            .map_err(|e| {
150                DatabaseInconsistencyError::on("upstream_oauth_providers")
151                    .column("userinfo_endpoint_override")
152                    .row(id)
153                    .source(e)
154            })?;
155
156        let jwks_uri_override = value
157            .jwks_uri_override
158            .map(|x| x.parse())
159            .transpose()
160            .map_err(|e| {
161                DatabaseInconsistencyError::on("upstream_oauth_providers")
162                    .column("jwks_uri_override")
163                    .row(id)
164                    .source(e)
165            })?;
166
167        let discovery_mode = value.discovery_mode.parse().map_err(|e| {
168            DatabaseInconsistencyError::on("upstream_oauth_providers")
169                .column("discovery_mode")
170                .row(id)
171                .source(e)
172        })?;
173
174        let pkce_mode = value.pkce_mode.parse().map_err(|e| {
175            DatabaseInconsistencyError::on("upstream_oauth_providers")
176                .column("pkce_mode")
177                .row(id)
178                .source(e)
179        })?;
180
181        let response_mode = value
182            .response_mode
183            .map(|x| x.parse())
184            .transpose()
185            .map_err(|e| {
186                DatabaseInconsistencyError::on("upstream_oauth_providers")
187                    .column("response_mode")
188                    .row(id)
189                    .source(e)
190            })?;
191
192        let additional_authorization_parameters = value
193            .additional_parameters
194            .map(|Json(x)| x)
195            .unwrap_or_default();
196
197        let on_backchannel_logout = value.on_backchannel_logout.parse().map_err(|e| {
198            DatabaseInconsistencyError::on("upstream_oauth_providers")
199                .column("on_backchannel_logout")
200                .row(id)
201                .source(e)
202        })?;
203
204        Ok(UpstreamOAuthProvider {
205            id,
206            issuer: value.issuer,
207            human_name: value.human_name,
208            brand_name: value.brand_name,
209            scope,
210            client_id: value.client_id,
211            encrypted_client_secret: value.encrypted_client_secret,
212            token_endpoint_auth_method,
213            token_endpoint_signing_alg,
214            id_token_signed_response_alg,
215            fetch_userinfo: value.fetch_userinfo,
216            userinfo_signed_response_alg,
217            created_at: value.created_at,
218            disabled_at: value.disabled_at,
219            claims_imports: value.claims_imports.0,
220            authorization_endpoint_override,
221            token_endpoint_override,
222            userinfo_endpoint_override,
223            jwks_uri_override,
224            discovery_mode,
225            pkce_mode,
226            response_mode,
227            additional_authorization_parameters,
228            forward_login_hint: value.forward_login_hint,
229            on_backchannel_logout,
230        })
231    }
232}
233
234impl Filter for UpstreamOAuthProviderFilter<'_> {
235    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
236        sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
237            Expr::col((
238                UpstreamOAuthProviders::Table,
239                UpstreamOAuthProviders::DisabledAt,
240            ))
241            .is_null()
242            .eq(enabled)
243        }))
244    }
245}
246
247#[async_trait]
248impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
249    type Error = DatabaseError;
250
251    #[tracing::instrument(
252        name = "db.upstream_oauth_provider.lookup",
253        skip_all,
254        fields(
255            db.query.text,
256            upstream_oauth_provider.id = %id,
257        ),
258        err,
259    )]
260    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
261        let res = sqlx::query_as!(
262            ProviderLookup,
263            r#"
264                SELECT
265                    upstream_oauth_provider_id,
266                    issuer,
267                    human_name,
268                    brand_name,
269                    scope,
270                    client_id,
271                    encrypted_client_secret,
272                    token_endpoint_signing_alg,
273                    token_endpoint_auth_method,
274                    id_token_signed_response_alg,
275                    fetch_userinfo,
276                    userinfo_signed_response_alg,
277                    created_at,
278                    disabled_at,
279                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
280                    jwks_uri_override,
281                    authorization_endpoint_override,
282                    token_endpoint_override,
283                    userinfo_endpoint_override,
284                    discovery_mode,
285                    pkce_mode,
286                    response_mode,
287                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
288                    forward_login_hint,
289                    on_backchannel_logout
290                FROM upstream_oauth_providers
291                WHERE upstream_oauth_provider_id = $1
292            "#,
293            Uuid::from(id),
294        )
295        .traced()
296        .fetch_optional(&mut *self.conn)
297        .await?;
298
299        let res = res
300            .map(UpstreamOAuthProvider::try_from)
301            .transpose()
302            .map_err(DatabaseError::from)?;
303
304        Ok(res)
305    }
306
307    #[tracing::instrument(
308        name = "db.upstream_oauth_provider.add",
309        skip_all,
310        fields(
311            db.query.text,
312            upstream_oauth_provider.id,
313            upstream_oauth_provider.issuer = params.issuer,
314            upstream_oauth_provider.client_id = %params.client_id,
315        ),
316        err,
317    )]
318    async fn add(
319        &mut self,
320        rng: &mut (dyn RngCore + Send),
321        clock: &dyn Clock,
322        params: UpstreamOAuthProviderParams,
323    ) -> Result<UpstreamOAuthProvider, Self::Error> {
324        let created_at = clock.now();
325        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
326        tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
327
328        sqlx::query!(
329            r#"
330            INSERT INTO upstream_oauth_providers (
331                upstream_oauth_provider_id,
332                issuer,
333                human_name,
334                brand_name,
335                scope,
336                token_endpoint_auth_method,
337                token_endpoint_signing_alg,
338                id_token_signed_response_alg,
339                fetch_userinfo,
340                userinfo_signed_response_alg,
341                client_id,
342                encrypted_client_secret,
343                claims_imports,
344                authorization_endpoint_override,
345                token_endpoint_override,
346                userinfo_endpoint_override,
347                jwks_uri_override,
348                discovery_mode,
349                pkce_mode,
350                response_mode,
351                forward_login_hint,
352                on_backchannel_logout,
353                created_at
354            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
355                      $12, $13, $14, $15, $16, $17, $18, $19, $20,
356                      $21, $22, $23)
357        "#,
358            Uuid::from(id),
359            params.issuer.as_deref(),
360            params.human_name.as_deref(),
361            params.brand_name.as_deref(),
362            params.scope.to_string(),
363            params.token_endpoint_auth_method.to_string(),
364            params
365                .token_endpoint_signing_alg
366                .as_ref()
367                .map(ToString::to_string),
368            params.id_token_signed_response_alg.to_string(),
369            params.fetch_userinfo,
370            params
371                .userinfo_signed_response_alg
372                .as_ref()
373                .map(ToString::to_string),
374            &params.client_id,
375            params.encrypted_client_secret.as_deref(),
376            Json(&params.claims_imports) as _,
377            params
378                .authorization_endpoint_override
379                .as_ref()
380                .map(ToString::to_string),
381            params
382                .token_endpoint_override
383                .as_ref()
384                .map(ToString::to_string),
385            params
386                .userinfo_endpoint_override
387                .as_ref()
388                .map(ToString::to_string),
389            params.jwks_uri_override.as_ref().map(ToString::to_string),
390            params.discovery_mode.as_str(),
391            params.pkce_mode.as_str(),
392            params.response_mode.as_ref().map(ToString::to_string),
393            params.forward_login_hint,
394            params.on_backchannel_logout.as_str(),
395            created_at,
396        )
397        .traced()
398        .execute(&mut *self.conn)
399        .await?;
400
401        Ok(UpstreamOAuthProvider {
402            id,
403            issuer: params.issuer,
404            human_name: params.human_name,
405            brand_name: params.brand_name,
406            scope: params.scope,
407            client_id: params.client_id,
408            encrypted_client_secret: params.encrypted_client_secret,
409            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
410            token_endpoint_auth_method: params.token_endpoint_auth_method,
411            id_token_signed_response_alg: params.id_token_signed_response_alg,
412            fetch_userinfo: params.fetch_userinfo,
413            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
414            created_at,
415            disabled_at: None,
416            claims_imports: params.claims_imports,
417            authorization_endpoint_override: params.authorization_endpoint_override,
418            token_endpoint_override: params.token_endpoint_override,
419            userinfo_endpoint_override: params.userinfo_endpoint_override,
420            jwks_uri_override: params.jwks_uri_override,
421            discovery_mode: params.discovery_mode,
422            pkce_mode: params.pkce_mode,
423            response_mode: params.response_mode,
424            additional_authorization_parameters: params.additional_authorization_parameters,
425            on_backchannel_logout: params.on_backchannel_logout,
426            forward_login_hint: params.forward_login_hint,
427        })
428    }
429
430    #[tracing::instrument(
431        name = "db.upstream_oauth_provider.delete_by_id",
432        skip_all,
433        fields(
434            db.query.text,
435            upstream_oauth_provider.id = %id,
436        ),
437        err,
438    )]
439    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
440        // Delete the authorization sessions first, as they have a foreign key
441        // constraint on the links and the providers.
442        {
443            let span = info_span!(
444                "db.oauth2_client.delete_by_id.authorization_sessions",
445                upstream_oauth_provider.id = %id,
446                { DB_QUERY_TEXT } = tracing::field::Empty,
447            );
448            sqlx::query!(
449                r#"
450                    DELETE FROM upstream_oauth_authorization_sessions
451                    WHERE upstream_oauth_provider_id = $1
452                "#,
453                Uuid::from(id),
454            )
455            .record(&span)
456            .execute(&mut *self.conn)
457            .instrument(span)
458            .await?;
459        }
460
461        // Delete the links next, as they have a foreign key constraint on the
462        // providers.
463        {
464            let span = info_span!(
465                "db.oauth2_client.delete_by_id.links",
466                upstream_oauth_provider.id = %id,
467                { DB_QUERY_TEXT } = tracing::field::Empty,
468            );
469            sqlx::query!(
470                r#"
471                    DELETE FROM upstream_oauth_links
472                    WHERE upstream_oauth_provider_id = $1
473                "#,
474                Uuid::from(id),
475            )
476            .record(&span)
477            .execute(&mut *self.conn)
478            .instrument(span)
479            .await?;
480        }
481
482        let res = sqlx::query!(
483            r#"
484                DELETE FROM upstream_oauth_providers
485                WHERE upstream_oauth_provider_id = $1
486            "#,
487            Uuid::from(id),
488        )
489        .traced()
490        .execute(&mut *self.conn)
491        .await?;
492
493        DatabaseError::ensure_affected_rows(&res, 1)
494    }
495
496    #[tracing::instrument(
497        name = "db.upstream_oauth_provider.add",
498        skip_all,
499        fields(
500            db.query.text,
501            upstream_oauth_provider.id = %id,
502            upstream_oauth_provider.issuer = params.issuer,
503            upstream_oauth_provider.client_id = %params.client_id,
504        ),
505        err,
506    )]
507    async fn upsert(
508        &mut self,
509        clock: &dyn Clock,
510        id: Ulid,
511        params: UpstreamOAuthProviderParams,
512    ) -> Result<UpstreamOAuthProvider, Self::Error> {
513        let created_at = clock.now();
514
515        let created_at = sqlx::query_scalar!(
516            r#"
517                INSERT INTO upstream_oauth_providers (
518                    upstream_oauth_provider_id,
519                    issuer,
520                    human_name,
521                    brand_name,
522                    scope,
523                    token_endpoint_auth_method,
524                    token_endpoint_signing_alg,
525                    id_token_signed_response_alg,
526                    fetch_userinfo,
527                    userinfo_signed_response_alg,
528                    client_id,
529                    encrypted_client_secret,
530                    claims_imports,
531                    authorization_endpoint_override,
532                    token_endpoint_override,
533                    userinfo_endpoint_override,
534                    jwks_uri_override,
535                    discovery_mode,
536                    pkce_mode,
537                    response_mode,
538                    additional_parameters,
539                    forward_login_hint,
540                    ui_order,
541                    on_backchannel_logout,
542                    created_at
543                ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
544                          $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
545                          $21, $22, $23, $24, $25)
546                ON CONFLICT (upstream_oauth_provider_id)
547                    DO UPDATE
548                    SET
549                        issuer = EXCLUDED.issuer,
550                        human_name = EXCLUDED.human_name,
551                        brand_name = EXCLUDED.brand_name,
552                        scope = EXCLUDED.scope,
553                        token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
554                        token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
555                        id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
556                        fetch_userinfo = EXCLUDED.fetch_userinfo,
557                        userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
558                        disabled_at = NULL,
559                        client_id = EXCLUDED.client_id,
560                        encrypted_client_secret = EXCLUDED.encrypted_client_secret,
561                        claims_imports = EXCLUDED.claims_imports,
562                        authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
563                        token_endpoint_override = EXCLUDED.token_endpoint_override,
564                        userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
565                        jwks_uri_override = EXCLUDED.jwks_uri_override,
566                        discovery_mode = EXCLUDED.discovery_mode,
567                        pkce_mode = EXCLUDED.pkce_mode,
568                        response_mode = EXCLUDED.response_mode,
569                        additional_parameters = EXCLUDED.additional_parameters,
570                        forward_login_hint = EXCLUDED.forward_login_hint,
571                        ui_order = EXCLUDED.ui_order,
572                        on_backchannel_logout = EXCLUDED.on_backchannel_logout
573                RETURNING created_at
574            "#,
575            Uuid::from(id),
576            params.issuer.as_deref(),
577            params.human_name.as_deref(),
578            params.brand_name.as_deref(),
579            params.scope.to_string(),
580            params.token_endpoint_auth_method.to_string(),
581            params
582                .token_endpoint_signing_alg
583                .as_ref()
584                .map(ToString::to_string),
585            params.id_token_signed_response_alg.to_string(),
586            params.fetch_userinfo,
587            params
588                .userinfo_signed_response_alg
589                .as_ref()
590                .map(ToString::to_string),
591            &params.client_id,
592            params.encrypted_client_secret.as_deref(),
593            Json(&params.claims_imports) as _,
594            params
595                .authorization_endpoint_override
596                .as_ref()
597                .map(ToString::to_string),
598            params
599                .token_endpoint_override
600                .as_ref()
601                .map(ToString::to_string),
602            params
603                .userinfo_endpoint_override
604                .as_ref()
605                .map(ToString::to_string),
606            params.jwks_uri_override.as_ref().map(ToString::to_string),
607            params.discovery_mode.as_str(),
608            params.pkce_mode.as_str(),
609            params.response_mode.as_ref().map(ToString::to_string),
610            Json(&params.additional_authorization_parameters) as _,
611            params.forward_login_hint,
612            params.ui_order,
613            params.on_backchannel_logout.as_str(),
614            created_at,
615        )
616        .traced()
617        .fetch_one(&mut *self.conn)
618        .await?;
619
620        Ok(UpstreamOAuthProvider {
621            id,
622            issuer: params.issuer,
623            human_name: params.human_name,
624            brand_name: params.brand_name,
625            scope: params.scope,
626            client_id: params.client_id,
627            encrypted_client_secret: params.encrypted_client_secret,
628            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
629            token_endpoint_auth_method: params.token_endpoint_auth_method,
630            id_token_signed_response_alg: params.id_token_signed_response_alg,
631            fetch_userinfo: params.fetch_userinfo,
632            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
633            created_at,
634            disabled_at: None,
635            claims_imports: params.claims_imports,
636            authorization_endpoint_override: params.authorization_endpoint_override,
637            token_endpoint_override: params.token_endpoint_override,
638            userinfo_endpoint_override: params.userinfo_endpoint_override,
639            jwks_uri_override: params.jwks_uri_override,
640            discovery_mode: params.discovery_mode,
641            pkce_mode: params.pkce_mode,
642            response_mode: params.response_mode,
643            additional_authorization_parameters: params.additional_authorization_parameters,
644            forward_login_hint: params.forward_login_hint,
645            on_backchannel_logout: params.on_backchannel_logout,
646        })
647    }
648
649    #[tracing::instrument(
650        name = "db.upstream_oauth_provider.disable",
651        skip_all,
652        fields(
653            db.query.text,
654            %upstream_oauth_provider.id,
655        ),
656        err,
657    )]
658    async fn disable(
659        &mut self,
660        clock: &dyn Clock,
661        mut upstream_oauth_provider: UpstreamOAuthProvider,
662    ) -> Result<UpstreamOAuthProvider, Self::Error> {
663        let disabled_at = clock.now();
664        let res = sqlx::query!(
665            r#"
666                UPDATE upstream_oauth_providers
667                SET disabled_at = $2
668                WHERE upstream_oauth_provider_id = $1
669            "#,
670            Uuid::from(upstream_oauth_provider.id),
671            disabled_at,
672        )
673        .traced()
674        .execute(&mut *self.conn)
675        .await?;
676
677        DatabaseError::ensure_affected_rows(&res, 1)?;
678
679        upstream_oauth_provider.disabled_at = Some(disabled_at);
680
681        Ok(upstream_oauth_provider)
682    }
683
684    #[tracing::instrument(
685        name = "db.upstream_oauth_provider.list",
686        skip_all,
687        fields(
688            db.query.text,
689        ),
690        err,
691    )]
692    async fn list(
693        &mut self,
694        filter: UpstreamOAuthProviderFilter<'_>,
695        pagination: Pagination,
696    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
697        let (sql, arguments) = Query::select()
698            .expr_as(
699                Expr::col((
700                    UpstreamOAuthProviders::Table,
701                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
702                )),
703                ProviderLookupIden::UpstreamOauthProviderId,
704            )
705            .expr_as(
706                Expr::col((
707                    UpstreamOAuthProviders::Table,
708                    UpstreamOAuthProviders::Issuer,
709                )),
710                ProviderLookupIden::Issuer,
711            )
712            .expr_as(
713                Expr::col((
714                    UpstreamOAuthProviders::Table,
715                    UpstreamOAuthProviders::HumanName,
716                )),
717                ProviderLookupIden::HumanName,
718            )
719            .expr_as(
720                Expr::col((
721                    UpstreamOAuthProviders::Table,
722                    UpstreamOAuthProviders::BrandName,
723                )),
724                ProviderLookupIden::BrandName,
725            )
726            .expr_as(
727                Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
728                ProviderLookupIden::Scope,
729            )
730            .expr_as(
731                Expr::col((
732                    UpstreamOAuthProviders::Table,
733                    UpstreamOAuthProviders::ClientId,
734                )),
735                ProviderLookupIden::ClientId,
736            )
737            .expr_as(
738                Expr::col((
739                    UpstreamOAuthProviders::Table,
740                    UpstreamOAuthProviders::EncryptedClientSecret,
741                )),
742                ProviderLookupIden::EncryptedClientSecret,
743            )
744            .expr_as(
745                Expr::col((
746                    UpstreamOAuthProviders::Table,
747                    UpstreamOAuthProviders::TokenEndpointSigningAlg,
748                )),
749                ProviderLookupIden::TokenEndpointSigningAlg,
750            )
751            .expr_as(
752                Expr::col((
753                    UpstreamOAuthProviders::Table,
754                    UpstreamOAuthProviders::TokenEndpointAuthMethod,
755                )),
756                ProviderLookupIden::TokenEndpointAuthMethod,
757            )
758            .expr_as(
759                Expr::col((
760                    UpstreamOAuthProviders::Table,
761                    UpstreamOAuthProviders::IdTokenSignedResponseAlg,
762                )),
763                ProviderLookupIden::IdTokenSignedResponseAlg,
764            )
765            .expr_as(
766                Expr::col((
767                    UpstreamOAuthProviders::Table,
768                    UpstreamOAuthProviders::FetchUserinfo,
769                )),
770                ProviderLookupIden::FetchUserinfo,
771            )
772            .expr_as(
773                Expr::col((
774                    UpstreamOAuthProviders::Table,
775                    UpstreamOAuthProviders::UserinfoSignedResponseAlg,
776                )),
777                ProviderLookupIden::UserinfoSignedResponseAlg,
778            )
779            .expr_as(
780                Expr::col((
781                    UpstreamOAuthProviders::Table,
782                    UpstreamOAuthProviders::CreatedAt,
783                )),
784                ProviderLookupIden::CreatedAt,
785            )
786            .expr_as(
787                Expr::col((
788                    UpstreamOAuthProviders::Table,
789                    UpstreamOAuthProviders::DisabledAt,
790                )),
791                ProviderLookupIden::DisabledAt,
792            )
793            .expr_as(
794                Expr::col((
795                    UpstreamOAuthProviders::Table,
796                    UpstreamOAuthProviders::ClaimsImports,
797                )),
798                ProviderLookupIden::ClaimsImports,
799            )
800            .expr_as(
801                Expr::col((
802                    UpstreamOAuthProviders::Table,
803                    UpstreamOAuthProviders::JwksUriOverride,
804                )),
805                ProviderLookupIden::JwksUriOverride,
806            )
807            .expr_as(
808                Expr::col((
809                    UpstreamOAuthProviders::Table,
810                    UpstreamOAuthProviders::TokenEndpointOverride,
811                )),
812                ProviderLookupIden::TokenEndpointOverride,
813            )
814            .expr_as(
815                Expr::col((
816                    UpstreamOAuthProviders::Table,
817                    UpstreamOAuthProviders::AuthorizationEndpointOverride,
818                )),
819                ProviderLookupIden::AuthorizationEndpointOverride,
820            )
821            .expr_as(
822                Expr::col((
823                    UpstreamOAuthProviders::Table,
824                    UpstreamOAuthProviders::UserinfoEndpointOverride,
825                )),
826                ProviderLookupIden::UserinfoEndpointOverride,
827            )
828            .expr_as(
829                Expr::col((
830                    UpstreamOAuthProviders::Table,
831                    UpstreamOAuthProviders::DiscoveryMode,
832                )),
833                ProviderLookupIden::DiscoveryMode,
834            )
835            .expr_as(
836                Expr::col((
837                    UpstreamOAuthProviders::Table,
838                    UpstreamOAuthProviders::PkceMode,
839                )),
840                ProviderLookupIden::PkceMode,
841            )
842            .expr_as(
843                Expr::col((
844                    UpstreamOAuthProviders::Table,
845                    UpstreamOAuthProviders::ResponseMode,
846                )),
847                ProviderLookupIden::ResponseMode,
848            )
849            .expr_as(
850                Expr::col((
851                    UpstreamOAuthProviders::Table,
852                    UpstreamOAuthProviders::AdditionalParameters,
853                )),
854                ProviderLookupIden::AdditionalParameters,
855            )
856            .expr_as(
857                Expr::col((
858                    UpstreamOAuthProviders::Table,
859                    UpstreamOAuthProviders::ForwardLoginHint,
860                )),
861                ProviderLookupIden::ForwardLoginHint,
862            )
863            .expr_as(
864                Expr::col((
865                    UpstreamOAuthProviders::Table,
866                    UpstreamOAuthProviders::OnBackchannelLogout,
867                )),
868                ProviderLookupIden::OnBackchannelLogout,
869            )
870            .from(UpstreamOAuthProviders::Table)
871            .apply_filter(filter)
872            .generate_pagination(
873                (
874                    UpstreamOAuthProviders::Table,
875                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
876                ),
877                pagination,
878            )
879            .build_sqlx(PostgresQueryBuilder);
880
881        let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
882            .traced()
883            .fetch_all(&mut *self.conn)
884            .await?;
885
886        let page = pagination
887            .process(edges)
888            .try_map(UpstreamOAuthProvider::try_from)?;
889
890        return Ok(page);
891    }
892
893    #[tracing::instrument(
894        name = "db.upstream_oauth_provider.count",
895        skip_all,
896        fields(
897            db.query.text,
898        ),
899        err,
900    )]
901    async fn count(
902        &mut self,
903        filter: UpstreamOAuthProviderFilter<'_>,
904    ) -> Result<usize, Self::Error> {
905        let (sql, arguments) = Query::select()
906            .expr(
907                Expr::col((
908                    UpstreamOAuthProviders::Table,
909                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
910                ))
911                .count(),
912            )
913            .from(UpstreamOAuthProviders::Table)
914            .apply_filter(filter)
915            .build_sqlx(PostgresQueryBuilder);
916
917        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
918            .traced()
919            .fetch_one(&mut *self.conn)
920            .await?;
921
922        count
923            .try_into()
924            .map_err(DatabaseError::to_invalid_operation)
925    }
926
927    #[tracing::instrument(
928        name = "db.upstream_oauth_provider.all_enabled",
929        skip_all,
930        fields(
931            db.query.text,
932        ),
933        err,
934    )]
935    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
936        let res = sqlx::query_as!(
937            ProviderLookup,
938            r#"
939                SELECT
940                    upstream_oauth_provider_id,
941                    issuer,
942                    human_name,
943                    brand_name,
944                    scope,
945                    client_id,
946                    encrypted_client_secret,
947                    token_endpoint_signing_alg,
948                    token_endpoint_auth_method,
949                    id_token_signed_response_alg,
950                    fetch_userinfo,
951                    userinfo_signed_response_alg,
952                    created_at,
953                    disabled_at,
954                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
955                    jwks_uri_override,
956                    authorization_endpoint_override,
957                    token_endpoint_override,
958                    userinfo_endpoint_override,
959                    discovery_mode,
960                    pkce_mode,
961                    response_mode,
962                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
963                    forward_login_hint,
964                    on_backchannel_logout
965                FROM upstream_oauth_providers
966                WHERE disabled_at IS NULL
967                ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
968            "#,
969        )
970        .traced()
971        .fetch_all(&mut *self.conn)
972        .await?;
973
974        let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
975        Ok(res?)
976    }
977}