diff --git a/posthog/api/user.py b/posthog/api/user.py index 1c4e28bae0..6934a3550c 100644 --- a/posthog/api/user.py +++ b/posthog/api/user.py @@ -199,7 +199,13 @@ class UserSerializer(serializers.ModelSerializer): def get_has_sso_enforcement(self, instance: User) -> bool: from posthog.models.organization_domain import OrganizationDomain - return bool(OrganizationDomain.objects.get_sso_enforcement_for_email_address(instance.email)) + organization = instance.current_organization + if not organization: + return False + + return bool( + OrganizationDomain.objects.get_sso_enforcement_for_email_address(instance.email, organization=organization) + ) def validate_set_current_organization(self, value: str) -> Organization: try: diff --git a/posthog/models/organization_domain.py b/posthog/models/organization_domain.py index 8b142c83de..7379ab53f9 100644 --- a/posthog/models/organization_domain.py +++ b/posthog/models/organization_domain.py @@ -62,19 +62,22 @@ class OrganizationDomainManager(models.Manager): return True return False - def get_sso_enforcement_for_email_address(self, email: str) -> Optional[str]: + def get_sso_enforcement_for_email_address( + self, email: str, organization: Organization | None = None + ) -> Optional[str]: """ Returns the specific `sso_enforcement` applicable for an email address or an `OrganizationDomain` objects. Validates SSO providers are properly configured and all the proper licenses exist. """ domain = email[email.index("@") + 1 :] - query = ( - self.verified_domains() - .filter(domain__iexact=domain) - .exclude(sso_enforcement="") - .values("sso_enforcement", "organization_id", "organization__available_product_features") - .first() - ) + queryset = self.verified_domains().filter(domain__iexact=domain).exclude(sso_enforcement="") + + if organization is not None: + queryset = queryset.filter(organization=organization) + + query = queryset.values( + "sso_enforcement", "organization_id", "organization__available_product_features" + ).first() if not query: return None