feat: posthog-cli interactive login & typescript schema download (#39903)

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Sandy Spicer
2025-11-02 12:40:09 -08:00
committed by GitHub
parent 2281ef0d1f
commit f7b8756a27
26 changed files with 2634 additions and 27 deletions

2
cli/Cargo.lock generated
View File

@@ -1520,7 +1520,7 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "posthog-cli"
version = "0.5.7"
version = "0.5.8"
dependencies = [
"anyhow",
"chrono",

View File

@@ -1,6 +1,6 @@
[package]
name = "posthog-cli"
version = "0.5.7"
version = "0.5.8"
authors = [
"David <david@posthog.com>",
"Olly <oliver@posthog.com>",

View File

@@ -68,6 +68,24 @@ pub enum ExpCommand {
#[command(subcommand)]
cmd: HermesSubcommand,
},
/// Download event definitions and generate typed SDK
Schema {
#[command(subcommand)]
cmd: SchemaCommand,
},
}
#[derive(Subcommand)]
pub enum SchemaCommand {
/// Download event definitions and generate typed SDK
Pull {
/// Output path for generated definitions (stored in posthog.json for future runs)
#[arg(short, long)]
output: Option<String>,
},
/// Show current schema sync status
Status,
}
impl Cli {
@@ -100,7 +118,7 @@ impl Cli {
match self.command {
Commands::Login => {
// Notably login doesn't have a context set up going it - it sets one up
crate::login::login()?;
crate::login::login(self.host)?;
}
Commands::Sourcemap { cmd } => match cmd {
SourcemapCommand::Inject(input_args) => {
@@ -136,6 +154,14 @@ impl Cli {
crate::sourcemaps::hermes::clone::clone(&args)?;
}
},
ExpCommand::Schema { cmd } => match cmd {
SchemaCommand::Pull { output } => {
crate::experimental::schema::pull(self.host, output)?;
}
SchemaCommand::Status => {
crate::experimental::schema::status()?;
}
},
},
}

View File

@@ -2,5 +2,6 @@
// Things in here should be considered unstable and possibly broken
pub mod query;
pub mod schema;
pub mod tasks;
pub mod tui;

View File

@@ -0,0 +1,366 @@
use anyhow::{Context, Result};
use inquire::{Select, Text};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use tracing::info;
use crate::invocation_context::context;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
enum Language {
TypeScript,
}
impl Language {
/// Get the language identifier used in API URLs
fn as_str(&self) -> &'static str {
match self {
Language::TypeScript => "typescript",
}
}
/// Get the display name for the language
fn display_name(&self) -> &'static str {
match self {
Language::TypeScript => "TypeScript",
}
}
/// Get the default output filename for this language
fn default_output_path(&self) -> &'static str {
match self {
Language::TypeScript => "posthog-typed.ts",
}
}
/// Get all available languages
fn all() -> Vec<Language> {
vec![Language::TypeScript]
}
/// Parse a language from a string identifier
fn from_str(s: &str) -> Option<Language> {
match s {
"typescript" => Some(Language::TypeScript),
_ => None,
}
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
struct SchemaConfig {
languages: HashMap<String, LanguageConfig>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct LanguageConfig {
output_path: String,
schema_hash: String,
updated_at: String,
event_count: usize,
}
impl SchemaConfig {
/// Load config from posthog.json, returns empty config if file doesn't exist or is invalid
fn load() -> Self {
let content = fs::read_to_string("posthog.json").ok();
content
.and_then(|c| serde_json::from_str(&c).ok())
.unwrap_or_default()
}
/// Save config to posthog.json
fn save(&self) -> Result<()> {
let json =
serde_json::to_string_pretty(self).context("Failed to serialize schema config")?;
fs::write("posthog.json", json).context("Failed to write posthog.json")?;
Ok(())
}
/// Get language config for a specific language
fn get_language(&self, language: Language) -> Option<&LanguageConfig> {
self.languages.get(language.as_str())
}
/// Get output path for a language
fn get_output_path(&self, language: Language) -> Option<String> {
self.languages
.get(language.as_str())
.map(|l| l.output_path.clone())
}
/// Update language config, preserving other languages
fn update_language(
&mut self,
language: Language,
output_path: String,
schema_hash: String,
event_count: usize,
) {
use chrono::Utc;
self.languages.insert(
language.as_str().to_string(),
LanguageConfig {
output_path,
schema_hash,
updated_at: Utc::now().to_rfc3339(),
event_count,
},
);
}
}
#[derive(Debug, Deserialize)]
struct DefinitionsResponse {
content: String,
event_count: usize,
schema_hash: String,
}
pub fn pull(_host: Option<String>, output_override: Option<String>) -> Result<()> {
// Select language
let language = select_language()?;
info!(
"Fetching {} definitions from PostHog...",
language.display_name()
);
// Load credentials
let token = context().token.clone();
let host = token.get_host();
// Determine output path
let output_path = determine_output_path(language, output_override)?;
// Fetch definitions from the server
let response = fetch_definitions(&host, &token.env_id, &token.token, language)?;
info!(
"✓ Fetched {} definitions for {} events",
language.display_name(),
response.event_count
);
// Check if schema has changed for this language
let config = SchemaConfig::load();
if let Some(lang_config) = config.get_language(language) {
if lang_config.schema_hash == response.schema_hash {
info!(
"Schema unchanged for {} (hash: {})",
language.as_str(),
response.schema_hash
);
println!(
"\n{} schema is already up to date!",
language.display_name()
);
println!(" No changes detected - skipping file write.");
return Ok(());
}
}
// Write TypeScript definitions to file
info!("Writing {}...", output_path);
// Create parent directories if they don't exist
if let Some(parent) = Path::new(&output_path).parent() {
if !parent.as_os_str().is_empty() {
fs::create_dir_all(parent)
.context(format!("Failed to create directory {}", parent.display()))?;
}
}
fs::write(&output_path, &response.content).context(format!("Failed to write {output_path}"))?;
info!("✓ Generated {}", output_path);
// Update schema configuration for this language
info!("Updating posthog.json...");
let mut config = SchemaConfig::load();
config.update_language(
language,
output_path.clone(),
response.schema_hash,
response.event_count,
);
config.save()?;
info!("✓ Updated posthog.json");
println!("\n✓ Schema sync complete!");
println!("\nNext steps:");
println!(" 1. Import PostHog from your generated module:");
println!(" import posthog from './{output_path}'");
println!(" 2. Use typed events with autocomplete and type safety:");
println!(" posthog.captureTyped('event_name', {{ property: 'value' }})");
println!(" 3. Or use regular capture() for flexibility:");
println!(" posthog.capture('dynamic_event', {{ any: 'data' }})");
println!();
Ok(())
}
fn determine_output_path(language: Language, output_override: Option<String>) -> Result<String> {
// If CLI override is provided, use it (and normalize it)
if let Some(path) = output_override {
return Ok(normalize_output_path(&path, language));
}
// Check if posthog.json exists and has an output_path for this language
let config = SchemaConfig::load();
if let Some(path) = config.get_output_path(language) {
return Ok(path);
}
// Prompt user for output path
let default_filename = language.default_output_path();
let current_dir = std::env::current_dir()
.ok()
.and_then(|p| p.to_str().map(String::from))
.unwrap_or_else(|| ".".to_string());
let help_message = format!(
"Your app will import PostHog from this file, so it should be accessible \
throughout your codebase (e.g., src/lib/, app/lib/, or your project root). \
This path will be saved in posthog.json and can be changed later. \
Current directory: {current_dir}"
);
let path = Text::new(&format!(
"Where should we save the {} typed PostHog module?",
language.display_name()
))
.with_default(default_filename)
.with_help_message(&help_message)
.prompt()
.unwrap_or(default_filename.to_string());
Ok(normalize_output_path(&path, language))
}
fn normalize_output_path(path: &str, language: Language) -> String {
let path_obj = Path::new(path);
// If it's a directory (existing or ends with slash), append default filename
let should_append_filename =
(path_obj.exists() && path_obj.is_dir()) || path.ends_with('/') || path.ends_with('\\');
if should_append_filename {
path_obj
.join(language.default_output_path())
.to_string_lossy()
.into_owned()
} else {
path.to_string()
}
}
pub fn status() -> Result<()> {
// Check authentication
println!("\nPostHog Schema Sync Status\n");
println!("Authentication:");
let token = context().token.clone();
println!(" ✓ Authenticated");
println!(" Host: {}", token.get_host());
println!(" Project ID: {}", token.env_id);
let masked_token = format!(
"{}****{}",
&token.token[..4],
&token.token[token.token.len() - 4..]
);
println!(" Token: {masked_token}");
println!();
// Check schema status
println!("Schema:");
let config = SchemaConfig::load();
if config.languages.is_empty() {
println!(" ✗ No schemas synced");
println!(" Run: posthog-cli exp schema pull");
} else {
println!(" ✓ Schemas synced\n");
for (language_str, lang_config) in &config.languages {
// Parse language to get display name, fallback to raw string if unknown
let display = Language::from_str(language_str)
.map(|l| l.display_name())
.unwrap_or(language_str.as_str());
println!(" {display}:");
println!(" Hash: {}", lang_config.schema_hash);
println!(" Updated: {}", lang_config.updated_at);
println!(" Events: {}", lang_config.event_count);
if Path::new(&lang_config.output_path).exists() {
println!(" File: ✓ {}", lang_config.output_path);
} else {
println!(" File: ! {} (missing)", lang_config.output_path);
}
println!();
}
}
println!();
Ok(())
}
fn fetch_definitions(
host: &str,
env_id: &str,
token: &str,
language: Language,
) -> Result<DefinitionsResponse> {
let url = format!(
"{}/api/projects/{}/event_definitions/{}/",
host,
env_id,
language.as_str()
);
let client = &context().client;
let response = client.get(&url).bearer_auth(token).send().context(format!(
"Failed to fetch {} definitions",
language.display_name()
))?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to fetch {} definitions: HTTP {}",
language.display_name(),
response.status()
));
}
let json: DefinitionsResponse = response.json().context(format!(
"Failed to parse {} definitions response",
language.display_name()
))?;
Ok(json)
}
fn select_language() -> Result<Language> {
let languages = Language::all();
if languages.len() == 1 {
return Ok(languages[0]);
}
let language_strs: Vec<&str> = languages.iter().map(|l| l.display_name()).collect();
let selected = Select::new("Which language would you like to download?", language_strs)
.prompt()
.context("Failed to select language")?;
// Find the language that matches the selected display name
languages
.into_iter()
.find(|l| l.display_name() == selected)
.ok_or_else(|| anyhow::anyhow!("Invalid language selection"))
}

View File

@@ -1,5 +1,8 @@
use anyhow::Error;
use inquire::Text;
use anyhow::{Context, Error};
use inquire::{Select, Text};
use serde::{Deserialize, Serialize};
use std::thread;
use std::time::Duration;
use tracing::info;
use crate::{
@@ -7,14 +10,265 @@ use crate::{
utils::auth::{host_validator, token_validator, CredentialProvider, HomeDirProvider, Token},
};
pub fn login() -> Result<(), Error> {
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri_complete: String,
expires_in: u64,
interval: u64,
}
#[derive(Debug, Serialize)]
struct PollRequest {
device_code: String,
}
#[derive(Debug, Deserialize)]
struct PollResponse {
status: String,
personal_api_key: Option<String>,
project_id: Option<String>,
}
pub fn login(host_override: Option<String>) -> Result<(), Error> {
login_with_use_cases(host_override, vec!["schema", "error_tracking"])
}
pub fn login_with_use_cases(
host_override: Option<String>,
use_cases: Vec<&str>,
) -> Result<(), Error> {
let host = if let Some(override_host) = host_override {
// Strip trailing slashes to avoid double slashes in URLs
override_host.trim_end_matches('/').to_string()
} else {
// Prompt user to select region or manual login
let options = vec!["US", "EU", "Manual"];
let selection = Select::new("Select your PostHog region:", options)
.with_help_message("Choose the region where your PostHog data is hosted, or 'Manual' to enter your own details")
.prompt()?;
match selection {
"US" => "https://us.posthog.com".to_string(),
"EU" => "https://eu.posthog.com".to_string(),
"Manual" => {
return manual_login();
}
_ => unreachable!(),
}
};
info!("🔐 Starting OAuth Device Flow authentication...");
info!("Connecting to: {}", host);
// Step 1: Request device code
let device_data = request_device_code(&host)?;
// Add use_cases parameter to the verification URL
let use_cases_param = use_cases.join(",");
let verification_url = if device_data.verification_uri_complete.contains('?') {
format!(
"{}&use_cases={}",
device_data.verification_uri_complete, use_cases_param
)
} else {
format!(
"{}?use_cases={}",
device_data.verification_uri_complete, use_cases_param
)
};
println!();
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!(" 📱 Authorization Required");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!();
println!("To authenticate, visit this URL in your browser:");
println!(" {verification_url}");
println!();
println!("Your authorization code:");
println!("{}", device_data.user_code);
println!();
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!();
// Step 2: Try to open browser
if let Err(e) = open_browser(&verification_url) {
info!("Could not open browser automatically: {}", e);
info!("Please open the URL manually");
} else {
info!("✓ Opened browser for authorization");
}
// Step 3: Poll for authorization
info!("Waiting for authorization...");
let poll_response = poll_for_authorization(
&host,
&device_data.device_code,
device_data.interval,
device_data.expires_in,
)?;
info!("✓ Successfully authenticated!");
// Step 4: Save credentials
let token = Token {
host: Some(host),
token: poll_response.personal_api_key.unwrap(),
env_id: poll_response.project_id.unwrap(),
};
let provider = HomeDirProvider;
provider.store_credentials(token)?;
info!("Token saved to: {}", provider.report_location());
complete_login(&provider, "interactive_login")
}
fn request_device_code(host: &str) -> Result<DeviceCodeResponse, Error> {
let client = reqwest::blocking::Client::new();
let url = format!("{host}/api/cli-auth/device-code/");
let response = client
.post(&url)
.header("Content-Type", "application/json")
.send()
.context("Failed to request device code")?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to request device code: HTTP {}",
response.status()
));
}
let device_data: DeviceCodeResponse = response
.json()
.context("Failed to parse device code response")?;
Ok(device_data)
}
fn open_browser(url: &str) -> Result<(), Error> {
// Try to open browser using platform-specific commands
#[cfg(target_os = "macos")]
{
std::process::Command::new("open")
.arg(url)
.spawn()
.context("Failed to open browser")?;
}
#[cfg(target_os = "linux")]
{
std::process::Command::new("xdg-open")
.arg(url)
.spawn()
.context("Failed to open browser")?;
}
#[cfg(target_os = "windows")]
{
std::process::Command::new("cmd")
.args(&["/C", "start", url])
.spawn()
.context("Failed to open browser")?;
}
Ok(())
}
fn poll_for_authorization(
host: &str,
device_code: &str,
interval_seconds: u64,
expires_in_seconds: u64,
) -> Result<PollResponse, Error> {
let client = reqwest::blocking::Client::new();
let url = format!("{host}/api/cli-auth/poll/");
let max_attempts = (expires_in_seconds / interval_seconds) + 1;
let poll_interval = Duration::from_secs(interval_seconds);
for attempt in 1..=max_attempts {
thread::sleep(poll_interval);
let request = PollRequest {
device_code: device_code.to_string(),
};
let response = client
.post(&url)
.header("Content-Type", "application/json")
.json(&request)
.send()
.context("Failed to poll for authorization")?;
let status_code = response.status();
if status_code.as_u16() == 202 {
// Still pending
if attempt % 3 == 0 {
info!(
"Still waiting for authorization... (attempt {}/{})",
attempt, max_attempts
);
}
continue;
}
// Parse response body for both success and error cases
let poll_response: PollResponse =
response.json().context("Failed to parse poll response")?;
if status_code.is_success() && poll_response.status == "authorized" {
return Ok(poll_response);
}
if status_code.as_u16() == 400 && poll_response.status == "expired" {
return Err(anyhow::anyhow!(
"Authorization code expired. Please try again."
));
}
return Err(anyhow::anyhow!(
"Unexpected response during polling: HTTP {} - status: {}",
status_code,
poll_response.status
));
}
Err(anyhow::anyhow!(
"Authorization timed out. Please try again."
))
}
fn complete_login(provider: &HomeDirProvider, command_name: &str) -> Result<(), Error> {
// Login is the only command that doesn't have a context coming in - because it modifies the context
init_context(None, false)?;
context().capture_command_invoked(command_name);
println!();
println!("🎉 Authentication complete!");
println!("Credentials saved to: {}", provider.report_location());
println!();
println!("You can now use the CLI:");
println!(" posthog-cli schema pull");
println!();
Ok(())
}
fn manual_login() -> Result<(), Error> {
info!("🔐 Manual login...");
let host = Text::new("Enter the PostHog host URL")
.with_default("https://us.posthog.com")
.with_validator(host_validator)
.prompt()?;
let env_id =
Text::new("Enter your project ID (the number in your posthog homepage url)").prompt()?;
Text::new("Enter your project ID (the number in your PostHog homepage URL)").prompt()?;
let token = Text::new(
"Enter your personal API token",
@@ -24,17 +278,14 @@ pub fn login() -> Result<(), Error> {
.prompt()?;
let token = Token {
host: Some(host),
host: Some(host.trim_end_matches('/').to_string()),
token,
env_id,
};
let provider = HomeDirProvider;
provider.store_credentials(token)?;
info!("Token saved to: {}", provider.report_location());
// Login is the only command that doesn't have a context coming in - because it modifies the context
init_context(None, false)?;
context().capture_command_invoked("interactive_login");
Ok(())
complete_login(&provider, "manual_login")
}

View File

@@ -22,7 +22,28 @@ fn main() {
match cmd::Cli::run() {
Ok(_) => info!("All done, happy hogging!"),
Err(_) => {
Err(e) => {
match e.exception_id {
Some(id) => {
eprintln!("Oops! {}", e.inner);
eprintln!();
eprintln!("Exception ID: {id}");
}
None => {
eprintln!("Oops! {}", e.inner);
let mut source = e.inner.source();
if source.is_some() {
eprintln!("\nCaused by:");
let mut index = 0;
while let Some(err) = source {
eprintln!(" {index}: {err}");
source = err.source();
index += 1;
}
}
}
};
std::process::exit(1);
}
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 41 KiB

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

After

Width:  |  Height:  |  Size: 40 KiB

View File

@@ -187,3 +187,20 @@ export const getMinimumEquivalentScopes = (scopes: string[]): string[] => {
return `${object}:${action}`
})
}
/** Convert scopes array format to object format for easier UI manipulation */
export const scopesArrayToObject = (scopes: string[]): Record<string, string> => {
const result: Record<string, string> = {}
scopes.forEach((scope) => {
const [key, action] = scope.split(':')
if (key && action) {
result[key] = action
}
})
return result
}
/** Convert scopes object format back to array format for API submission */
export const scopesObjectToArray = (scopesObj: Record<string, string>): string[] => {
return Object.entries(scopesObj).map(([key, action]) => `${key}:${action}`)
}

View File

@@ -15,6 +15,7 @@ const pathsWithoutProjectId = [
'oauth',
'shared',
'embedded',
'cli',
'render_query',
]

View File

@@ -13,6 +13,7 @@ export const appScenes: Record<Scene | string, () => any> = {
[Scene.BillingSection]: () => import('./billing/BillingSection'),
[Scene.Billing]: () => import('./billing/Billing'),
[Scene.Canvas]: () => import('./notebooks/NotebookCanvasScene'),
[Scene.CLIAuthorize]: () => import('./authentication/CLIAuthorize'),
[Scene.Cohort]: () => import('./cohorts/Cohort'),
[Scene.CohortCalculationHistory]: () => import('./cohorts/CohortCalculationHistory'),
[Scene.Cohorts]: () => import('./cohorts/Cohorts'),

View File

@@ -0,0 +1,293 @@
import { useActions, useValues } from 'kea'
import { Form } from 'kea-forms'
import { Fragment, useState } from 'react'
import { IconCode, IconGear, IconWarning } from '@posthog/icons'
import { IconInfo } from '@posthog/icons'
import { LemonButton, LemonInput, LemonSegmentedButton, LemonSelect, Link, Tooltip } from '@posthog/lemon-ui'
import { BridgePage } from 'lib/components/BridgePage/BridgePage'
import { LemonBanner } from 'lib/lemon-ui/LemonBanner'
import { LemonField } from 'lib/lemon-ui/LemonField'
import { LemonModal } from 'lib/lemon-ui/LemonModal'
import { IconErrorOutline } from 'lib/lemon-ui/icons'
import { API_SCOPES } from 'lib/scopes'
import { capitalizeFirstLetter } from 'lib/utils'
import { SceneExport } from 'scenes/sceneTypes'
import { urls } from 'scenes/urls'
import { cliAuthorizeLogic } from './cliAuthorizeLogic'
export const scene: SceneExport = {
component: CLIAuthorize,
logic: cliAuthorizeLogic,
}
function ScopesList({
scopes,
formScopeRadioValues,
displayScopeValues,
setScopeRadioValue,
showAll = false,
}: {
scopes: typeof API_SCOPES
formScopeRadioValues: Record<string, string>
displayScopeValues?: Record<string, string>
setScopeRadioValue: (key: string, action: string) => void
showAll?: boolean
}): JSX.Element {
// Use displayScopeValues for filtering if provided, otherwise use formScopeRadioValues
const filterValues = displayScopeValues ?? formScopeRadioValues
const visibleScopes = showAll
? scopes
: scopes.filter((scope) => filterValues[scope.key] && filterValues[scope.key] !== 'none')
if (!showAll && visibleScopes.length === 0) {
return (
<div className="text-muted text-sm italic py-2">
No scopes selected. Click "Manage scopes" to select permissions.
</div>
)
}
return (
<div>
{visibleScopes.map(({ key, disabledActions, warnings, info }) => {
return (
<Fragment key={key}>
<div className="flex items-center justify-between gap-2 min-h-8">
<div className="flex items-center gap-1">
<b>{capitalizeFirstLetter(key.replace(/_/g, ' '))}</b>
{info ? (
<Tooltip title={info}>
<IconInfo className="text-secondary text-base" />
</Tooltip>
) : null}
</div>
<LemonSegmentedButton
onChange={(value) => setScopeRadioValue(key, value)}
value={formScopeRadioValues[key] ?? 'none'}
options={[
{ label: 'No access', value: 'none' },
{
label: 'Read',
value: 'read',
disabledReason: disabledActions?.includes('read')
? 'Does not apply to this resource'
: undefined,
},
{
label: 'Write',
value: 'write',
disabledReason: disabledActions?.includes('write')
? 'Does not apply to this resource'
: undefined,
},
]}
size="xsmall"
/>
</div>
{warnings?.[formScopeRadioValues[key]] && (
<div className="flex items-start gap-2 text-xs italic pb-2">
<IconWarning className="text-base text-secondary mt-0.5" />
<span>{warnings[formScopeRadioValues[key]]}</span>
</div>
)}
</Fragment>
)
})}
</div>
)
}
export function CLIAuthorize(): JSX.Element {
const {
authorize,
isSuccess,
projects,
projectsLoading,
isAuthorizeSubmitting,
formScopeRadioValues,
displayedScopeValues,
missingSchemaScopes,
missingErrorTrackingScopes,
} = useValues(cliAuthorizeLogic)
const { setAuthorizeValue, setScopeRadioValue, updateDisplayedScopeSnapshot } = useActions(cliAuthorizeLogic)
const [isScopesModalOpen, setIsScopesModalOpen] = useState(false)
const handleOpenModal = (): void => {
setIsScopesModalOpen(true)
}
const handleCloseModal = (): void => {
updateDisplayedScopeSnapshot()
setIsScopesModalOpen(false)
}
return (
<BridgePage
view="login"
{...(!isSuccess
? {
hedgehog: true as const,
message: (
<>
Authorize
<br />
PostHog CLI
</>
),
}
: { hedgehog: false as const })}
>
{isSuccess ? (
<div className="text-center space-y-4">
<h2>CLI Authorization Complete</h2>
<LemonBanner type="success">
<div className="space-y-2">
<p className="font-semibold">Your CLI has been authorized successfully!</p>
<p>You can now close this window and return to your terminal.</p>
</div>
</LemonBanner>
<div className="text-muted text-sm mt-4">
<p>
A Personal API Key has been created for your CLI. You can manage your API keys in{' '}
<Link to={urls.settings('user-api-keys')} className="font-semibold">
Settings Personal API Keys
</Link>
</p>
</div>
</div>
) : (
<div className="space-y-4">
<h2>Authorize CLI Access</h2>
<p className="text-muted text-sm">
The PostHog CLI should have displayed a 9-character code (e.g., ABCD-1234). Enter it below to
authorize your CLI.
</p>
<Form logic={cliAuthorizeLogic} formKey="authorize" enableFormOnSubmit className="space-y-4">
<LemonField name="userCode" label="Authorization Code">
<LemonInput
className="ph-ignore-input font-mono text-lg tracking-wider"
autoFocus
data-attr="cli-auth-code"
placeholder="ABCD-1234"
maxLength={9}
value={authorize.userCode}
onChange={(value) => setAuthorizeValue('userCode', value.toUpperCase())}
autoComplete="off"
autoCorrect="off"
autoCapitalize="characters"
spellCheck={false}
/>
</LemonField>
<LemonField name="projectId" label="Project">
<LemonSelect
data-attr="cli-project-select"
placeholder="Select a project"
value={authorize.projectId}
onChange={(value) => setAuthorizeValue('projectId', value)}
options={projects.map((project: { id: number; name: string }) => ({
label: project.name,
value: project.id,
}))}
loading={projectsLoading}
/>
</LemonField>
<div className="mt-4 mb-2">
<div className="flex items-center justify-between mb-2">
<h3>Scopes</h3>
<LemonButton
type="secondary"
size="small"
icon={<IconGear />}
onClick={handleOpenModal}
>
Manage scopes
</LemonButton>
</div>
<p className="text-muted text-sm mb-2">
Selected permissions for the CLI. Only grant the scopes you need.
</p>
</div>
<LemonField name="scopes">
{({ error }) => (
<>
{error && (
<div className="text-danger flex items-center gap-1 text-sm mb-2">
<IconErrorOutline className="text-xl" /> {error}
</div>
)}
<ScopesList
scopes={API_SCOPES}
formScopeRadioValues={formScopeRadioValues}
displayScopeValues={displayedScopeValues}
setScopeRadioValue={setScopeRadioValue}
showAll={false}
/>
</>
)}
</LemonField>
<LemonModal
title="Manage CLI Scopes"
description="Select which permissions to grant to the CLI. Only select the scopes you need."
isOpen={isScopesModalOpen}
onClose={handleCloseModal}
footer={
<LemonButton type="primary" onClick={handleCloseModal}>
Done
</LemonButton>
}
>
<div className="max-h-96 overflow-y-auto">
<ScopesList
scopes={API_SCOPES}
formScopeRadioValues={formScopeRadioValues}
setScopeRadioValue={setScopeRadioValue}
showAll={true}
/>
</div>
</LemonModal>
{(missingSchemaScopes || missingErrorTrackingScopes) && (
<div className="space-y-2 mt-2">
{missingSchemaScopes && (
<LemonBanner type="warning">
<b>Schema management unavailable:</b> The CLI needs both{' '}
<code>event_definition</code> and <code>property_definition</code> permissions
(read or write) to manage schemas.
</LemonBanner>
)}
{missingErrorTrackingScopes && (
<LemonBanner type="warning">
<b>Error tracking unavailable:</b> The CLI needs <code>error_tracking</code>{' '}
permissions (read or write) to manage error tracking.
</LemonBanner>
)}
</div>
)}
<LemonButton
type="primary"
status="alt"
htmlType="submit"
data-attr="cli-authorize-submit"
fullWidth
center
loading={isAuthorizeSubmitting}
size="large"
icon={<IconCode />}
>
Authorize CLI
</LemonButton>
</Form>
</div>
)}
</BridgePage>
)
}

View File

@@ -0,0 +1,242 @@
import { actions, afterMount, kea, listeners, path, reducers, selectors } from 'kea'
import { forms } from 'kea-forms'
import { loaders } from 'kea-loaders'
import { urlToAction } from 'kea-router'
import api from 'lib/api'
import { scopesArrayToObject } from 'lib/scopes'
import type { cliAuthorizeLogicType } from './cliAuthorizeLogicType'
export type CLIUseCase = 'schema' | 'error_tracking'
export interface CLIAuthorizeForm {
userCode: string
projectId: number | null
scopes: string[]
}
// Map use cases to their required scopes
const USE_CASE_SCOPES: Record<CLIUseCase, string[]> = {
schema: ['event_definition:read', 'property_definition:read'],
error_tracking: ['error_tracking:write'],
}
// Default use cases when none are specified
const DEFAULT_USE_CASES: CLIUseCase[] = ['schema', 'error_tracking']
function getDefaultScopesForUseCases(useCases: CLIUseCase[]): string[] {
const scopesSet = new Set<string>()
for (const useCase of useCases) {
const scopes = USE_CASE_SCOPES[useCase] || []
scopes.forEach((scope) => scopesSet.add(scope))
}
return Array.from(scopesSet)
}
// Pre-compute default scopes
const DEFAULT_SCOPES = getDefaultScopesForUseCases(DEFAULT_USE_CASES)
export const cliAuthorizeLogic = kea<cliAuthorizeLogicType>([
path(['scenes', 'authentication', 'cliAuthorizeLogic']),
actions({
setSuccess: (success: boolean) => ({ success }),
setScopeRadioValue: (key: string, action: string) => ({ key, action }),
setRequestedUseCases: (useCases: CLIUseCase[]) => ({ useCases }),
updateDisplayedScopeSnapshot: true,
setDisplayedScopeValues: (values: Record<string, string>) => ({ values }),
}),
reducers({
isSuccess: [
false,
{
setSuccess: (_, { success }) => success,
},
],
requestedUseCases: [
DEFAULT_USE_CASES,
{
setRequestedUseCases: (_, { useCases }) => useCases,
},
],
displayedScopeValues: [
{} as Record<string, string>,
{
setDisplayedScopeValues: (_, { values }) => values,
},
],
}),
loaders(() => ({
projects: [
[] as { id: number; name: string }[],
{
loadProjects: async () => {
const response = await api.get('api/projects/')
return response.results || []
},
},
],
})),
forms(() => ({
authorize: {
defaults: {
userCode: '',
projectId: null,
scopes: DEFAULT_SCOPES,
} as CLIAuthorizeForm,
errors: ({ userCode, projectId, scopes }) => ({
userCode: !userCode
? 'Please enter the code from your terminal'
: userCode.length !== 9
? 'Code must be 9 characters (XXXX-XXXX)'
: undefined,
projectId: !projectId ? 'Please select a project' : undefined,
scopes: !scopes?.length ? ('Your API key needs at least one scope' as any) : undefined,
}),
submit: async ({ userCode, projectId, scopes }) => {
try {
const response = await api.create('api/cli-auth/authorize/', {
user_code: userCode.toUpperCase().replace(/\s/g, ''),
project_id: projectId,
scopes: scopes,
})
return response
} catch (error: any) {
const errorCode = error?.code || error?.error
if (errorCode === 'invalid_code') {
throw { userCode: 'Invalid or expired code. Please try again.' }
} else if (errorCode === 'expired') {
throw { userCode: 'This code has expired. Please request a new code in your terminal.' }
} else if (errorCode === 'access_denied') {
throw { projectId: 'You do not have access to this project.' }
} else if (errorCode === 'invalid_project') {
throw { projectId: 'Project not found.' }
} else {
throw { userCode: 'An error occurred. Please try again.' }
}
}
},
},
})),
selectors(() => ({
formScopeRadioValues: [
(s) => [s.authorize],
(authorize): Record<string, string> => {
if (!authorize || !authorize.scopes) {
return {}
}
return scopesArrayToObject(authorize.scopes)
},
],
missingSchemaScopes: [
(s) => [s.authorize, s.requestedUseCases],
(authorize, requestedUseCases): boolean => {
// Only show warning if schema use case was requested
if (!requestedUseCases.includes('schema')) {
return false
}
if (!authorize || !authorize.scopes) {
return false
}
// Warn if missing BOTH event_definition (read or write) AND property_definition (read or write)
// Note: write permissions include read, so having write is sufficient
const hasEventDefinition =
authorize.scopes.includes('event_definition:read') ||
authorize.scopes.includes('event_definition:write')
const hasPropertyDefinition =
authorize.scopes.includes('property_definition:read') ||
authorize.scopes.includes('property_definition:write')
return !hasEventDefinition || !hasPropertyDefinition
},
],
missingErrorTrackingScopes: [
(s) => [s.authorize, s.requestedUseCases],
(authorize, requestedUseCases): boolean => {
// Only show warning if error_tracking use case was requested
if (!requestedUseCases.includes('error_tracking')) {
return false
}
if (!authorize || !authorize.scopes) {
return false
}
// Warn if missing error_tracking entirely (neither read nor write)
// Note: write permissions include read, so having write is sufficient
return (
!authorize.scopes.includes('error_tracking:read') &&
!authorize.scopes.includes('error_tracking:write')
)
},
],
})),
listeners(({ actions, values }) => ({
loadProjectsSuccess: () => {
// Set default project to first project if not already set
if (values.projects.length > 0 && !values.authorize.projectId) {
actions.setAuthorizeValue('projectId', values.projects[0].id)
}
},
setAuthorizeValue: (payload) => {
// Initialize displayed scope values when scopes are first set
if (payload.name === 'scopes' && Object.keys(values.displayedScopeValues).length === 0) {
// Directly compute scope values from the scopes array being set
const scopesArray = payload.value as string[]
const scopeValues = scopesArrayToObject(scopesArray)
actions.setDisplayedScopeValues(scopeValues)
}
},
updateDisplayedScopeSnapshot: () => {
// Update displayed scope values with current form values
actions.setDisplayedScopeValues(values.formScopeRadioValues)
},
submitAuthorizeSuccess: () => {
actions.setSuccess(true)
},
submitAuthorizeFailure: () => {
// Error handling is done in the form errors
},
setRequestedUseCases: ({ useCases }) => {
// Update scopes when requested use cases change
const newScopes = getDefaultScopesForUseCases(useCases)
actions.setAuthorizeValue('scopes', newScopes)
// Directly compute and update displayed scope values
const scopeValues = scopesArrayToObject(newScopes)
actions.setDisplayedScopeValues(scopeValues)
},
setScopeRadioValue: ({ key, action }) => {
if (!values.authorize || !values.authorize.scopes) {
return
}
// Remove existing scope with this key
const filteredScopes = values.authorize.scopes.filter((scope) => !scope.startsWith(`${key}:`))
// Add new scope if not 'none'
const newScopes = action === 'none' ? filteredScopes : [...filteredScopes, `${key}:${action}`]
actions.setAuthorizeValue('scopes', newScopes)
},
})),
urlToAction(({ actions }) => ({
'/cli/authorize': (_, searchParams) => {
const code = searchParams.code
if (code) {
// Set the form field value directly
actions.setAuthorizeValue('userCode', code)
}
// Parse use_cases from URL (comma-separated)
const useCasesParam = searchParams.use_cases
if (useCasesParam) {
const useCases = useCasesParam.split(',').filter((uc: string): uc is CLIUseCase => {
return uc === 'schema' || uc === 'error_tracking'
})
if (useCases.length > 0) {
actions.setRequestedUseCases(useCases)
}
}
},
})),
afterMount(({ actions }) => {
actions.loadProjects()
}),
])

View File

@@ -20,6 +20,7 @@ export enum Scene {
BillingAuthorizationStatus = 'BillingAuthorizationStatus',
BillingSection = 'BillingSection',
Canvas = 'Canvas',
CLIAuthorize = 'CLIAuthorize',
Cohort = 'Cohort',
CohortCalculationHistory = 'CohortCalculationHistory',
Cohorts = 'Cohorts',

View File

@@ -65,6 +65,12 @@ export const sceneConfigurations: Record<Scene | string, SceneConfig> = {
defaultDocsPath: '/blog/introducing-notebooks',
hideProjectNotice: true,
},
[Scene.CLIAuthorize]: {
name: 'Authorize CLI',
projectBased: false,
organizationBased: false,
layout: 'plain',
},
[Scene.Cohort]: { projectBased: true, name: 'Cohort', defaultDocsPath: '/docs/data/cohorts' },
[Scene.CohortCalculationHistory]: { projectBased: true, name: 'Cohort Calculation History' },
[Scene.Cohorts]: {
@@ -674,6 +680,7 @@ export const routes: Record<string, [Scene | string, string]> = {
[urls.site(':url')]: [Scene.Site, 'site'],
[urls.login()]: [Scene.Login, 'login'],
[urls.login2FA()]: [Scene.Login2FA, 'login2FA'],
[urls.cliAuthorize()]: [Scene.CLIAuthorize, 'cliAuthorize'],
[urls.emailMFAVerify()]: [Scene.EmailMFAVerify, 'emailMFAVerify'],
[urls.preflight()]: [Scene.PreflightCheck, 'preflight'],
[urls.signup()]: [Scene.Signup, 'signup'],

View File

@@ -9,6 +9,7 @@ import api from 'lib/api'
import { CodeSnippet } from 'lib/components/CodeSnippet'
import { OrganizationMembershipLevel } from 'lib/constants'
import { lemonToast } from 'lib/lemon-ui/LemonToast/LemonToast'
import { scopesArrayToObject, scopesObjectToArray } from 'lib/scopes'
import { hasMembershipLevelOrHigher, organizationAllowsPersonalApiKeysForMembers } from 'lib/utils/permissioning'
import { urls } from 'scenes/urls'
import { userLogic } from 'scenes/userLogic'
@@ -138,14 +139,7 @@ export const personalAPIKeysLogic = kea<personalAPIKeysLogicType>([
formScopeRadioValues: [
(s) => [s.editingKey],
(editingKey): Record<string, string> => {
const result: Record<string, string> = {}
editingKey.scopes.forEach((scope) => {
const [key, action] = scope.split(':')
result[key] = action
})
return result
return scopesArrayToObject(editingKey.scopes)
},
],
allAccessSelected: [
@@ -374,11 +368,19 @@ export const personalAPIKeysLogic = kea<personalAPIKeysLogicType>([
},
setScopeRadioValue: ({ key, action }) => {
const newScopes = values.editingKey.scopes.filter((scope) => !scope.startsWith(key))
if (action !== 'none') {
newScopes.push(`${key}:${action}`)
// Convert current scopes array to object for easier manipulation
const scopesObject = scopesArrayToObject(values.editingKey.scopes)
// Update the specific scope
if (action === 'none') {
delete scopesObject[key]
} else {
scopesObject[key] = action
}
// Convert back to array format
const newScopes = scopesObjectToArray(scopesObject)
actions.setEditingKeyValue('scopes', newScopes)
},

View File

@@ -93,6 +93,7 @@ export const urls = {
login: (): string => '/login',
login2FA: (): string => '/login/2fa',
login2FASetup: (): string => '/login/2fa_setup',
cliAuthorize: (): string => '/cli/authorize',
emailMFAVerify: (): string => '/login/verify',
liveDebugger: (): string => '/live-debugger',
passwordReset: (): string => '/reset',

View File

@@ -68,6 +68,7 @@ from . import (
app_metrics,
async_migration,
authentication,
cli_auth,
comments,
dead_letter_queue,
debug_ch_queries,
@@ -521,6 +522,7 @@ router.register(r"login/email-mfa", authentication.EmailMFAViewSet, "login_email
router.register(r"reset", authentication.PasswordResetViewSet, "password_reset")
router.register(r"users", user.UserViewSet, "users")
router.register(r"personal_api_keys", personal_api_key.PersonalAPIKeyViewSet, "personal_api_keys")
router.register(r"cli-auth", cli_auth.CLIAuthViewSet, "cli_auth")
router.register(r"instance_status", instance_status.InstanceStatusViewSet, "instance_status")
router.register(r"dead_letter_queue", dead_letter_queue.DeadLetterQueueViewSet, "dead_letter_queue")
router.register(r"async_migrations", async_migration.AsyncMigrationsViewset, "async_migrations")

355
posthog/api/cli_auth.py Normal file
View File

@@ -0,0 +1,355 @@
"""
CLI Authentication API using OAuth2 Device Flow
This implements the device authorization flow (RFC 8628) for the PostHog CLI.
Users can authenticate without copying/pasting API keys.
Flow:
1. CLI requests device code
2. User opens browser and authorizes
3. CLI polls for completion
4. Returns Personal API Key
"""
import string
import secrets
from django.core.cache import cache
from django.utils import timezone
from rest_framework import serializers, status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from posthog.auth import SessionAuthentication
from posthog.models import PersonalAPIKey, User
from posthog.models.personal_api_key import hash_key_value
from posthog.models.utils import generate_random_token_personal, mask_key_value
class CLIAuthSessionAuthentication(SessionAuthentication):
"""
Custom session authentication for CLI authorization.
The user_code serves as an additional authorization token beyond CSRF,
so we can safely skip CSRF validation for this specific endpoint.
"""
def authenticate(self, request):
"""Authenticate using session authentication."""
return super().authenticate(request)
def enforce_csrf(self, request):
"""Skip CSRF enforcement - the user_code acts as the authorization token."""
return None
# Device code lives for 10 minutes
DEVICE_CODE_EXPIRY_SECONDS = 600
# CLI polling interval (5 seconds)
CLI_POLL_INTERVAL_SECONDS = 5
# Scopes granted to CLI
CLI_SCOPES = [
"event_definition:read",
"property_definition:read",
"error_tracking:write",
]
def generate_user_code() -> str:
"""Generate a human-readable code like 'ABCD-1234'"""
letters = "".join(secrets.choice(string.ascii_uppercase) for _ in range(4))
numbers = "".join(secrets.choice(string.digits) for _ in range(4))
return f"{letters}-{numbers}"
def generate_device_code() -> str:
"""Generate a secure random device code"""
return secrets.token_urlsafe(32)
def get_device_cache_key(device_code: str) -> str:
"""Get cache key for device code"""
return f"cli_device:{device_code}"
def get_user_code_cache_key(user_code: str) -> str:
"""Get cache key for user code"""
return f"cli_user_code:{user_code}"
class DeviceCodeRequestSerializer(serializers.Serializer):
"""Request to initiate device authorization flow"""
pass # No input required
class DeviceCodeResponseSerializer(serializers.Serializer):
"""Response containing device and user codes"""
device_code = serializers.CharField(help_text="Code for CLI to poll with")
user_code = serializers.CharField(help_text="Code for user to enter in browser")
verification_uri = serializers.CharField(help_text="URL for user to visit")
verification_uri_complete = serializers.CharField(help_text="URL with code pre-filled")
expires_in = serializers.IntegerField(help_text="Seconds until code expires")
interval = serializers.IntegerField(help_text="Polling interval in seconds")
class DeviceAuthorizationSerializer(serializers.Serializer):
"""User authorizes the device code"""
user_code = serializers.CharField(max_length=9, help_text="The user code displayed in CLI")
project_id = serializers.IntegerField(help_text="The project to authorize CLI access for")
scopes = serializers.ListField(
child=serializers.CharField(),
required=False,
help_text="Scopes to grant to the CLI (defaults to CLI_SCOPES)",
)
class DevicePollSerializer(serializers.Serializer):
"""CLI polls for authorization status"""
device_code = serializers.CharField(help_text="Device code from initial request")
class DevicePollResponseSerializer(serializers.Serializer):
"""Response to poll request"""
status = serializers.ChoiceField(choices=["pending", "authorized", "expired"])
personal_api_key = serializers.CharField(required=False, help_text="The API key (only if authorized)")
label = serializers.CharField(required=False, help_text="Label of the created key") # type: ignore[assignment]
project_id = serializers.CharField(required=False, help_text="The project ID (only if authorized)")
class CLIAuthViewSet(viewsets.ViewSet):
"""
OAuth2 Device Authorization Flow for CLI authentication
Endpoints:
- POST /api/cli-auth/device-code/ (no auth required)
- POST /api/cli-auth/authorize/ (session auth required)
- POST /api/cli-auth/poll/ (no auth required)
"""
def get_permissions(self):
"""Authorize endpoint requires auth, others don't"""
if getattr(self, "action", None) == "authorize":
return [IsAuthenticated()]
return [AllowAny()]
def get_authenticators(self):
"""Only use session auth for browser-based authorization"""
action = getattr(self, "action", None)
# Check both action and URL path since action might not be set yet
if action == "authorize" or (hasattr(self, "request") and "authorize" in self.request.path):
return [CLIAuthSessionAuthentication()]
return []
@action(methods=["POST"], detail=False, url_path="device-code")
def device_code(self, request):
"""
Step 1: CLI requests device code
Returns device code for polling and user code for browser authorization.
"""
device_code = generate_device_code()
user_code = generate_user_code()
# Store in cache with expiry
device_cache_key = get_device_cache_key(device_code)
cache.set(
device_cache_key,
{
"user_code": user_code,
"status": "pending",
"created_at": timezone.now().isoformat(),
},
timeout=DEVICE_CODE_EXPIRY_SECONDS,
)
# Also create reverse lookup (user_code -> device_code) for authorization
user_code_cache_key = get_user_code_cache_key(user_code)
cache.set(user_code_cache_key, device_code, timeout=DEVICE_CODE_EXPIRY_SECONDS)
# Get the base URL for verification
# In production this would be the actual domain
base_url = request.build_absolute_uri("/").rstrip("/")
response_data = {
"device_code": device_code,
"user_code": user_code,
"verification_uri": f"{base_url}/cli/authorize",
"verification_uri_complete": f"{base_url}/cli/authorize?code={user_code}",
"expires_in": DEVICE_CODE_EXPIRY_SECONDS,
"interval": CLI_POLL_INTERVAL_SECONDS,
}
serializer = DeviceCodeResponseSerializer(response_data)
return Response(serializer.data, status=status.HTTP_200_OK)
@action(methods=["POST"], detail=False, url_path="authorize")
def authorize(self, request):
"""
Step 2: User authorizes in browser
Requires authenticated session. Creates a Personal API Key and marks
the device code as authorized.
The user_code itself acts as a single-use authorization token,
providing additional security beyond CSRF tokens.
"""
serializer = DeviceAuthorizationSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
user_code = serializer.validated_data["user_code"]
project_id = serializer.validated_data["project_id"]
scopes = serializer.validated_data.get("scopes", CLI_SCOPES)
# Validate that at least one scope is provided
if not scopes or len(scopes) == 0:
return Response(
{"error": "invalid_request", "error_description": "At least one scope is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# Look up device code from user code
user_code_cache_key = get_user_code_cache_key(user_code)
device_code = cache.get(user_code_cache_key)
if not device_code:
return Response(
{"error": "invalid_code", "error_description": "User code not found or expired"},
status=status.HTTP_400_BAD_REQUEST,
)
# Get device code data
device_cache_key = get_device_cache_key(device_code)
device_data = cache.get(device_cache_key)
if not device_data:
return Response(
{"error": "expired", "error_description": "Device code expired"}, status=status.HTTP_400_BAD_REQUEST
)
# Prevent duplicate authorization (race condition)
if device_data.get("status") == "authorized":
return Response(
{"error": "already_authorized", "error_description": "This code has already been authorized"},
status=status.HTTP_400_BAD_REQUEST,
)
# Verify user has access to the project
user: User = request.user
from posthog.models import Team
try:
team = Team.objects.get(id=project_id)
# Check if user has access to this team's organization
if not user.organization_memberships.filter(organization=team.organization).exists():
return Response(
{"error": "access_denied", "error_description": "You do not have access to this project"},
status=status.HTTP_403_FORBIDDEN,
)
except Team.DoesNotExist:
return Response(
{"error": "invalid_project", "error_description": "Project not found"},
status=status.HTTP_400_BAD_REQUEST,
)
# Create Personal API Key for the CLI
api_key_value = generate_random_token_personal()
mask_value = mask_key_value(api_key_value)
secure_value = hash_key_value(api_key_value)
# Label max length is 40 chars, so truncate if needed
timestamp = timezone.now().strftime("%Y-%m-%d %H:%M")
max_team_name_len = 40 - len("CLI - ") - len(f" - {timestamp}")
team_name_truncated = team.name[:max_team_name_len] if len(team.name) > max_team_name_len else team.name
label = f"CLI - {team_name_truncated} - {timestamp}"
PersonalAPIKey.objects.create(
user=user,
label=label,
secure_value=secure_value,
mask_value=mask_value,
scopes=scopes,
)
# Mark device as authorized and store the API key
device_data["status"] = "authorized"
device_data["personal_api_key"] = api_key_value
device_data["label"] = label
device_data["project_id"] = str(project_id)
device_data["authorized_at"] = timezone.now().isoformat()
device_data["user_id"] = user.id
# Update cache with longer TTL to ensure CLI can poll
cache.set(device_cache_key, device_data, timeout=60) # 1 minute to retrieve
return Response(
{
"status": "success",
"label": label,
"mask_value": mask_value,
},
status=status.HTTP_200_OK,
)
@action(methods=["POST"], detail=False, url_path="poll")
def poll(self, request):
"""
Step 3: CLI polls for authorization status
Returns:
- 202: Still pending (keep polling)
- 200: Authorized (includes API key)
- 400: Expired or invalid
"""
serializer = DevicePollSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
device_code = serializer.validated_data["device_code"]
# Look up device code
device_cache_key = get_device_cache_key(device_code)
device_data = cache.get(device_cache_key)
if not device_data:
return Response(
{"status": "expired", "error": "expired_token", "error_description": "Device code expired"},
status=status.HTTP_400_BAD_REQUEST,
)
if device_data["status"] == "pending":
# Still waiting for authorization
return Response(
{"status": "pending"},
status=status.HTTP_202_ACCEPTED, # Indicates to keep polling
)
if device_data["status"] == "authorized":
# Success! Return the API key
response_data = {
"status": "authorized",
"personal_api_key": device_data["personal_api_key"],
"label": device_data["label"],
"project_id": device_data["project_id"],
}
# Clean up - key has been retrieved
cache.delete(device_cache_key)
user_code_cache_key = get_user_code_cache_key(device_data["user_code"])
cache.delete(user_code_cache_key)
response_serializer = DevicePollResponseSerializer(response_data)
return Response(response_serializer.data, status=status.HTTP_200_OK)
# Unknown status
return Response(
{"error": "invalid_request", "error_description": "Invalid device code status"},
status=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -1,8 +1,10 @@
import json
import hashlib
from datetime import datetime
from typing import Any, Literal, Optional, cast
from django.core.cache import cache
from django.db.models import Manager
from django.db.models import Manager, Q
from loginas.utils import is_impersonated_session
from rest_framework import mixins, request, response, serializers, status, viewsets
@@ -16,7 +18,7 @@ from posthog.constants import AvailableFeature, EventDefinitionType
from posthog.event_usage import report_user_action
from posthog.exceptions import EnterpriseFeatureException
from posthog.filters import TermSearchFilterBackend, term_search_filter_sql
from posthog.models import EventDefinition, Team
from posthog.models import EventDefinition, EventSchema, Team
from posthog.models.activity_logging.activity_log import Detail, log_activity
from posthog.models.user import User
from posthog.models.utils import UUIDT
@@ -261,6 +263,269 @@ class EventDefinitionViewSet(
)
return response.Response(status=status.HTTP_204_NO_CONTENT)
# Version of the TypeScript generator - increment when changing the structure
# This ensures clients update even when schemas don't change
TYPESCRIPT_GENERATOR_VERSION = "1.0.0"
@action(detail=False, methods=["GET"], url_path="typescript", required_scopes=["event_definition:read"])
def typescript_definitions(self, *args, **kwargs):
"""Generate TypeScript definitions from event schemas"""
# System events that users should be able to manually capture
# These are commonly used in user code and should be typed
included_system_events = [
"$pageview", # Manually captured in SPAs (React Router, Vue Router, etc.)
"$pageleave", # Sometimes manually captured alongside $pageview
"$screen", # Manually captured in mobile apps (iOS, Android, React Native, Flutter)
]
# Fetch event definitions: either non-system events or explicitly included system events
event_definitions = (
EventDefinition.objects.filter(team__project_id=self.project_id)
.filter(
Q(name__in=included_system_events) # Include whitelisted system events
| ~Q(name__startswith="$") # Include all non-system events
)
.order_by("name")
)
# Fetch all event schemas with their property groups
event_schemas = (
EventSchema.objects.filter(event_definition__team__project_id=self.project_id)
.select_related("property_group")
.prefetch_related("property_group__properties")
)
# Build a mapping of event_definition_id -> property group properties
schema_map: dict[str, list[Any]] = {}
for event_schema in event_schemas:
event_id = str(event_schema.event_definition_id)
if event_id not in schema_map:
schema_map[event_id] = []
schema_map[event_id].extend(event_schema.property_group.properties.all())
# Calculate deterministic hash based on schema data AND generator version
# This ensures clients update when either schemas or generator structure changes
schema_data = []
for event_def in event_definitions:
properties = schema_map.get(str(event_def.id), [])
prop_data = [(p.name, p.property_type, p.is_required) for p in properties]
schema_data.append((event_def.name, sorted(prop_data)))
# Include version in hash calculation
hash_input = {
"version": self.TYPESCRIPT_GENERATOR_VERSION,
"schemas": schema_data,
}
schema_hash = hashlib.sha256(json.dumps(hash_input, sort_keys=True).encode()).hexdigest()[:32]
# Generate TypeScript definitions
ts_content = self._generate_typescript(event_definitions, schema_map)
return response.Response(
{
"content": ts_content,
"event_count": len(event_definitions),
"schema_hash": schema_hash,
"generator_version": self.TYPESCRIPT_GENERATOR_VERSION,
}
)
def _generate_typescript(self, event_definitions, schema_map):
"""Generate complete TypeScript module with type definitions and exports"""
# Generate file header
header = f"""/**
* GENERATED FILE - DO NOT EDIT
*
* This file was auto-generated by PostHog
* Generated at: {datetime.now().isoformat()}
* Generator version: {self.TYPESCRIPT_GENERATOR_VERSION}
*
* Provides capture() for type-safe events and captureRaw() for flexibility
*/
import originalPostHog from 'posthog-js'
import type {{ CaptureOptions, CaptureResult, PostHog as OriginalPostHog, Properties }} from 'posthog-js'
"""
# Generate event schemas interface
event_schemas_lines = [
"// Define event schemas with their required and optional fields",
"interface EventSchemas {",
]
for event_def in event_definitions:
properties = schema_map.get(str(event_def.id), [])
# Escape event name for use as object key
event_name = event_def.name.replace("'", "\\'")
if not properties:
event_schemas_lines.append(f" '{event_name}': Record<string, any>")
else:
event_schemas_lines.append(f" '{event_name}': {{")
for prop in properties:
ts_type = self._map_property_type(prop.property_type)
optional_marker = "" if prop.is_required else "?"
# Always quote property names (simpler and handles all edge cases)
prop_name = f"'{prop.name.replace("'", "\\'")}'"
event_schemas_lines.append(f" {prop_name}{optional_marker}: {ts_type}")
event_schemas_lines.append(" }")
event_schemas_lines.append("}")
event_schemas = "\n".join(event_schemas_lines)
# Generate type aliases
type_aliases = """
// Type alias for all valid event names
export type EventName = keyof EventSchemas
// Type helper to get properties for a specific event
// Intersects the schema with Record<string, any> to allow additional properties
export type EventProperties<K extends EventName> = EventSchemas[K] & Record<string, any>
// Helper type to check if a type has required properties
type HasRequiredProperties<K extends EventName> = {} extends EventSchemas[K] ? false : true
// Helper to detect if T is exactly 'string' (not a literal)
type IsExactlyString<T> = string extends T ? (T extends string ? true : false) : false
"""
# Generate TypedPostHog interface
typed_posthog_interface = """
// Enhanced PostHog interface with typed capture
interface TypedPostHog extends Omit<OriginalPostHog, 'capture'> {
/**
* Type-safe capture for defined events, or flexible capture for undefined events
*
* Note: For defined events, wrap properties in a variable to allow additional properties:
* const props = { file_size_b: 100, extra: 'data' }
* posthog.capture('downloaded_file', props)
*
* @example
* // Defined event with type safety
* posthog.capture('uploaded_file', {
* file_name: 'test.txt',
* file_size_b: 100
* })
*
* @example
* // For events with all optional properties, properties argument is optional
* posthog.capture('logged_out') // no properties needed
*
* @example
* // Undefined events work with arbitrary properties
* posthog.capture('custom_event', { whatever: 'data' })
* posthog.capture('another_event') // or no properties
*/
// Overload 1: For known events (specific EventName literals)
// This should match first for all known event names
capture<K extends EventName>(
event_name: K,
...args: HasRequiredProperties<K> extends true
? [properties: EventProperties<K>, options?: CaptureOptions]
: [properties?: EventProperties<K>, options?: CaptureOptions]
): CaptureResult | undefined
// Overload 2: For undefined events and blocking string variables
// Only matches if event_name is NOT a known EventName
// The conditional type rejects broad string type
capture<T extends string>(
event_name: IsExactlyString<T> extends true ? never : (T extends EventName ? never : T),
properties?: Properties | null,
options?: CaptureOptions
): CaptureResult | undefined
/**
* Raw capture for any event (original behavior, no type checking)
*
* Use capture() for type-safe defined events or flexible undefined events.
* Use captureRaw() only when you need to bypass all type checking.
*
* @example
* posthog.captureRaw('Any Event Name', { whatever: 'data' })
*/
captureRaw(event_name: string, properties?: Properties | null, options?: CaptureOptions): CaptureResult | undefined
}
"""
# Generate implementation
implementation = """
// Create the implementation
const createTypedPostHog = (original: OriginalPostHog): TypedPostHog => {
// Create the enhanced PostHog object
const enhanced: TypedPostHog = Object.create(original)
// Add capture method (type-safe for defined events, flexible for undefined)
enhanced.capture = function (event_name: string, ...args: any[]): CaptureResult | undefined {
const [properties, options] = args
return original.capture(event_name, properties, options)
}
// Add captureRaw method for untyped/flexible event tracking
enhanced.captureRaw = function (
event_name: string,
properties?: Properties | null,
options?: CaptureOptions
): CaptureResult | undefined {
return original.capture(event_name, properties, options)
}
// Proxy to delegate all other properties/methods to the original
return new Proxy(enhanced, {
get(target, prop) {
if (prop in target) {
return (target as any)[prop]
}
return (original as any)[prop]
},
set(target, prop, value) {
;(original as any)[prop] = value
return true
},
})
}
"""
# Generate exports
exports = """
// Create and export the typed instance
const posthog = createTypedPostHog(originalPostHog as OriginalPostHog)
export default posthog
export { posthog }
export type { EventSchemas, TypedPostHog }
// Re-export everything else from posthog-js
export * from 'posthog-js'
/**
* USAGE GUIDE
* ===========
*
* For type-safe defined events (recommended):
* posthog.capture('uploaded_file', { file_name: 'test.txt', file_size_b: 100 })
*
* For undefined events (flexible):
* posthog.capture('Custom Event', { whatever: 'data' })
*
* For bypassing all type checking (rare):
* posthog.captureRaw('Any Event', { whatever: 'data' })
*/
"""
# Combine all sections
return header + event_schemas + type_aliases + typed_posthog_interface + implementation + exports
def _map_property_type(self, property_type: str) -> str:
"""Map PostHog property types to TypeScript types"""
type_map = {
"String": "string",
"Numeric": "number",
"Boolean": "boolean",
"DateTime": "string | Date",
"Array": "any[]",
"Object": "Record<string, any>",
}
return type_map.get(property_type, "any")
@action(detail=True, methods=["GET"], url_path="metrics")
def metrics_totals(self, *args, **kwargs):
instance: EventDefinition = self.get_object()

View File

@@ -0,0 +1,476 @@
from datetime import timedelta
from freezegun import freeze_time
from posthog.test.base import APIBaseTest
from django.core.cache import cache
from django.utils import timezone
from rest_framework import status
from posthog.api.cli_auth import CLI_SCOPES, DEVICE_CODE_EXPIRY_SECONDS, get_device_cache_key, get_user_code_cache_key
from posthog.models import PersonalAPIKey, Team, User
from posthog.models.organization import Organization
from posthog.models.personal_api_key import hash_key_value
class TestCLIAuthDeviceCodeEndpoint(APIBaseTest):
"""
Tests for the device code request endpoint (step 1 of OAuth device flow)
"""
def setUp(self):
super().setUp()
cache.clear() # Clear cache before each test
def test_device_code_request_returns_correct_data(self):
"""Test that requesting a device code returns all required fields"""
response = self.client.post("/api/cli-auth/device-code/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
# Check all required fields are present
self.assertIn("device_code", data)
self.assertIn("user_code", data)
self.assertIn("verification_uri", data)
self.assertIn("verification_uri_complete", data)
self.assertIn("expires_in", data)
self.assertIn("interval", data)
# Check values are correct
self.assertEqual(data["expires_in"], DEVICE_CODE_EXPIRY_SECONDS)
self.assertIn("/cli/authorize", data["verification_uri"])
self.assertIn(data["user_code"], data["verification_uri_complete"])
# Verify user code format (XXXX-XXXX)
user_code = data["user_code"]
self.assertEqual(len(user_code), 9)
self.assertEqual(user_code[4], "-")
self.assertTrue(user_code[:4].isalpha())
self.assertTrue(user_code[5:].isdigit())
def test_device_code_is_stored_in_cache(self):
"""Test that device code and user code are properly stored in cache"""
response = self.client.post("/api/cli-auth/device-code/")
data = response.json()
device_code = data["device_code"]
user_code = data["user_code"]
# Check device code is in cache
device_cache_key = get_device_cache_key(device_code)
device_data = cache.get(device_cache_key)
self.assertIsNotNone(device_data)
self.assertEqual(device_data["user_code"], user_code)
self.assertEqual(device_data["status"], "pending")
# Check user code reverse lookup is in cache
user_code_cache_key = get_user_code_cache_key(user_code)
cached_device_code = cache.get(user_code_cache_key)
self.assertEqual(cached_device_code, device_code)
def test_device_code_works_without_authentication(self):
"""Test that device code endpoint works for unauthenticated requests"""
self.client.logout()
response = self.client.post("/api/cli-auth/device-code/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
class TestCLIAuthAuthorizeEndpoint(APIBaseTest):
"""
Tests for the authorization endpoint (step 2 of OAuth device flow)
"""
def setUp(self):
super().setUp()
cache.clear()
# Create a device code for testing
response = self.client.post("/api/cli-auth/device-code/")
self.device_data = response.json()
self.device_code = self.device_data["device_code"]
self.user_code = self.device_data["user_code"]
def test_successful_authorization_creates_api_key(self):
"""Test that successful authorization creates a Personal API Key"""
initial_key_count = PersonalAPIKey.objects.filter(user=self.user).count()
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data["status"], "success")
self.assertIn("label", data)
self.assertIn("mask_value", data)
# Check that API key was created
new_key_count = PersonalAPIKey.objects.filter(user=self.user).count()
self.assertEqual(new_key_count, initial_key_count + 1)
# Check the created key has correct scopes
api_key = PersonalAPIKey.objects.filter(user=self.user).order_by("-created_at").first()
assert api_key is not None
self.assertEqual(api_key.scopes, CLI_SCOPES)
self.assertIn("CLI -", api_key.label)
def test_authorization_requires_authentication(self):
"""Test that authorization endpoint requires user to be logged in"""
self.client.logout()
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_authorization_rejects_invalid_user_code(self):
"""Test that authorization fails with invalid user code"""
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": "XXXX-9999", "project_id": self.team.id},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
data = response.json()
self.assertIn("error", data)
self.assertEqual(data["error"], "invalid_code")
def test_authorization_rejects_expired_user_code(self):
"""Test that authorization fails with expired user code"""
# Wait for the code to expire
with freeze_time(timezone.now() + timedelta(seconds=DEVICE_CODE_EXPIRY_SECONDS + 1)):
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json()["error"], "invalid_code")
def test_authorization_rejects_user_without_team_access(self):
"""Test that user cannot authorize for a team they don't have access to"""
# Create another organization and team
other_org = Organization.objects.create(name="Other Org")
other_team = Team.objects.create(organization=other_org, name="Other Team")
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": other_team.id},
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.json()["error"], "access_denied")
# Verify no API key was created
api_keys = PersonalAPIKey.objects.filter(user=self.user).count()
self.assertEqual(api_keys, 0)
def test_authorization_rejects_nonexistent_project(self):
"""Test that authorization fails with non-existent project ID"""
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": 99999},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json()["error"], "invalid_project")
def test_authorization_updates_cache_with_api_key(self):
"""Test that authorization updates the cache with the API key"""
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Check cache was updated
device_cache_key = get_device_cache_key(self.device_code)
device_data = cache.get(device_cache_key)
self.assertEqual(device_data["status"], "authorized")
self.assertIn("personal_api_key", device_data)
self.assertEqual(device_data["project_id"], str(self.team.id))
self.assertEqual(device_data["user_id"], self.user.id)
def test_multiple_users_can_authorize_different_codes(self):
"""Test that multiple users can authorize different device codes concurrently"""
# Create another user
other_user = User.objects.create_and_join(self.organization, "other@posthog.com", "password123")
# Create device code for other user
response2 = self.client.post("/api/cli-auth/device-code/")
user_code2 = response2.json()["user_code"]
# First user authorizes their code
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Switch to other user
self.client.force_login(other_user)
# Other user authorizes their code
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": user_code2, "project_id": self.team.id},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Both users should have their own API keys
self.assertEqual(PersonalAPIKey.objects.filter(user=self.user).count(), 1)
self.assertEqual(PersonalAPIKey.objects.filter(user=other_user).count(), 1)
def test_authorization_prevents_duplicate_api_keys_from_race_condition(self):
"""Test that attempting to authorize the same code twice does not create duplicate API keys"""
initial_key_count = PersonalAPIKey.objects.filter(user=self.user).count()
# First authorization succeeds
response1 = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
self.assertEqual(response1.status_code, status.HTTP_200_OK)
# Second authorization attempt should fail (code already authorized)
response2 = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
self.assertEqual(response2.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response2.json()["error"], "already_authorized")
# Verify only one API key was created
final_key_count = PersonalAPIKey.objects.filter(user=self.user).count()
self.assertEqual(final_key_count, initial_key_count + 1)
class TestCLIAuthPollEndpoint(APIBaseTest):
"""
Tests for the poll endpoint (step 3 of OAuth device flow)
"""
def setUp(self):
super().setUp()
cache.clear()
# Create and authorize a device code
response = self.client.post("/api/cli-auth/device-code/")
self.device_data = response.json()
self.device_code = self.device_data["device_code"]
self.user_code = self.device_data["user_code"]
def test_poll_returns_pending_before_authorization(self):
"""Test that polling returns pending status before user authorizes"""
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
data = response.json()
self.assertEqual(data["status"], "pending")
def test_poll_returns_api_key_after_authorization(self):
"""Test that polling returns API key after user authorizes"""
# Authorize the code
self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
# Poll for the result
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data["status"], "authorized")
self.assertIn("personal_api_key", data)
self.assertIn("label", data)
self.assertEqual(data["project_id"], str(self.team.id))
# Verify the API key is valid
api_key = data["personal_api_key"]
self.assertTrue(api_key.startswith("phx_"))
def test_poll_returns_expired_for_old_code(self):
"""Test that polling returns expired for old device codes"""
with freeze_time(timezone.now() + timedelta(seconds=DEVICE_CODE_EXPIRY_SECONDS + 1)):
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
data = response.json()
self.assertEqual(data["status"], "expired")
def test_poll_returns_expired_for_nonexistent_code(self):
"""Test that polling returns expired for non-existent device codes"""
response = self.client.post("/api/cli-auth/poll/", {"device_code": "nonexistent_code"})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
data = response.json()
self.assertEqual(data["status"], "expired")
def test_poll_cleans_up_cache_after_successful_retrieval(self):
"""Test that cache is cleaned up after API key is retrieved"""
# Authorize the code
self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
# Poll for the result
self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
# Verify cache is cleaned up
device_cache_key = get_device_cache_key(self.device_code)
user_code_cache_key = get_user_code_cache_key(self.user_code)
self.assertIsNone(cache.get(device_cache_key))
self.assertIsNone(cache.get(user_code_cache_key))
def test_poll_can_be_called_multiple_times_before_authorization(self):
"""Test that poll can be called multiple times while pending"""
# Poll multiple times
for _ in range(3):
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
self.assertEqual(response.json()["status"], "pending")
def test_poll_works_without_authentication(self):
"""Test that poll endpoint works for unauthenticated requests"""
self.client.logout()
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
def test_poll_returns_api_key_only_once(self):
"""Test that API key can only be retrieved once"""
# Authorize the code
self.client.post(
"/api/cli-auth/authorize/",
{"user_code": self.user_code, "project_id": self.team.id},
)
# First poll succeeds
response1 = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
self.assertEqual(response1.status_code, status.HTTP_200_OK)
# Second poll fails (cache cleaned up)
response2 = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
self.assertEqual(response2.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response2.json()["status"], "expired")
class TestCLIAuthEndToEnd(APIBaseTest):
"""
End-to-end tests for the complete CLI authentication flow
"""
def setUp(self):
super().setUp()
cache.clear()
def test_complete_authentication_flow(self):
"""Test the complete device flow from start to finish"""
# Step 1: Request device code (CLI)
response = self.client.post("/api/cli-auth/device-code/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
device_data = response.json()
device_code = device_data["device_code"]
user_code = device_data["user_code"]
# Step 2: User opens browser and authorizes
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": user_code, "project_id": self.team.id},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Step 3: CLI polls and gets API key
response = self.client.post("/api/cli-auth/poll/", {"device_code": device_code})
self.assertEqual(response.status_code, status.HTTP_200_OK)
api_key = response.json()["personal_api_key"]
# Step 4: Verify the API key works
self.client.logout()
response = self.client.get(
f"/api/projects/{self.team.pk}/event_definitions/",
HTTP_AUTHORIZATION=f"Bearer {api_key}",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_api_key_has_correct_scopes(self):
"""Test that created API key has only the CLI scopes"""
# Complete the flow
response = self.client.post("/api/cli-auth/device-code/")
device_code = response.json()["device_code"]
user_code = response.json()["user_code"]
self.client.post(
"/api/cli-auth/authorize/",
{"user_code": user_code, "project_id": self.team.id},
)
response = self.client.post("/api/cli-auth/poll/", {"device_code": device_code})
api_key_value = response.json()["personal_api_key"]
# Get the API key from database
api_key = PersonalAPIKey.objects.get(secure_value=hash_key_value(api_key_value))
# Verify scopes
self.assertEqual(api_key.scopes, CLI_SCOPES)
self.assertIn("event_definition:read", api_key.scopes)
self.assertIn("property_definition:read", api_key.scopes)
self.assertIn("error_tracking:write", api_key.scopes)
def test_cross_team_access_is_prevented(self):
"""Test that user cannot authorize CLI for a team in a different organization"""
# Create a new organization that the user is NOT a member of
other_org = Organization.objects.create(name="Other Organization")
other_team = Team.objects.create(organization=other_org, name="Other Team")
# Complete device code request
response = self.client.post("/api/cli-auth/device-code/")
user_code = response.json()["user_code"]
# Try to authorize for the other team
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": user_code, "project_id": other_team.id},
)
# Should be rejected
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.json()["error"], "access_denied")
# Verify no API key was created
self.assertEqual(PersonalAPIKey.objects.filter(user=self.user).count(), 0)
def test_user_can_authorize_for_multiple_teams_in_same_org(self):
"""Test that user can authorize CLI for multiple teams in the same organization"""
# Create another team in the same organization
team2 = Team.objects.create(organization=self.organization, name="Team 2")
# Authorize for first team
response1 = self.client.post("/api/cli-auth/device-code/")
user_code1 = response1.json()["user_code"]
self.client.post(
"/api/cli-auth/authorize/",
{"user_code": user_code1, "project_id": self.team.id},
)
# Authorize for second team
response2 = self.client.post("/api/cli-auth/device-code/")
user_code2 = response2.json()["user_code"]
response = self.client.post(
"/api/cli-auth/authorize/",
{"user_code": user_code2, "project_id": team2.id},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Both API keys should be created
self.assertEqual(PersonalAPIKey.objects.filter(user=self.user).count(), 2)

View File

@@ -0,0 +1,275 @@
"""
Integration test for TypeScript definition generation.
Tests the complete flow:
1. Create EventDefinitions with EventSchemas
2. Generate TypeScript via the typescript_definitions method
3. Write generated TypeScript to temp file
4. Create a test file that uses the types
5. Run TypeScript compiler to verify no errors
"""
import tempfile
import subprocess
from pathlib import Path
from posthog.test.base import APIBaseTest
from rest_framework import status
from posthog.models import EventDefinition, EventSchema, SchemaPropertyGroup, SchemaPropertyGroupProperty
class TestEventDefinitionTypeScriptGeneration(APIBaseTest):
"""
Critical integration test ensuring TypeScript generation maintains type safety
while allowing additional properties beyond the schema.
"""
def setUp(self):
super().setUp()
# Create property group with required and optional fields
self.property_group = SchemaPropertyGroup.objects.create(
team=self.team, project=self.project, name="Test Properties"
)
SchemaPropertyGroupProperty.objects.create(
property_group=self.property_group,
name="required_field",
property_type="Numeric",
is_required=True,
description="A required numeric field",
)
SchemaPropertyGroupProperty.objects.create(
property_group=self.property_group,
name="optional_field",
property_type="String",
is_required=False,
description="An optional string field",
)
# Create event definition and link to property group
self.event_def = EventDefinition.objects.create(team=self.team, project=self.project, name="test_event")
EventSchema.objects.create(event_definition=self.event_def, property_group=self.property_group)
# Create event with all optional fields
self.optional_event_def = EventDefinition.objects.create(
team=self.team, project=self.project, name="optional_event"
)
optional_property_group = SchemaPropertyGroup.objects.create(
team=self.team, project=self.project, name="Optional Properties"
)
SchemaPropertyGroupProperty.objects.create(
property_group=optional_property_group,
name="optional_only",
property_type="String",
is_required=False,
)
EventSchema.objects.create(event_definition=self.optional_event_def, property_group=optional_property_group)
# Create event with no schema (all properties allowed)
self.untyped_event_def = EventDefinition.objects.create(
team=self.team, project=self.project, name="untyped_event"
)
def _generate_typescript(self) -> str:
"""Generate TypeScript definitions by calling the actual API endpoint"""
response = self.client.get(f"/api/projects/{self.project.id}/event_definitions/typescript/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
return response.json()["content"]
def test_typescript_allows_additional_properties(self):
"""
Critical test: Verify that additional properties beyond schema
are allowed while required properties are still validated.
This is the core functionality that prevents "excess property checking"
errors in TypeScript while maintaining type safety for required fields.
Uses the real posthog-js package to ensure compatibility with actual types.
"""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Generate TypeScript
ts_content = self._generate_typescript()
# Create minimal package.json to install only required dependencies
package_json = tmpdir_path / "package.json"
package_json.write_text('{"dependencies": {"typescript": "^5.0.0", "posthog-js": "^1.0.0"}}')
install_result = subprocess.run(
["pnpm", "install", "--no-frozen-lockfile"],
cwd=str(tmpdir_path),
capture_output=True,
text=True,
timeout=120,
)
if install_result.returncode != 0:
self.fail(
f"Failed to install dependencies:\n"
f"STDOUT: {install_result.stdout}\n"
f"STDERR: {install_result.stderr}"
)
# Write generated types (using real posthog-js)
types_file = tmpdir_path / "posthog-typed.ts"
types_file.write_text(ts_content)
# Create test file that exercises all type scenarios
test_file = tmpdir_path / "test.ts"
test_file.write_text(
"""
import posthog, { EventName } from './posthog-typed'
// ========================================
// TEST 1: Additional properties are allowed
// ========================================
// ✅ Should compile: required field + extra properties (CRITICAL TEST)
posthog.capture('test_event', {
required_field: 123,
optional_field: 'test',
extra_property: 'this should be allowed',
another_extra: true,
nested_extra: { foo: 'bar' }
})
// ✅ Should compile: only required field
posthog.capture('test_event', {
required_field: 456
})
// ========================================
// TEST 2: Required properties are validated
// ========================================
// ❌ Should fail: missing required field
// @ts-expect-error
posthog.capture('test_event', {
optional_field: 'test'
})
// ❌ Should fail: wrong type for required field
// @ts-expect-error
posthog.capture('test_event', {
required_field: 'string not allowed'
})
// ========================================
// TEST 3: Events with all optional properties
// ========================================
// ✅ Should compile: no properties needed
posthog.capture('optional_event')
// ✅ Should compile: with properties
posthog.capture('optional_event', {
optional_only: 'value',
extra_field: 123
})
// ========================================
// TEST 4: Untyped events accept anything
// ========================================
// ✅ Should compile: any properties
posthog.capture('untyped_event', {
anything: 'goes',
here: 123
})
// ✅ Should compile: no properties
posthog.capture('untyped_event')
// ========================================
// TEST 5: Undefined events work flexibly
// ========================================
// ✅ Should compile: custom event with properties
posthog.capture('custom_undefined_event', {
any: 'properties',
work: 'here'
})
// ✅ Should compile: custom event without properties
posthog.capture('another_custom_event')
// ========================================
// TEST 6: String variables are blocked
// ========================================
// ❌ Should fail: broad string type not allowed
let stringVar: string = 'test_event'
// @ts-expect-error
posthog.capture(stringVar)
// ✅ Should compile: EventName type works
let typedVar: EventName = 'test_event'
posthog.capture(typedVar, { required_field: 789 })
// ✅ Should compile: const infers literal type
const constVar = 'test_event'
posthog.capture(constVar, { required_field: 999 })
// ========================================
// TEST 7: captureRaw bypasses all checking
// ========================================
// ✅ Should compile: missing required fields is OK
posthog.captureRaw('test_event', {
optional_field: 'test'
})
// ✅ Should compile: wrong types are OK
posthog.captureRaw('test_event', {
required_field: 'string is fine here'
})
// ✅ Should compile: string variables work
posthog.captureRaw(stringVar, { any: 'data' })
"""
)
# Create tsconfig.json
tsconfig_file = tmpdir_path / "tsconfig.json"
tsconfig_file.write_text(
"""
{
"compilerOptions": {
"target": "ES2020",
"module": "ESNext",
"lib": ["ES2020", "DOM"],
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"moduleResolution": "node"
}
}
"""
)
# Run TypeScript compiler using pnpm
result = subprocess.run(
["pnpm", "exec", "tsc", "--noEmit", "--project", str(tsconfig_file)],
cwd=str(tmpdir_path),
capture_output=True,
text=True,
timeout=30,
)
# Assert compilation succeeded
self.assertEqual(
result.returncode,
0,
f"TypeScript compilation failed. This indicates the type system is broken.\n\n"
f"STDOUT:\n{result.stdout}\n\n"
f"STDERR:\n{result.stderr}\n\n"
f"Generated TypeScript file location: {types_file}",
)

View File

@@ -159,6 +159,10 @@ class AutoProjectMiddleware:
self.token_allowlist = PROJECT_SWITCHING_TOKEN_ALLOWLIST
def __call__(self, request: HttpRequest):
# Skip project switching for CLI authorization page
if request.path.startswith("/cli/authorize"):
return self.get_response(request)
if request.user.is_authenticated:
path_parts = request.path.strip("/").split("/")
project_id_in_url = None