Files
posthog/ee/api/test/test_authentication.py

1055 lines
43 KiB
Python

import os
import uuid
import datetime
from typing import cast
import pytest
from freezegun.api import freeze_time
from unittest.mock import patch
from django.core import mail
from django.core.exceptions import ValidationError
from django.test import override_settings
from django.utils import timezone
from rest_framework import status
from social_core.exceptions import AuthFailed, AuthMissingParameter
from social_django.models import UserSocialAuth
from posthog.constants import AvailableFeature
from posthog.models import OrganizationMembership, User
from posthog.models.organization_domain import OrganizationDomain
from ee.api.authentication import CustomGoogleOAuth2
from ee.api.test.base import APILicensedTest
from ee.models.license import License
SAML_MOCK_SETTINGS = {
"SOCIAL_AUTH_SAML_SECURITY_CONFIG": {
"wantAttributeStatement": False, # already present in settings
"allowSingleLabelDomains": True, # to allow `http://testserver` in tests
},
"SITE_URL": "http://localhost:8000", # http://localhost:8010 is now the default, but fixtures use 8000
}
SAML_MOCK_SETTINGS["SOCIAL_AUTH_SAML_SP_ENTITY_ID"] = SAML_MOCK_SETTINGS["SITE_URL"]
GOOGLE_MOCK_SETTINGS = {
"SOCIAL_AUTH_GOOGLE_OAUTH2_KEY": "google_key",
"SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET": "google_secret",
}
GITHUB_MOCK_SETTINGS = {
"SOCIAL_AUTH_GITHUB_KEY": "github_key",
"SOCIAL_AUTH_GITHUB_SECRET": "github_secret",
}
CURRENT_FOLDER = os.path.dirname(__file__)
class TestEELoginPrecheckAPI(APILicensedTest):
CONFIG_AUTO_LOGIN = False
def test_login_precheck_with_enforced_sso(self):
OrganizationDomain.objects.create(
domain="witw.app",
organization=self.organization,
verified_at=timezone.now(),
sso_enforcement="google-oauth2",
)
User.objects.create_and_join(self.organization, "spain@witw.app", self.CONFIG_PASSWORD)
with self.settings(**GOOGLE_MOCK_SETTINGS):
response = self.client.post("/api/login/precheck", {"email": "spain@witw.app"})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.json(),
{"sso_enforcement": "google-oauth2", "saml_available": False},
)
def test_login_precheck_with_unverified_domain(self):
OrganizationDomain.objects.create(
domain="witw.app",
organization=self.organization,
verified_at=None, # note domain is not verified
sso_enforcement="google-oauth2",
)
with self.settings(**GOOGLE_MOCK_SETTINGS):
response = self.client.post(
"/api/login/precheck", {"email": "i_do_not_exist@witw.app"}
) # Note we didn't create a user that matches, only domain is matched
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"sso_enforcement": None, "saml_available": False})
def test_login_precheck_with_inexistent_account(self):
OrganizationDomain.objects.create(
domain="anotherdomain.com",
organization=self.organization,
verified_at=timezone.now(),
sso_enforcement="github",
)
User.objects.create_and_join(self.organization, "i_do_not_exist@anotherdomain.com", self.CONFIG_PASSWORD)
with self.settings(**GITHUB_MOCK_SETTINGS):
response = self.client.post("/api/login/precheck", {"email": "i_do_not_exist@anotherdomain.com"})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"sso_enforcement": "github", "saml_available": False})
def test_login_precheck_with_enforced_sso_but_improperly_configured_sso(self):
OrganizationDomain.objects.create(
domain="witw.app",
organization=self.organization,
verified_at=timezone.now(),
sso_enforcement="google-oauth2",
)
User.objects.create_and_join(self.organization, "spain@witw.app", self.CONFIG_PASSWORD)
response = self.client.post(
"/api/login/precheck", {"email": "spain@witw.app"}
) # Note Google OAuth is not configured
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"sso_enforcement": None, "saml_available": False})
class TestEEAuthenticationAPI(APILicensedTest):
CONFIG_EMAIL = "user7@posthog.com"
def create_enforced_domain(self, **kwargs) -> OrganizationDomain:
return OrganizationDomain.objects.create(
**{
"domain": "posthog.com",
"organization": self.organization,
"verified_at": timezone.now(),
"sso_enforcement": "google-oauth2",
**kwargs,
}
)
def test_can_enforce_sso(self):
self.client.logout()
# Can log in with password with SSO configured but not enforced
with self.settings(**GOOGLE_MOCK_SETTINGS):
response = self.client.post(
"/api/login",
{"email": self.CONFIG_EMAIL, "password": self.CONFIG_PASSWORD},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"success": True})
# Forcing SSO disables regular API password login
self.create_enforced_domain()
with self.settings(**GOOGLE_MOCK_SETTINGS):
response = self.client.post(
"/api/login",
{"email": self.CONFIG_EMAIL, "password": self.CONFIG_PASSWORD},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "sso_enforced",
"detail": "You can only login with SSO for this account (google-oauth2).",
"attr": None,
},
)
def test_can_enforce_sso_on_cloud_enviroment(self):
self.client.logout()
License.objects.filter(pk=-1).delete() # No instance licenses
self.create_enforced_domain()
self.organization.available_product_features = [{"key": "sso_enforcement", "name": "sso_enforcement"}]
self.organization.save()
with self.settings(**GOOGLE_MOCK_SETTINGS):
response = self.client.post(
"/api/login",
{"email": self.CONFIG_EMAIL, "password": self.CONFIG_PASSWORD},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "sso_enforced",
"detail": "You can only login with SSO for this account (google-oauth2).",
"attr": None,
},
)
def test_cannot_reset_password_with_enforced_sso(self):
self.create_enforced_domain()
with self.settings(
**GOOGLE_MOCK_SETTINGS,
EMAIL_HOST="localhost",
SITE_URL="https://my.posthog.net",
):
response = self.client.post("/api/reset/", {"email": "i_dont_exist@posthog.com"})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "sso_enforced",
"detail": "Password reset is disabled because SSO login is enforced for this domain.",
"attr": None,
},
)
self.assertEqual(len(mail.outbox), 0)
@patch("posthog.models.organization_domain.logger.warning")
def test_cannot_enforce_sso_without_a_license(self, mock_warning):
self.client.logout()
self.license.valid_until = timezone.now() - datetime.timedelta(days=1)
self.license.save()
self.create_enforced_domain()
# Enforcement is ignored
with self.settings(**GOOGLE_MOCK_SETTINGS):
response = self.client.post(
"/api/login",
{"email": self.CONFIG_EMAIL, "password": self.CONFIG_PASSWORD},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"success": True})
# Attempting to use SAML fails
with self.settings(**GOOGLE_MOCK_SETTINGS):
response = self.client.get("/login/google-oauth2/")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertIn("/login?error_code=improperly_configured_sso", response.headers["Location"])
# Ensure warning is properly logged for debugging
mock_warning.assert_called_with(
"🤑🚪 SSO is enforced for domain posthog.com but the organization does not have the proper license.",
domain="posthog.com",
organization=str(self.organization.id),
)
def test_login_with_sso_resets_session(self):
with self.settings(**GOOGLE_MOCK_SETTINGS):
first_key = self.client.session.session_key
self.client.post("/login/google-oauth2/", {})
second_key = self.client.session.session_key
self.assertNotEqual(first_key, second_key)
def test_existing_session_remains_valid_when_sso_enforced(self):
"""Test that existing password-authenticated sessions remain valid after SSO is enforced"""
self.client.logout()
# Step 1: User logs in with password (no SSO enforcement yet)
response = self.client.post(
"/api/login",
{"email": self.CONFIG_EMAIL, "password": self.CONFIG_PASSWORD},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Step 2: Verify user can access protected endpoints
response = self.client.get("/api/users/@me/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["email"], self.CONFIG_EMAIL)
# Step 3: Admin enforces SSO for the domain
self.create_enforced_domain()
# Step 4: User's existing session should still work
response = self.client.get("/api/users/@me/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["email"], self.CONFIG_EMAIL)
# Step 5: User logs out
self.client.logout()
# Step 6: User can no longer log in with password
with self.settings(**GOOGLE_MOCK_SETTINGS):
response = self.client.post(
"/api/login",
{"email": self.CONFIG_EMAIL, "password": self.CONFIG_PASSWORD},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "sso_enforced",
"detail": "You can only login with SSO for this account (google-oauth2).",
"attr": None,
},
)
@pytest.mark.skip_on_multitenancy
@override_settings(**SAML_MOCK_SETTINGS)
class TestEESAMLAuthenticationAPI(APILicensedTest):
CONFIG_AUTO_LOGIN = False
organization_domain: OrganizationDomain = None # type: ignore
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.organization_domain = OrganizationDomain.objects.create(
domain="posthog.com",
verified_at=timezone.now(),
organization=cls.organization,
jit_provisioning_enabled=True,
saml_entity_id="http://www.okta.com/exk1ijlhixJxpyEBZ5d7",
saml_acs_url="https://idp.hogflix.io/saml",
saml_x509_cert="""MIIDqDCCApCgAwIBAgIGAXtoc3o9MA0GCSqGSIb3DQEBCwUAMIGUMQswCQYDVQQGEwJVUzETMBEG
A1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuIEZyYW5jaXNjbzENMAsGA1UECgwET2t0YTEU
MBIGA1UECwwLU1NPUHJvdmlkZXIxFTATBgNVBAMMDGRldi0xMzU1NDU1NDEcMBoGCSqGSIb3DQEJ
ARYNaW5mb0Bva3RhLmNvbTAeFw0yMTA4MjExMTIyMjNaFw0zMTA4MjExMTIzMjNaMIGUMQswCQYD
VQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuIEZyYW5jaXNjbzENMAsG
A1UECgwET2t0YTEUMBIGA1UECwwLU1NPUHJvdmlkZXIxFTATBgNVBAMMDGRldi0xMzU1NDU1NDEc
MBoGCSqGSIb3DQEJARYNaW5mb0Bva3RhLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC
ggEBAMb1IcGzor7mGsGR0AsyzQaT0O9S1SVvdkG3z2duEU/I/a4fvaECm9xvVH7TY+RwwXcnkMst
+ZZJVkTtnUGLn0oSbcwJ1iJwWNOctaNlaJtPDLvJTJpFB857D2tU01/zPn8UpBebX8tJSIcvnvyO
Iblums97f9tlsI9GHqX5N1e1TxRg6FB2ba46mgb0EdzLtPxdYDVf8b5+V0EWp0fu5nbu5T4T+1Tq
IVj2F1xwFTdsHnzh7FP92ohRRl8WQuC1BjAJTagGmgtfxQk2MW0Ti7Dl0Ejcwcjp7ezbyOgWLBmA
fJ/Sg/MyEX11+4H+VQ8bGwIYtTM2Hc+W6gnhg4IdIfcCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEA
Ef8AeVm+rbrDqil8GwZz/6mTeSHeJgsYZhJqCsaVkRPe03+NO93fRt28vlDQoz9alzA1I1ikjmfB
W/+x2dFPThR1/G4zGfF5pwU13gW1fse0/bO564f6LrmWYawL8SzwGbtelc9DxPN1X5g8Qk+j4DNm
jSjV4Oxsv3ogajnnGYGv22iBgS1qccK/cg41YkpgfP36HbiwA10xjUMv5zs97Ljep4ejp6yoKrGL
dcKmj4EG6bfcI3KY6wK46JoogXZdHDaFP+WOJNj/pJ165hYsYLcqkJktj/rEgGQmqAXWPOXHmFJb
5FPleoJTchctnzUw+QfmSsLWQ838/lUQsN7FsQ==""",
)
# SAML Metadata
def test_can_get_saml_metadata(self):
self.client.force_login(self.user)
OrganizationMembership.objects.filter(organization=self.organization, user=self.user).update(
level=OrganizationMembership.Level.ADMIN
)
response = self.client.get("/api/saml/metadata/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue("/complete/saml/" in response.content.decode())
def test_need_to_be_authenticated_to_get_saml_metadata(self):
response = self.client.get("/api/saml/metadata/")
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.json(), self.unauthenticated_response())
def test_only_admins_can_get_saml_metadata(self):
self.client.force_login(self.user)
response = self.client.get("/api/saml/metadata/")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(
response.json(),
self.permission_denied_response("You need to be an administrator or owner to access this resource."),
)
# Login precheck
def test_login_precheck_with_available_but_unenforced_saml(self):
response = self.client.post(
"/api/login/precheck", {"email": "helloworld@posthog.com"}
) # Note Google OAuth is not configured
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"sso_enforcement": None, "saml_available": True})
# Initiate SAML flow
def test_can_initiate_saml_flow(self):
response = self.client.get("/login/saml/?email=hellohello@posthog.com")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
# Assert user is redirected to the IdP's login page
location = response.headers["Location"]
self.assertIn("https://idp.hogflix.io/saml?SAMLRequest=", location)
def test_cannot_initiate_saml_flow_without_target_email_address(self):
"""
We need the email address to know how to route the SAML request.
"""
with self.assertRaises(AuthMissingParameter) as e:
self.client.get("/login/saml/")
self.assertEqual(str(e.exception), "Missing needed parameter email")
def test_cannot_initiate_saml_flow_for_unconfigured_domain(self):
"""
SAML settings have not been configured for the domain.
"""
with self.assertRaises(AuthFailed) as e:
self.client.get("/login/saml/?email=hellohello@gmail.com")
self.assertEqual(
str(e.exception),
"Authentication failed: SAML not configured for this user.",
)
def test_cannot_initiate_saml_flow_for_unverified_domain(self):
"""
Domain is unverified.
"""
self.organization_domain.verified_at = None
self.organization_domain.save()
with self.assertRaises(AuthFailed) as e:
self.client.get("/login/saml/?email=hellohello@gmail.com")
self.assertEqual(
str(e.exception),
"Authentication failed: SAML not configured for this user.",
)
# Finish SAML flow (i.e. actual log in)
@freeze_time("2021-08-25T22:09:14.252Z") # Ensures the SAML timestamp validation passes
def test_can_login_with_saml(self):
user = User.objects.create(email="engineering@posthog.com", distinct_id=str(uuid.uuid4()))
response = self.client.get("/login/saml/?email=engineering@posthog.com")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
encoding="utf_8",
) as f:
saml_response = f.read()
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(self.organization_domain.id),
},
follow=True,
format="multipart",
)
self.assertEqual(response.status_code, status.HTTP_200_OK) # because `follow=True`
self.assertRedirects(response, "/") # redirect to the home page
# Ensure proper user was assigned
_session = self.client.session
self.assertEqual(_session.get("_auth_user_id"), str(user.pk))
# Test logged in request
response = self.client.get("/api/users/@me/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
@freeze_time("2021-08-25T23:37:55.345Z")
def test_saml_jit_provisioning_and_assertion_with_different_attribute_names(self):
"""
Tests JIT provisioning for creating a user account on the fly.
In addition, tests that the user can log in when the SAML response contains attribute names in one of their alternative forms.
For example in this case we receive the user's first name at `urn:oid:2.5.4.42` instead of `first_name`.
"""
response = self.client.get("/login/saml/?email=engineering@posthog.com")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_alt_attribute_names"),
encoding="utf_8",
) as f:
saml_response = f.read()
user_count = User.objects.count()
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(self.organization_domain.id),
},
format="multipart",
follow=True,
)
self.assertEqual(response.status_code, status.HTTP_200_OK) # because `follow=True`
self.assertRedirects(response, "/") # redirect to the home page
# User is created
self.assertEqual(User.objects.count(), user_count + 1)
user = cast(User, User.objects.last())
self.assertEqual(user.first_name, "PostHog")
self.assertEqual(user.email, "engineering@posthog.com")
self.assertEqual(user.organization, self.organization)
self.assertEqual(user.team, self.team)
self.assertEqual(user.organization_memberships.count(), 1)
self.assertEqual(
cast(OrganizationMembership, user.organization_memberships.first()).level,
OrganizationMembership.Level.MEMBER,
)
_session = self.client.session
self.assertEqual(_session.get("_auth_user_id"), str(user.pk))
@freeze_time("2021-08-25T23:37:55.345Z")
def test_saml_jit_provisioning_with_case_insensitive_domain(self):
"""
Tests that JIT provisioning works with case-insensitive domain matching.
This verifies that users with email domains that differ only in case from
the verified domain in the system can still be provisioned automatically.
"""
# Create a new domain with uppercase characters
original_domain = self.organization_domain.domain
uppercase_email = f"engineering@{original_domain.upper()}"
response = self.client.get(f"/login/saml/?email={uppercase_email}")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_alt_attribute_names"),
encoding="utf_8",
) as f:
saml_response = f.read()
user_count = User.objects.count()
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(self.organization_domain.id),
},
format="multipart",
follow=True,
)
self.assertEqual(response.status_code, status.HTTP_200_OK) # because `follow=True`
self.assertRedirects(response, "/") # redirect to the home page
# User is created despite the case difference in domain
self.assertEqual(User.objects.count(), user_count + 1)
user = cast(User, User.objects.last())
self.assertEqual(user.email, uppercase_email.lower()) # The SSO middleware will make this lowercase
self.assertEqual(user.organization, self.organization)
self.assertEqual(user.team, self.team)
self.assertEqual(user.organization_memberships.count(), 1)
self.assertEqual(
cast(OrganizationMembership, user.organization_memberships.first()).level,
OrganizationMembership.Level.MEMBER,
)
_session = self.client.session
self.assertEqual(_session.get("_auth_user_id"), str(user.pk))
@freeze_time("2021-08-25T22:09:14.252Z")
def test_cannot_login_with_improperly_signed_payload(self):
self.organization_domain.saml_x509_cert = """MIIDPjCCAiYCCQC864/0fftWQTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV
UzELMAkGA1UECAwCVVMxCzAJBgNVBAcMAlVTMQswCQYDVQQKDAJVUzELMAkGA1UE
CwwCVVMxCzAJBgNVBAMMAlVTMREwDwYJKoZIhvcNAQkBFgJVUzAeFw0yMTA4MjYw
MDAxMzNaFw0zMTA4MjYwMDAxMzNaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJV
UzELMAkGA1UEBwwCVVMxCzAJBgNVBAoMAlVTMQswCQYDVQQLDAJVUzELMAkGA1UE
AwwCVVMxETAPBgkqhkiG9w0BCQEWAlVTMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
MIIBCgKCAQEA25s1++GpP9vcXKJ+SN/xdlvPYLir3yMZd/bRfolygQ4BbuzCbqKv
04AGzKfwV11HXxjtQAU/KDtXuVRa+3vZroWcK01GL1C1aH/x0Q2Wy4XZ8Ooi7NlF
MME6vbCIBmXuo4TNouE/VFTz6ntwDNopIdlGDq4M60tFeoT99eDD4OhoCSaIo0aH
2s14CzF0sec3W742yuMHCVyTDrxFzkjMel/CdoNzysvwrqvkGYtLYJn2GSUIoCpG
y6N5CaVkNpAinNSeHKP9qN/z9hSsDNgz0QuTwZ2BxfDWtwJmRJzdQ3Oeq6RlniNY
BBI71zpuQhPeAlyoBg0wG+2ikiCllGug7wIDAQABMA0GCSqGSIb3DQEBCwUAA4IB
AQB8ytXAmU4oYjANiEJVVO5LZUCx3OrY/P1OX73eoXi624yj7xvhaa7whlk1SSL/
2ks8NZNLBFJbUwShdpzR2X+7AlvsLHmodAMq2Oj5x8O+mFB/6DBl0r40NAAsuzVw
2shE4kRi4RXVB0KiyBuExry5YSVTUu8spG4/oTQYJNZFZoSfsHS2mTyprBqqca1j
yh4jGarFborxwACgg6fCiMbHVq8qlcSkRvSW03u89s3Y4mxhMX3F4AZb56ddyfMk
LERK8jfXCMVmWPTy830CtQaZX2AJyBwHG4ElP2BOZNbFAvGzrKaBmK2Ym/OJxkhx
YotAcSbU3p5bzd11wpyebYHB"""
self.organization_domain.save()
response = self.client.get("/login/saml/?email=engineering@posthog.com")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
encoding="utf_8",
) as f:
saml_response = f.read()
user_count = User.objects.count()
with self.assertRaises(AuthFailed) as e:
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(self.organization_domain.id),
},
format="multipart",
follow=True,
)
self.assertIn("Signature validation failed. SAML Response rejected", str(e.exception))
self.assertEqual(User.objects.count(), user_count)
# Test logged in request fails
response = self.client.get("/api/users/@me/")
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@freeze_time("2021-08-25T22:09:14.252Z")
def test_cannot_signup_with_saml_if_jit_provisioning_is_disabled(self):
self.organization_domain.jit_provisioning_enabled = False
self.organization_domain.save()
response = self.client.get("/login/saml/?email=engineering@posthog.com")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
encoding="utf_8",
) as f:
saml_response = f.read()
user_count = User.objects.count()
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(self.organization_domain.id),
},
format="multipart",
follow=True,
)
self.assertEqual(response.status_code, status.HTTP_200_OK) # because `follow=True`
self.assertRedirects(response, "/login?error_code=jit_not_enabled") # show the appropriate login error
# User is created
self.assertEqual(User.objects.count(), user_count)
# Test logged in request fails
response = self.client.get("/api/users/@me/")
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@freeze_time("2021-08-25T23:53:51.000Z")
def test_cannot_create_account_without_first_name_in_payload(self):
response = self.client.get("/login/saml/?email=engineering@posthog.com")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_no_first_name"),
encoding="utf_8",
) as f:
saml_response = f.read()
user_count = User.objects.count()
with self.assertRaises(ValidationError) as e:
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(self.organization_domain.id),
},
format="multipart",
follow=True,
)
self.assertEqual(
str(e.exception),
"{'name': ['This field is required and was not provided by the IdP.']}",
)
self.assertEqual(User.objects.count(), user_count)
@freeze_time("2021-08-25T22:09:14.252Z")
def test_cannot_login_with_saml_on_unverified_domain(self):
User.objects.create(email="engineering@posthog.com", distinct_id=str(uuid.uuid4()))
response = self.client.get("/login/saml/?email=engineering@posthog.com")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
# Note we "unverify" the domain after the initial request because we want to test the actual login process (not SAML initiation)
self.organization_domain.verified_at = None
self.organization_domain.save()
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
encoding="utf_8",
) as f:
saml_response = f.read()
with self.assertRaises(AuthFailed) as e:
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(self.organization_domain.id),
},
follow=True,
format="multipart",
)
self.assertEqual(
str(e.exception),
"Authentication failed: Authentication request is invalid. Invalid RelayState.",
)
# Assert user is not logged in
response = self.client.get("/api/users/@me/")
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_saml_can_be_enforced(self):
User.objects.create_and_join(
organization=self.organization,
email="engineering@posthog.com",
password=self.CONFIG_PASSWORD,
)
# Can log in regularly with SAML configured
response = self.client.post(
"/api/login",
{"email": "engineering@posthog.com", "password": self.CONFIG_PASSWORD},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"success": True})
# Forcing only SAML disables regular API password login
self.organization_domain.sso_enforcement = "saml"
self.organization_domain.save()
response = self.client.post(
"/api/login",
{"email": "engineering@posthog.com", "password": self.CONFIG_PASSWORD},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "sso_enforced",
"detail": "You can only login with SSO for this account (saml).",
"attr": None,
},
)
# Login precheck returns SAML info
response = self.client.post("/api/login/precheck", {"email": "engineering@posthog.com"})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"sso_enforcement": "saml", "saml_available": True})
def test_cannot_use_saml_without_enterprise_license(self):
self.organization.available_product_features = [
{"key": AvailableFeature.SSO_ENFORCEMENT, "name": AvailableFeature.SSO_ENFORCEMENT}
]
self.organization.save()
# Enforcement is ignored
self.organization_domain.sso_enforcement = "saml"
self.organization_domain.save()
response = self.client.post("/api/login/precheck", {"email": self.CONFIG_EMAIL})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"sso_enforcement": None, "saml_available": False})
# Cannot start SAML flow
with self.assertRaises(AuthFailed) as e:
response = self.client.get("/login/saml/?email=engineering@posthog.com")
self.assertEqual(
str(e.exception),
"Authentication failed: Your organization does not have the required license to use SAML.",
)
# Attempting to use SAML fails
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
encoding="utf_8",
) as f:
saml_response = f.read()
with self.assertRaises(AuthFailed) as e:
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(self.organization_domain.id),
},
follow=True,
format="multipart",
)
self.assertEqual(
str(e.exception),
"Authentication failed: Your organization does not have the required license to use SAML.",
)
# Remove after we figure out saml / xmlsec issues
# Test login with SAML on dev prod before removing
def test_xmlsec_and_lxml(self):
import lxml
import xmlsec
assert "1.3.14" == xmlsec.__version__
assert "5.2.1" == lxml.__version__
class TestCustomGoogleOAuth2(APILicensedTest):
def setUp(self):
super().setUp()
self.google_oauth = CustomGoogleOAuth2()
self.details = {"email": "test@posthog.com"}
self.sub = "google-oauth2|123456789"
def test_auth_extra_arguments_without_email(self):
"""Test that auth_extra_arguments returns base arguments when no email is provided."""
# Mock strategy to return empty GET parameters
mock_request = type("MockRequest", (), {})()
mock_request.GET = {}
mock_strategy = type("MockStrategy", (), {})()
mock_strategy.request = mock_request
mock_strategy.setting = lambda name, default=None, backend=None: default
self.google_oauth.strategy = mock_strategy
extra_args = self.google_oauth.auth_extra_arguments()
# Should only contain base arguments from parent class, no login_hint
self.assertNotIn("login_hint", extra_args)
def test_auth_extra_arguments_with_email(self):
"""Test that auth_extra_arguments adds login_hint when email is provided."""
# Mock strategy to return email in GET parameters
mock_request = type("MockRequest", (), {})()
mock_request.GET = {"email": "test@posthog.com"}
mock_strategy = type("MockStrategy", (), {})()
mock_strategy.request = mock_request
mock_strategy.setting = lambda name, default=None, backend=None: default
self.google_oauth.strategy = mock_strategy
extra_args = self.google_oauth.auth_extra_arguments()
self.assertEqual(extra_args["login_hint"], "test@posthog.com")
def test_get_user_id_existing_user_with_sub(self):
"""Test that a user with sub as uid continues using that sub."""
# Create user with sub as uid
UserSocialAuth.objects.create(provider="google-oauth2", uid=self.sub, user=self.user)
response = {"email": "test@posthog.com", "sub": self.sub}
uid = self.google_oauth.get_user_id(self.details, response)
self.assertEqual(uid, self.sub)
# Verify no migration occurred (count should be 1)
self.assertEqual(UserSocialAuth.objects.filter(provider="google-oauth2").count(), 1)
# Verify uid is still sub
self.assertEqual(UserSocialAuth.objects.get(provider="google-oauth2").uid, self.sub)
def test_get_user_id_migrates_email_to_sub(self):
"""Test that a user with email as uid gets migrated to using sub."""
# Create user with email as uid (legacy format)
social_auth = UserSocialAuth.objects.create(provider="google-oauth2", uid="test@posthog.com", user=self.user)
response = {"email": "test@posthog.com", "sub": self.sub}
uid = self.google_oauth.get_user_id(self.details, response)
self.assertEqual(uid, self.sub)
# Verify the uid was updated
social_auth.refresh_from_db()
self.assertEqual(social_auth.uid, self.sub)
def test_get_user_id_new_user_uses_sub(self):
"""Test that a new user gets sub as uid."""
response = {"email": "test@posthog.com", "sub": self.sub}
uid = self.google_oauth.get_user_id(self.details, response)
self.assertEqual(uid, self.sub)
# Verify no UserSocialAuth objects were created
self.assertEqual(UserSocialAuth.objects.filter(provider="google-oauth2").count(), 0)
def test_get_user_id_missing_sub_raises_error(self):
"""Test that missing sub in response raises ValueError."""
response = {
"email": "test@posthog.com",
# no sub provided
}
with self.assertRaises(ValueError) as e:
self.google_oauth.get_user_id(self.details, response)
self.assertEqual(str(e.exception), "Google OAuth response missing 'sub' claim")
class TestSSOEnforcement(APILicensedTest):
"""Test SSO enforcement across different auth methods"""
CONFIG_AUTO_LOGIN = False
def setUp(self):
super().setUp()
# Create a user with the domain we'll be testing
self.test_email = "test@testdomain.com"
self.test_user = User.objects.create_and_join(self.organization, self.test_email, self.CONFIG_PASSWORD)
@override_settings(**SAML_MOCK_SETTINGS, **GOOGLE_MOCK_SETTINGS)
def test_cannot_use_social_auth_when_saml_is_required(self):
"""Test that Google OAuth2 fails when SAML is enforced"""
from ee.api.authentication import social_auth_allowed
# Create domain with SAML enforcement
OrganizationDomain.objects.create(
domain="testdomain.com",
organization=self.organization,
verified_at=timezone.now(),
sso_enforcement="saml",
saml_entity_id="http://www.okta.com/test",
saml_acs_url="https://idp.test.io/saml",
saml_x509_cert="test_cert",
)
# Test that Google OAuth2 is blocked
with self.assertRaises(AuthFailed) as context:
social_auth_allowed(
backend=type("MockBackend", (), {"name": "google-oauth2"})(),
details={"email": self.test_email},
response={},
)
self.assertEqual(context.exception.args[0], "saml_sso_enforced")
# Test that GitHub is also blocked
with self.assertRaises(AuthFailed) as context:
social_auth_allowed(
backend=type("MockBackend", (), {"name": "github"})(), details={"email": self.test_email}, response={}
)
self.assertEqual(context.exception.args[0], "saml_sso_enforced")
@override_settings(**SAML_MOCK_SETTINGS, **GOOGLE_MOCK_SETTINGS)
def test_other_social_auth_blocked_with_google_enforcement(self):
"""Test other social auth methods are blocked, including saml, when Google is enforced"""
from ee.api.authentication import social_auth_allowed
OrganizationDomain.objects.create(
domain="testdomain.com",
organization=self.organization,
verified_at=timezone.now(),
sso_enforcement="google-oauth2",
)
# Test that GitHub auth is blocked
with self.assertRaises(AuthFailed) as context:
social_auth_allowed(
backend=type("MockBackend", (), {"name": "github"})(), details={"email": self.test_email}, response={}
)
self.assertEqual(context.exception.args[0], "google_sso_enforced")
# Test that SAML auth is blocked
with self.assertRaises(AuthFailed) as context:
social_auth_allowed(
backend=type("MockBackend", (), {"name": "saml"})(),
details={"email": self.test_email},
response={},
)
self.assertEqual(context.exception.args[0], "google_sso_enforced")
# Test that Google OAuth2 is allowed
try:
social_auth_allowed(
backend=type("MockBackend", (), {"name": "google-oauth2"})(),
details={"email": self.test_email},
response={},
)
except AuthFailed:
self.fail("Google OAuth2 should be allowed when Google OAuth2 is enforced")
@freeze_time("2021-08-25T22:09:14.252Z") # Same timestamp as other SAML tests using this fixture
@override_settings(**SAML_MOCK_SETTINGS, **GOOGLE_MOCK_SETTINGS)
def test_saml_auth_flow_blocked_when_google_oauth2_enforced(self):
"""Integration test: Verify SAML auth flow is blocked when Google OAuth2 is enforced"""
OrganizationDomain.objects.create(
domain="posthog.com",
organization=self.organization,
verified_at=timezone.now(),
sso_enforcement="google-oauth2",
)
# Create SAML configuration for the same organization (needed for RelayState)
org_domain_saml = OrganizationDomain.objects.create(
domain="saml-posthog.com", # Different domain for SAML config
organization=self.organization,
verified_at=timezone.now(),
saml_entity_id="http://www.okta.com/exk1ijlhixJxpyEBZ5d7",
saml_acs_url="https://my.posthog.app/complete/saml/",
saml_x509_cert="""MIIDqDCCApCgAwIBAgIGAXtoc3o9MA0GCSqGSIb3DQEBCwUAMIGUMQswCQYDVQQGEwJVUzETMBEG
A1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuIEZyYW5jaXNjbzENMAsGA1UECgwET2t0YTEU
MBIGA1UECwwLU1NPUHJvdmlkZXIxFTATBgNVBAMMDGRldi0xMzU1NDU1NDEcMBoGCSqGSIb3DQEJ
ARYNaW5mb0Bva3RhLmNvbTAeFw0yMTA4MjExMTIyMjNaFw0zMTA4MjExMTIzMjNaMIGUMQswCQYD
VQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuIEZyYW5jaXNjbzENMAsG
A1UECgwET2t0YTEUMBIGA1UECwwLU1NPUHJvdmlkZXIxFTATBgNVBAMMDGRldi0xMzU1NDU1NDEc
MBoGCSqGSIb3DQEJARYNaW5mb0Bva3RhLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC
ggEBAMb1IcGzor7mGsGR0AsyzQaT0O9S1SVvdkG3z2duEU/I/a4fvaECm9xvVH7TY+RwwXcnkMst
+ZZJVkTtnUGLn0oSbcwJ1iJwWNOctaNlaJtPDLvJTJpFB857D2tU01/zPn8UpBebX8tJSIcvnvyO
Iblums97f9tlsI9GHqX5N1e1TxRg6FB2ba46mgb0EdzLtPxdYDVf8b5+V0EWp0fu5nbu5T4T+1Tq
IVj2F1xwFTdsHnzh7FP92ohRRl8WQuC1BjAJTagGmgtfxQk2MW0Ti7Dl0Ejcwcjp7ezbyOgWLBmA
fJ/Sg/MyEX11+4H+VQ8bGwIYtTM2Hc+W6gnhg4IdIfcCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEA
Ef8AeVm+rbrDqil8GwZz/6mTeSHeJgsYZhJqCsaVkRPe03+NO93fRt28vlDQoz9alzA1I1ikjmfB
W/+x2dFPThR1/G4zGfF5pwU13gW1fse0/bO564f6LrmWYawL8SzwGbtelc9DxPN1X5g8Qk+j4DNm
jSjV4Oxsv3ogajnnGYGv22iBgS1qccK/cg41YkpgfP36HbiwA10xjUMv5zs97Ljep4ejp6yoKrGL
dcKmj4EG6bfcI3KY6wK46JoogXZdHDaFP+WOJNj/pJ165hYsYLcqkJktj/rEgGQmqAXWPOXHmFJb
5FPleoJTchctnzUw+QfmSsLWQ838/lUQsN7FsQ==""",
)
# Set the SAML state in session (required for SAML authentication)
_session = self.client.session
_session.update({"saml_state": "ONELOGIN_87856a50b5490e643b1ebef9cb5bf6e78225a3c6"})
_session.save()
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
encoding="utf_8",
) as f:
saml_response = f.read()
# Attempt to complete SAML authentication
# The SAML response contains test@posthog.com which has google-oauth2 enforcement
response = self.client.post(
"/complete/saml/",
{
"SAMLResponse": saml_response,
"RelayState": str(org_domain_saml.id),
},
format="multipart",
follow=False, # Don't follow redirects so we can check the redirect URL
)
# Should be redirected to login with SSO enforcement error
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertIn("/login?error_code=google_sso_enforced", response.headers["Location"])