feat: posthog-cli interactive login & typescript schema download (#39903)
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com>
2
cli/Cargo.lock
generated
@@ -1520,7 +1520,7 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "posthog-cli"
|
||||
version = "0.5.7"
|
||||
version = "0.5.8"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "posthog-cli"
|
||||
version = "0.5.7"
|
||||
version = "0.5.8"
|
||||
authors = [
|
||||
"David <david@posthog.com>",
|
||||
"Olly <oliver@posthog.com>",
|
||||
|
||||
@@ -68,6 +68,24 @@ pub enum ExpCommand {
|
||||
#[command(subcommand)]
|
||||
cmd: HermesSubcommand,
|
||||
},
|
||||
|
||||
/// Download event definitions and generate typed SDK
|
||||
Schema {
|
||||
#[command(subcommand)]
|
||||
cmd: SchemaCommand,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum SchemaCommand {
|
||||
/// Download event definitions and generate typed SDK
|
||||
Pull {
|
||||
/// Output path for generated definitions (stored in posthog.json for future runs)
|
||||
#[arg(short, long)]
|
||||
output: Option<String>,
|
||||
},
|
||||
/// Show current schema sync status
|
||||
Status,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
@@ -100,7 +118,7 @@ impl Cli {
|
||||
match self.command {
|
||||
Commands::Login => {
|
||||
// Notably login doesn't have a context set up going it - it sets one up
|
||||
crate::login::login()?;
|
||||
crate::login::login(self.host)?;
|
||||
}
|
||||
Commands::Sourcemap { cmd } => match cmd {
|
||||
SourcemapCommand::Inject(input_args) => {
|
||||
@@ -136,6 +154,14 @@ impl Cli {
|
||||
crate::sourcemaps::hermes::clone::clone(&args)?;
|
||||
}
|
||||
},
|
||||
ExpCommand::Schema { cmd } => match cmd {
|
||||
SchemaCommand::Pull { output } => {
|
||||
crate::experimental::schema::pull(self.host, output)?;
|
||||
}
|
||||
SchemaCommand::Status => {
|
||||
crate::experimental::schema::status()?;
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,5 +2,6 @@
|
||||
// Things in here should be considered unstable and possibly broken
|
||||
|
||||
pub mod query;
|
||||
pub mod schema;
|
||||
pub mod tasks;
|
||||
pub mod tui;
|
||||
|
||||
366
cli/src/experimental/schema.rs
Normal file
@@ -0,0 +1,366 @@
|
||||
use anyhow::{Context, Result};
|
||||
use inquire::{Select, Text};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use tracing::info;
|
||||
|
||||
use crate::invocation_context::context;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum Language {
|
||||
TypeScript,
|
||||
}
|
||||
|
||||
impl Language {
|
||||
/// Get the language identifier used in API URLs
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Language::TypeScript => "typescript",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the display name for the language
|
||||
fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Language::TypeScript => "TypeScript",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default output filename for this language
|
||||
fn default_output_path(&self) -> &'static str {
|
||||
match self {
|
||||
Language::TypeScript => "posthog-typed.ts",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all available languages
|
||||
fn all() -> Vec<Language> {
|
||||
vec![Language::TypeScript]
|
||||
}
|
||||
|
||||
/// Parse a language from a string identifier
|
||||
fn from_str(s: &str) -> Option<Language> {
|
||||
match s {
|
||||
"typescript" => Some(Language::TypeScript),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
struct SchemaConfig {
|
||||
languages: HashMap<String, LanguageConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct LanguageConfig {
|
||||
output_path: String,
|
||||
schema_hash: String,
|
||||
updated_at: String,
|
||||
event_count: usize,
|
||||
}
|
||||
|
||||
impl SchemaConfig {
|
||||
/// Load config from posthog.json, returns empty config if file doesn't exist or is invalid
|
||||
fn load() -> Self {
|
||||
let content = fs::read_to_string("posthog.json").ok();
|
||||
content
|
||||
.and_then(|c| serde_json::from_str(&c).ok())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Save config to posthog.json
|
||||
fn save(&self) -> Result<()> {
|
||||
let json =
|
||||
serde_json::to_string_pretty(self).context("Failed to serialize schema config")?;
|
||||
fs::write("posthog.json", json).context("Failed to write posthog.json")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get language config for a specific language
|
||||
fn get_language(&self, language: Language) -> Option<&LanguageConfig> {
|
||||
self.languages.get(language.as_str())
|
||||
}
|
||||
|
||||
/// Get output path for a language
|
||||
fn get_output_path(&self, language: Language) -> Option<String> {
|
||||
self.languages
|
||||
.get(language.as_str())
|
||||
.map(|l| l.output_path.clone())
|
||||
}
|
||||
|
||||
/// Update language config, preserving other languages
|
||||
fn update_language(
|
||||
&mut self,
|
||||
language: Language,
|
||||
output_path: String,
|
||||
schema_hash: String,
|
||||
event_count: usize,
|
||||
) {
|
||||
use chrono::Utc;
|
||||
|
||||
self.languages.insert(
|
||||
language.as_str().to_string(),
|
||||
LanguageConfig {
|
||||
output_path,
|
||||
schema_hash,
|
||||
updated_at: Utc::now().to_rfc3339(),
|
||||
event_count,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DefinitionsResponse {
|
||||
content: String,
|
||||
event_count: usize,
|
||||
schema_hash: String,
|
||||
}
|
||||
|
||||
pub fn pull(_host: Option<String>, output_override: Option<String>) -> Result<()> {
|
||||
// Select language
|
||||
let language = select_language()?;
|
||||
|
||||
info!(
|
||||
"Fetching {} definitions from PostHog...",
|
||||
language.display_name()
|
||||
);
|
||||
|
||||
// Load credentials
|
||||
let token = context().token.clone();
|
||||
let host = token.get_host();
|
||||
|
||||
// Determine output path
|
||||
let output_path = determine_output_path(language, output_override)?;
|
||||
|
||||
// Fetch definitions from the server
|
||||
let response = fetch_definitions(&host, &token.env_id, &token.token, language)?;
|
||||
|
||||
info!(
|
||||
"✓ Fetched {} definitions for {} events",
|
||||
language.display_name(),
|
||||
response.event_count
|
||||
);
|
||||
|
||||
// Check if schema has changed for this language
|
||||
let config = SchemaConfig::load();
|
||||
if let Some(lang_config) = config.get_language(language) {
|
||||
if lang_config.schema_hash == response.schema_hash {
|
||||
info!(
|
||||
"Schema unchanged for {} (hash: {})",
|
||||
language.as_str(),
|
||||
response.schema_hash
|
||||
);
|
||||
println!(
|
||||
"\n✓ {} schema is already up to date!",
|
||||
language.display_name()
|
||||
);
|
||||
println!(" No changes detected - skipping file write.");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Write TypeScript definitions to file
|
||||
info!("Writing {}...", output_path);
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
if let Some(parent) = Path::new(&output_path).parent() {
|
||||
if !parent.as_os_str().is_empty() {
|
||||
fs::create_dir_all(parent)
|
||||
.context(format!("Failed to create directory {}", parent.display()))?;
|
||||
}
|
||||
}
|
||||
|
||||
fs::write(&output_path, &response.content).context(format!("Failed to write {output_path}"))?;
|
||||
info!("✓ Generated {}", output_path);
|
||||
|
||||
// Update schema configuration for this language
|
||||
info!("Updating posthog.json...");
|
||||
let mut config = SchemaConfig::load();
|
||||
config.update_language(
|
||||
language,
|
||||
output_path.clone(),
|
||||
response.schema_hash,
|
||||
response.event_count,
|
||||
);
|
||||
config.save()?;
|
||||
info!("✓ Updated posthog.json");
|
||||
|
||||
println!("\n✓ Schema sync complete!");
|
||||
println!("\nNext steps:");
|
||||
println!(" 1. Import PostHog from your generated module:");
|
||||
println!(" import posthog from './{output_path}'");
|
||||
println!(" 2. Use typed events with autocomplete and type safety:");
|
||||
println!(" posthog.captureTyped('event_name', {{ property: 'value' }})");
|
||||
println!(" 3. Or use regular capture() for flexibility:");
|
||||
println!(" posthog.capture('dynamic_event', {{ any: 'data' }})");
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn determine_output_path(language: Language, output_override: Option<String>) -> Result<String> {
|
||||
// If CLI override is provided, use it (and normalize it)
|
||||
if let Some(path) = output_override {
|
||||
return Ok(normalize_output_path(&path, language));
|
||||
}
|
||||
|
||||
// Check if posthog.json exists and has an output_path for this language
|
||||
let config = SchemaConfig::load();
|
||||
if let Some(path) = config.get_output_path(language) {
|
||||
return Ok(path);
|
||||
}
|
||||
|
||||
// Prompt user for output path
|
||||
let default_filename = language.default_output_path();
|
||||
let current_dir = std::env::current_dir()
|
||||
.ok()
|
||||
.and_then(|p| p.to_str().map(String::from))
|
||||
.unwrap_or_else(|| ".".to_string());
|
||||
|
||||
let help_message = format!(
|
||||
"Your app will import PostHog from this file, so it should be accessible \
|
||||
throughout your codebase (e.g., src/lib/, app/lib/, or your project root). \
|
||||
This path will be saved in posthog.json and can be changed later. \
|
||||
Current directory: {current_dir}"
|
||||
);
|
||||
|
||||
let path = Text::new(&format!(
|
||||
"Where should we save the {} typed PostHog module?",
|
||||
language.display_name()
|
||||
))
|
||||
.with_default(default_filename)
|
||||
.with_help_message(&help_message)
|
||||
.prompt()
|
||||
.unwrap_or(default_filename.to_string());
|
||||
|
||||
Ok(normalize_output_path(&path, language))
|
||||
}
|
||||
|
||||
fn normalize_output_path(path: &str, language: Language) -> String {
|
||||
let path_obj = Path::new(path);
|
||||
|
||||
// If it's a directory (existing or ends with slash), append default filename
|
||||
let should_append_filename =
|
||||
(path_obj.exists() && path_obj.is_dir()) || path.ends_with('/') || path.ends_with('\\');
|
||||
|
||||
if should_append_filename {
|
||||
path_obj
|
||||
.join(language.default_output_path())
|
||||
.to_string_lossy()
|
||||
.into_owned()
|
||||
} else {
|
||||
path.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn status() -> Result<()> {
|
||||
// Check authentication
|
||||
println!("\nPostHog Schema Sync Status\n");
|
||||
|
||||
println!("Authentication:");
|
||||
let token = context().token.clone();
|
||||
println!(" ✓ Authenticated");
|
||||
println!(" Host: {}", token.get_host());
|
||||
println!(" Project ID: {}", token.env_id);
|
||||
let masked_token = format!(
|
||||
"{}****{}",
|
||||
&token.token[..4],
|
||||
&token.token[token.token.len() - 4..]
|
||||
);
|
||||
println!(" Token: {masked_token}");
|
||||
|
||||
println!();
|
||||
|
||||
// Check schema status
|
||||
println!("Schema:");
|
||||
let config = SchemaConfig::load();
|
||||
|
||||
if config.languages.is_empty() {
|
||||
println!(" ✗ No schemas synced");
|
||||
println!(" Run: posthog-cli exp schema pull");
|
||||
} else {
|
||||
println!(" ✓ Schemas synced\n");
|
||||
|
||||
for (language_str, lang_config) in &config.languages {
|
||||
// Parse language to get display name, fallback to raw string if unknown
|
||||
let display = Language::from_str(language_str)
|
||||
.map(|l| l.display_name())
|
||||
.unwrap_or(language_str.as_str());
|
||||
|
||||
println!(" {display}:");
|
||||
println!(" Hash: {}", lang_config.schema_hash);
|
||||
println!(" Updated: {}", lang_config.updated_at);
|
||||
println!(" Events: {}", lang_config.event_count);
|
||||
|
||||
if Path::new(&lang_config.output_path).exists() {
|
||||
println!(" File: ✓ {}", lang_config.output_path);
|
||||
} else {
|
||||
println!(" File: ! {} (missing)", lang_config.output_path);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fetch_definitions(
|
||||
host: &str,
|
||||
env_id: &str,
|
||||
token: &str,
|
||||
language: Language,
|
||||
) -> Result<DefinitionsResponse> {
|
||||
let url = format!(
|
||||
"{}/api/projects/{}/event_definitions/{}/",
|
||||
host,
|
||||
env_id,
|
||||
language.as_str()
|
||||
);
|
||||
|
||||
let client = &context().client;
|
||||
let response = client.get(&url).bearer_auth(token).send().context(format!(
|
||||
"Failed to fetch {} definitions",
|
||||
language.display_name()
|
||||
))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to fetch {} definitions: HTTP {}",
|
||||
language.display_name(),
|
||||
response.status()
|
||||
));
|
||||
}
|
||||
|
||||
let json: DefinitionsResponse = response.json().context(format!(
|
||||
"Failed to parse {} definitions response",
|
||||
language.display_name()
|
||||
))?;
|
||||
|
||||
Ok(json)
|
||||
}
|
||||
|
||||
fn select_language() -> Result<Language> {
|
||||
let languages = Language::all();
|
||||
|
||||
if languages.len() == 1 {
|
||||
return Ok(languages[0]);
|
||||
}
|
||||
|
||||
let language_strs: Vec<&str> = languages.iter().map(|l| l.display_name()).collect();
|
||||
let selected = Select::new("Which language would you like to download?", language_strs)
|
||||
.prompt()
|
||||
.context("Failed to select language")?;
|
||||
|
||||
// Find the language that matches the selected display name
|
||||
languages
|
||||
.into_iter()
|
||||
.find(|l| l.display_name() == selected)
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid language selection"))
|
||||
}
|
||||
271
cli/src/login.rs
@@ -1,5 +1,8 @@
|
||||
use anyhow::Error;
|
||||
use inquire::Text;
|
||||
use anyhow::{Context, Error};
|
||||
use inquire::{Select, Text};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
@@ -7,14 +10,265 @@ use crate::{
|
||||
utils::auth::{host_validator, token_validator, CredentialProvider, HomeDirProvider, Token},
|
||||
};
|
||||
|
||||
pub fn login() -> Result<(), Error> {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri_complete: String,
|
||||
expires_in: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct PollRequest {
|
||||
device_code: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PollResponse {
|
||||
status: String,
|
||||
personal_api_key: Option<String>,
|
||||
project_id: Option<String>,
|
||||
}
|
||||
|
||||
pub fn login(host_override: Option<String>) -> Result<(), Error> {
|
||||
login_with_use_cases(host_override, vec!["schema", "error_tracking"])
|
||||
}
|
||||
|
||||
pub fn login_with_use_cases(
|
||||
host_override: Option<String>,
|
||||
use_cases: Vec<&str>,
|
||||
) -> Result<(), Error> {
|
||||
let host = if let Some(override_host) = host_override {
|
||||
// Strip trailing slashes to avoid double slashes in URLs
|
||||
override_host.trim_end_matches('/').to_string()
|
||||
} else {
|
||||
// Prompt user to select region or manual login
|
||||
let options = vec!["US", "EU", "Manual"];
|
||||
let selection = Select::new("Select your PostHog region:", options)
|
||||
.with_help_message("Choose the region where your PostHog data is hosted, or 'Manual' to enter your own details")
|
||||
.prompt()?;
|
||||
|
||||
match selection {
|
||||
"US" => "https://us.posthog.com".to_string(),
|
||||
"EU" => "https://eu.posthog.com".to_string(),
|
||||
"Manual" => {
|
||||
return manual_login();
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
info!("🔐 Starting OAuth Device Flow authentication...");
|
||||
info!("Connecting to: {}", host);
|
||||
|
||||
// Step 1: Request device code
|
||||
let device_data = request_device_code(&host)?;
|
||||
|
||||
// Add use_cases parameter to the verification URL
|
||||
let use_cases_param = use_cases.join(",");
|
||||
let verification_url = if device_data.verification_uri_complete.contains('?') {
|
||||
format!(
|
||||
"{}&use_cases={}",
|
||||
device_data.verification_uri_complete, use_cases_param
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}?use_cases={}",
|
||||
device_data.verification_uri_complete, use_cases_param
|
||||
)
|
||||
};
|
||||
|
||||
println!();
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!(" 📱 Authorization Required");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!();
|
||||
println!("To authenticate, visit this URL in your browser:");
|
||||
println!(" {verification_url}");
|
||||
println!();
|
||||
println!("Your authorization code:");
|
||||
println!(" ✨ {} ✨", device_data.user_code);
|
||||
println!();
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!();
|
||||
|
||||
// Step 2: Try to open browser
|
||||
if let Err(e) = open_browser(&verification_url) {
|
||||
info!("Could not open browser automatically: {}", e);
|
||||
info!("Please open the URL manually");
|
||||
} else {
|
||||
info!("✓ Opened browser for authorization");
|
||||
}
|
||||
|
||||
// Step 3: Poll for authorization
|
||||
info!("Waiting for authorization...");
|
||||
let poll_response = poll_for_authorization(
|
||||
&host,
|
||||
&device_data.device_code,
|
||||
device_data.interval,
|
||||
device_data.expires_in,
|
||||
)?;
|
||||
|
||||
info!("✓ Successfully authenticated!");
|
||||
|
||||
// Step 4: Save credentials
|
||||
let token = Token {
|
||||
host: Some(host),
|
||||
token: poll_response.personal_api_key.unwrap(),
|
||||
env_id: poll_response.project_id.unwrap(),
|
||||
};
|
||||
let provider = HomeDirProvider;
|
||||
provider.store_credentials(token)?;
|
||||
|
||||
info!("Token saved to: {}", provider.report_location());
|
||||
|
||||
complete_login(&provider, "interactive_login")
|
||||
}
|
||||
|
||||
fn request_device_code(host: &str) -> Result<DeviceCodeResponse, Error> {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let url = format!("{host}/api/cli-auth/device-code/");
|
||||
|
||||
let response = client
|
||||
.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.send()
|
||||
.context("Failed to request device code")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to request device code: HTTP {}",
|
||||
response.status()
|
||||
));
|
||||
}
|
||||
|
||||
let device_data: DeviceCodeResponse = response
|
||||
.json()
|
||||
.context("Failed to parse device code response")?;
|
||||
|
||||
Ok(device_data)
|
||||
}
|
||||
|
||||
fn open_browser(url: &str) -> Result<(), Error> {
|
||||
// Try to open browser using platform-specific commands
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
std::process::Command::new("open")
|
||||
.arg(url)
|
||||
.spawn()
|
||||
.context("Failed to open browser")?;
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
std::process::Command::new("xdg-open")
|
||||
.arg(url)
|
||||
.spawn()
|
||||
.context("Failed to open browser")?;
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
std::process::Command::new("cmd")
|
||||
.args(&["/C", "start", url])
|
||||
.spawn()
|
||||
.context("Failed to open browser")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_for_authorization(
|
||||
host: &str,
|
||||
device_code: &str,
|
||||
interval_seconds: u64,
|
||||
expires_in_seconds: u64,
|
||||
) -> Result<PollResponse, Error> {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let url = format!("{host}/api/cli-auth/poll/");
|
||||
let max_attempts = (expires_in_seconds / interval_seconds) + 1;
|
||||
let poll_interval = Duration::from_secs(interval_seconds);
|
||||
|
||||
for attempt in 1..=max_attempts {
|
||||
thread::sleep(poll_interval);
|
||||
|
||||
let request = PollRequest {
|
||||
device_code: device_code.to_string(),
|
||||
};
|
||||
|
||||
let response = client
|
||||
.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.context("Failed to poll for authorization")?;
|
||||
|
||||
let status_code = response.status();
|
||||
|
||||
if status_code.as_u16() == 202 {
|
||||
// Still pending
|
||||
if attempt % 3 == 0 {
|
||||
info!(
|
||||
"Still waiting for authorization... (attempt {}/{})",
|
||||
attempt, max_attempts
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse response body for both success and error cases
|
||||
let poll_response: PollResponse =
|
||||
response.json().context("Failed to parse poll response")?;
|
||||
|
||||
if status_code.is_success() && poll_response.status == "authorized" {
|
||||
return Ok(poll_response);
|
||||
}
|
||||
|
||||
if status_code.as_u16() == 400 && poll_response.status == "expired" {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Authorization code expired. Please try again."
|
||||
));
|
||||
}
|
||||
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unexpected response during polling: HTTP {} - status: {}",
|
||||
status_code,
|
||||
poll_response.status
|
||||
));
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!(
|
||||
"Authorization timed out. Please try again."
|
||||
))
|
||||
}
|
||||
|
||||
fn complete_login(provider: &HomeDirProvider, command_name: &str) -> Result<(), Error> {
|
||||
// Login is the only command that doesn't have a context coming in - because it modifies the context
|
||||
init_context(None, false)?;
|
||||
context().capture_command_invoked(command_name);
|
||||
|
||||
println!();
|
||||
println!("🎉 Authentication complete!");
|
||||
println!("Credentials saved to: {}", provider.report_location());
|
||||
println!();
|
||||
println!("You can now use the CLI:");
|
||||
println!(" posthog-cli schema pull");
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn manual_login() -> Result<(), Error> {
|
||||
info!("🔐 Manual login...");
|
||||
|
||||
let host = Text::new("Enter the PostHog host URL")
|
||||
.with_default("https://us.posthog.com")
|
||||
.with_validator(host_validator)
|
||||
.prompt()?;
|
||||
|
||||
let env_id =
|
||||
Text::new("Enter your project ID (the number in your posthog homepage url)").prompt()?;
|
||||
Text::new("Enter your project ID (the number in your PostHog homepage URL)").prompt()?;
|
||||
|
||||
let token = Text::new(
|
||||
"Enter your personal API token",
|
||||
@@ -24,17 +278,14 @@ pub fn login() -> Result<(), Error> {
|
||||
.prompt()?;
|
||||
|
||||
let token = Token {
|
||||
host: Some(host),
|
||||
host: Some(host.trim_end_matches('/').to_string()),
|
||||
token,
|
||||
env_id,
|
||||
};
|
||||
let provider = HomeDirProvider;
|
||||
provider.store_credentials(token)?;
|
||||
|
||||
info!("Token saved to: {}", provider.report_location());
|
||||
|
||||
// Login is the only command that doesn't have a context coming in - because it modifies the context
|
||||
init_context(None, false)?;
|
||||
context().capture_command_invoked("interactive_login");
|
||||
|
||||
Ok(())
|
||||
complete_login(&provider, "manual_login")
|
||||
}
|
||||
|
||||
@@ -22,7 +22,28 @@ fn main() {
|
||||
|
||||
match cmd::Cli::run() {
|
||||
Ok(_) => info!("All done, happy hogging!"),
|
||||
Err(_) => {
|
||||
Err(e) => {
|
||||
match e.exception_id {
|
||||
Some(id) => {
|
||||
eprintln!("Oops! {}", e.inner);
|
||||
eprintln!();
|
||||
eprintln!("Exception ID: {id}");
|
||||
}
|
||||
None => {
|
||||
eprintln!("Oops! {}", e.inner);
|
||||
|
||||
let mut source = e.inner.source();
|
||||
if source.is_some() {
|
||||
eprintln!("\nCaused by:");
|
||||
let mut index = 0;
|
||||
while let Some(err) = source {
|
||||
eprintln!(" {index}: {err}");
|
||||
source = err.source();
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 95 KiB |
|
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 95 KiB |
|
Before Width: | Height: | Size: 41 KiB After Width: | Height: | Size: 41 KiB |
|
Before Width: | Height: | Size: 40 KiB After Width: | Height: | Size: 40 KiB |
@@ -187,3 +187,20 @@ export const getMinimumEquivalentScopes = (scopes: string[]): string[] => {
|
||||
return `${object}:${action}`
|
||||
})
|
||||
}
|
||||
|
||||
/** Convert scopes array format to object format for easier UI manipulation */
|
||||
export const scopesArrayToObject = (scopes: string[]): Record<string, string> => {
|
||||
const result: Record<string, string> = {}
|
||||
scopes.forEach((scope) => {
|
||||
const [key, action] = scope.split(':')
|
||||
if (key && action) {
|
||||
result[key] = action
|
||||
}
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
/** Convert scopes object format back to array format for API submission */
|
||||
export const scopesObjectToArray = (scopesObj: Record<string, string>): string[] => {
|
||||
return Object.entries(scopesObj).map(([key, action]) => `${key}:${action}`)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ const pathsWithoutProjectId = [
|
||||
'oauth',
|
||||
'shared',
|
||||
'embedded',
|
||||
'cli',
|
||||
'render_query',
|
||||
]
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ export const appScenes: Record<Scene | string, () => any> = {
|
||||
[Scene.BillingSection]: () => import('./billing/BillingSection'),
|
||||
[Scene.Billing]: () => import('./billing/Billing'),
|
||||
[Scene.Canvas]: () => import('./notebooks/NotebookCanvasScene'),
|
||||
[Scene.CLIAuthorize]: () => import('./authentication/CLIAuthorize'),
|
||||
[Scene.Cohort]: () => import('./cohorts/Cohort'),
|
||||
[Scene.CohortCalculationHistory]: () => import('./cohorts/CohortCalculationHistory'),
|
||||
[Scene.Cohorts]: () => import('./cohorts/Cohorts'),
|
||||
|
||||
293
frontend/src/scenes/authentication/CLIAuthorize.tsx
Normal file
@@ -0,0 +1,293 @@
|
||||
import { useActions, useValues } from 'kea'
|
||||
import { Form } from 'kea-forms'
|
||||
import { Fragment, useState } from 'react'
|
||||
|
||||
import { IconCode, IconGear, IconWarning } from '@posthog/icons'
|
||||
import { IconInfo } from '@posthog/icons'
|
||||
import { LemonButton, LemonInput, LemonSegmentedButton, LemonSelect, Link, Tooltip } from '@posthog/lemon-ui'
|
||||
|
||||
import { BridgePage } from 'lib/components/BridgePage/BridgePage'
|
||||
import { LemonBanner } from 'lib/lemon-ui/LemonBanner'
|
||||
import { LemonField } from 'lib/lemon-ui/LemonField'
|
||||
import { LemonModal } from 'lib/lemon-ui/LemonModal'
|
||||
import { IconErrorOutline } from 'lib/lemon-ui/icons'
|
||||
import { API_SCOPES } from 'lib/scopes'
|
||||
import { capitalizeFirstLetter } from 'lib/utils'
|
||||
import { SceneExport } from 'scenes/sceneTypes'
|
||||
import { urls } from 'scenes/urls'
|
||||
|
||||
import { cliAuthorizeLogic } from './cliAuthorizeLogic'
|
||||
|
||||
export const scene: SceneExport = {
|
||||
component: CLIAuthorize,
|
||||
logic: cliAuthorizeLogic,
|
||||
}
|
||||
|
||||
function ScopesList({
|
||||
scopes,
|
||||
formScopeRadioValues,
|
||||
displayScopeValues,
|
||||
setScopeRadioValue,
|
||||
showAll = false,
|
||||
}: {
|
||||
scopes: typeof API_SCOPES
|
||||
formScopeRadioValues: Record<string, string>
|
||||
displayScopeValues?: Record<string, string>
|
||||
setScopeRadioValue: (key: string, action: string) => void
|
||||
showAll?: boolean
|
||||
}): JSX.Element {
|
||||
// Use displayScopeValues for filtering if provided, otherwise use formScopeRadioValues
|
||||
const filterValues = displayScopeValues ?? formScopeRadioValues
|
||||
const visibleScopes = showAll
|
||||
? scopes
|
||||
: scopes.filter((scope) => filterValues[scope.key] && filterValues[scope.key] !== 'none')
|
||||
|
||||
if (!showAll && visibleScopes.length === 0) {
|
||||
return (
|
||||
<div className="text-muted text-sm italic py-2">
|
||||
No scopes selected. Click "Manage scopes" to select permissions.
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
{visibleScopes.map(({ key, disabledActions, warnings, info }) => {
|
||||
return (
|
||||
<Fragment key={key}>
|
||||
<div className="flex items-center justify-between gap-2 min-h-8">
|
||||
<div className="flex items-center gap-1">
|
||||
<b>{capitalizeFirstLetter(key.replace(/_/g, ' '))}</b>
|
||||
|
||||
{info ? (
|
||||
<Tooltip title={info}>
|
||||
<IconInfo className="text-secondary text-base" />
|
||||
</Tooltip>
|
||||
) : null}
|
||||
</div>
|
||||
<LemonSegmentedButton
|
||||
onChange={(value) => setScopeRadioValue(key, value)}
|
||||
value={formScopeRadioValues[key] ?? 'none'}
|
||||
options={[
|
||||
{ label: 'No access', value: 'none' },
|
||||
{
|
||||
label: 'Read',
|
||||
value: 'read',
|
||||
disabledReason: disabledActions?.includes('read')
|
||||
? 'Does not apply to this resource'
|
||||
: undefined,
|
||||
},
|
||||
{
|
||||
label: 'Write',
|
||||
value: 'write',
|
||||
disabledReason: disabledActions?.includes('write')
|
||||
? 'Does not apply to this resource'
|
||||
: undefined,
|
||||
},
|
||||
]}
|
||||
size="xsmall"
|
||||
/>
|
||||
</div>
|
||||
{warnings?.[formScopeRadioValues[key]] && (
|
||||
<div className="flex items-start gap-2 text-xs italic pb-2">
|
||||
<IconWarning className="text-base text-secondary mt-0.5" />
|
||||
<span>{warnings[formScopeRadioValues[key]]}</span>
|
||||
</div>
|
||||
)}
|
||||
</Fragment>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export function CLIAuthorize(): JSX.Element {
|
||||
const {
|
||||
authorize,
|
||||
isSuccess,
|
||||
projects,
|
||||
projectsLoading,
|
||||
isAuthorizeSubmitting,
|
||||
formScopeRadioValues,
|
||||
displayedScopeValues,
|
||||
missingSchemaScopes,
|
||||
missingErrorTrackingScopes,
|
||||
} = useValues(cliAuthorizeLogic)
|
||||
const { setAuthorizeValue, setScopeRadioValue, updateDisplayedScopeSnapshot } = useActions(cliAuthorizeLogic)
|
||||
const [isScopesModalOpen, setIsScopesModalOpen] = useState(false)
|
||||
|
||||
const handleOpenModal = (): void => {
|
||||
setIsScopesModalOpen(true)
|
||||
}
|
||||
|
||||
const handleCloseModal = (): void => {
|
||||
updateDisplayedScopeSnapshot()
|
||||
setIsScopesModalOpen(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<BridgePage
|
||||
view="login"
|
||||
{...(!isSuccess
|
||||
? {
|
||||
hedgehog: true as const,
|
||||
message: (
|
||||
<>
|
||||
Authorize
|
||||
<br />
|
||||
PostHog CLI
|
||||
</>
|
||||
),
|
||||
}
|
||||
: { hedgehog: false as const })}
|
||||
>
|
||||
{isSuccess ? (
|
||||
<div className="text-center space-y-4">
|
||||
<h2>CLI Authorization Complete</h2>
|
||||
<LemonBanner type="success">
|
||||
<div className="space-y-2">
|
||||
<p className="font-semibold">Your CLI has been authorized successfully!</p>
|
||||
<p>You can now close this window and return to your terminal.</p>
|
||||
</div>
|
||||
</LemonBanner>
|
||||
<div className="text-muted text-sm mt-4">
|
||||
<p>
|
||||
A Personal API Key has been created for your CLI. You can manage your API keys in{' '}
|
||||
<Link to={urls.settings('user-api-keys')} className="font-semibold">
|
||||
Settings → Personal API Keys
|
||||
</Link>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
<h2>Authorize CLI Access</h2>
|
||||
<p className="text-muted text-sm">
|
||||
The PostHog CLI should have displayed a 9-character code (e.g., ABCD-1234). Enter it below to
|
||||
authorize your CLI.
|
||||
</p>
|
||||
<Form logic={cliAuthorizeLogic} formKey="authorize" enableFormOnSubmit className="space-y-4">
|
||||
<LemonField name="userCode" label="Authorization Code">
|
||||
<LemonInput
|
||||
className="ph-ignore-input font-mono text-lg tracking-wider"
|
||||
autoFocus
|
||||
data-attr="cli-auth-code"
|
||||
placeholder="ABCD-1234"
|
||||
maxLength={9}
|
||||
value={authorize.userCode}
|
||||
onChange={(value) => setAuthorizeValue('userCode', value.toUpperCase())}
|
||||
autoComplete="off"
|
||||
autoCorrect="off"
|
||||
autoCapitalize="characters"
|
||||
spellCheck={false}
|
||||
/>
|
||||
</LemonField>
|
||||
<LemonField name="projectId" label="Project">
|
||||
<LemonSelect
|
||||
data-attr="cli-project-select"
|
||||
placeholder="Select a project"
|
||||
value={authorize.projectId}
|
||||
onChange={(value) => setAuthorizeValue('projectId', value)}
|
||||
options={projects.map((project: { id: number; name: string }) => ({
|
||||
label: project.name,
|
||||
value: project.id,
|
||||
}))}
|
||||
loading={projectsLoading}
|
||||
/>
|
||||
</LemonField>
|
||||
|
||||
<div className="mt-4 mb-2">
|
||||
<div className="flex items-center justify-between mb-2">
|
||||
<h3>Scopes</h3>
|
||||
<LemonButton
|
||||
type="secondary"
|
||||
size="small"
|
||||
icon={<IconGear />}
|
||||
onClick={handleOpenModal}
|
||||
>
|
||||
Manage scopes
|
||||
</LemonButton>
|
||||
</div>
|
||||
<p className="text-muted text-sm mb-2">
|
||||
Selected permissions for the CLI. Only grant the scopes you need.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<LemonField name="scopes">
|
||||
{({ error }) => (
|
||||
<>
|
||||
{error && (
|
||||
<div className="text-danger flex items-center gap-1 text-sm mb-2">
|
||||
<IconErrorOutline className="text-xl" /> {error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<ScopesList
|
||||
scopes={API_SCOPES}
|
||||
formScopeRadioValues={formScopeRadioValues}
|
||||
displayScopeValues={displayedScopeValues}
|
||||
setScopeRadioValue={setScopeRadioValue}
|
||||
showAll={false}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</LemonField>
|
||||
|
||||
<LemonModal
|
||||
title="Manage CLI Scopes"
|
||||
description="Select which permissions to grant to the CLI. Only select the scopes you need."
|
||||
isOpen={isScopesModalOpen}
|
||||
onClose={handleCloseModal}
|
||||
footer={
|
||||
<LemonButton type="primary" onClick={handleCloseModal}>
|
||||
Done
|
||||
</LemonButton>
|
||||
}
|
||||
>
|
||||
<div className="max-h-96 overflow-y-auto">
|
||||
<ScopesList
|
||||
scopes={API_SCOPES}
|
||||
formScopeRadioValues={formScopeRadioValues}
|
||||
setScopeRadioValue={setScopeRadioValue}
|
||||
showAll={true}
|
||||
/>
|
||||
</div>
|
||||
</LemonModal>
|
||||
|
||||
{(missingSchemaScopes || missingErrorTrackingScopes) && (
|
||||
<div className="space-y-2 mt-2">
|
||||
{missingSchemaScopes && (
|
||||
<LemonBanner type="warning">
|
||||
<b>Schema management unavailable:</b> The CLI needs both{' '}
|
||||
<code>event_definition</code> and <code>property_definition</code> permissions
|
||||
(read or write) to manage schemas.
|
||||
</LemonBanner>
|
||||
)}
|
||||
{missingErrorTrackingScopes && (
|
||||
<LemonBanner type="warning">
|
||||
<b>Error tracking unavailable:</b> The CLI needs <code>error_tracking</code>{' '}
|
||||
permissions (read or write) to manage error tracking.
|
||||
</LemonBanner>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<LemonButton
|
||||
type="primary"
|
||||
status="alt"
|
||||
htmlType="submit"
|
||||
data-attr="cli-authorize-submit"
|
||||
fullWidth
|
||||
center
|
||||
loading={isAuthorizeSubmitting}
|
||||
size="large"
|
||||
icon={<IconCode />}
|
||||
>
|
||||
Authorize CLI
|
||||
</LemonButton>
|
||||
</Form>
|
||||
</div>
|
||||
)}
|
||||
</BridgePage>
|
||||
)
|
||||
}
|
||||
242
frontend/src/scenes/authentication/cliAuthorizeLogic.ts
Normal file
@@ -0,0 +1,242 @@
|
||||
import { actions, afterMount, kea, listeners, path, reducers, selectors } from 'kea'
|
||||
import { forms } from 'kea-forms'
|
||||
import { loaders } from 'kea-loaders'
|
||||
import { urlToAction } from 'kea-router'
|
||||
|
||||
import api from 'lib/api'
|
||||
import { scopesArrayToObject } from 'lib/scopes'
|
||||
|
||||
import type { cliAuthorizeLogicType } from './cliAuthorizeLogicType'
|
||||
|
||||
export type CLIUseCase = 'schema' | 'error_tracking'
|
||||
|
||||
export interface CLIAuthorizeForm {
|
||||
userCode: string
|
||||
projectId: number | null
|
||||
scopes: string[]
|
||||
}
|
||||
|
||||
// Map use cases to their required scopes
|
||||
const USE_CASE_SCOPES: Record<CLIUseCase, string[]> = {
|
||||
schema: ['event_definition:read', 'property_definition:read'],
|
||||
error_tracking: ['error_tracking:write'],
|
||||
}
|
||||
|
||||
// Default use cases when none are specified
|
||||
const DEFAULT_USE_CASES: CLIUseCase[] = ['schema', 'error_tracking']
|
||||
|
||||
function getDefaultScopesForUseCases(useCases: CLIUseCase[]): string[] {
|
||||
const scopesSet = new Set<string>()
|
||||
for (const useCase of useCases) {
|
||||
const scopes = USE_CASE_SCOPES[useCase] || []
|
||||
scopes.forEach((scope) => scopesSet.add(scope))
|
||||
}
|
||||
return Array.from(scopesSet)
|
||||
}
|
||||
|
||||
// Pre-compute default scopes
|
||||
const DEFAULT_SCOPES = getDefaultScopesForUseCases(DEFAULT_USE_CASES)
|
||||
|
||||
export const cliAuthorizeLogic = kea<cliAuthorizeLogicType>([
|
||||
path(['scenes', 'authentication', 'cliAuthorizeLogic']),
|
||||
actions({
|
||||
setSuccess: (success: boolean) => ({ success }),
|
||||
setScopeRadioValue: (key: string, action: string) => ({ key, action }),
|
||||
setRequestedUseCases: (useCases: CLIUseCase[]) => ({ useCases }),
|
||||
updateDisplayedScopeSnapshot: true,
|
||||
setDisplayedScopeValues: (values: Record<string, string>) => ({ values }),
|
||||
}),
|
||||
reducers({
|
||||
isSuccess: [
|
||||
false,
|
||||
{
|
||||
setSuccess: (_, { success }) => success,
|
||||
},
|
||||
],
|
||||
requestedUseCases: [
|
||||
DEFAULT_USE_CASES,
|
||||
{
|
||||
setRequestedUseCases: (_, { useCases }) => useCases,
|
||||
},
|
||||
],
|
||||
displayedScopeValues: [
|
||||
{} as Record<string, string>,
|
||||
{
|
||||
setDisplayedScopeValues: (_, { values }) => values,
|
||||
},
|
||||
],
|
||||
}),
|
||||
loaders(() => ({
|
||||
projects: [
|
||||
[] as { id: number; name: string }[],
|
||||
{
|
||||
loadProjects: async () => {
|
||||
const response = await api.get('api/projects/')
|
||||
return response.results || []
|
||||
},
|
||||
},
|
||||
],
|
||||
})),
|
||||
forms(() => ({
|
||||
authorize: {
|
||||
defaults: {
|
||||
userCode: '',
|
||||
projectId: null,
|
||||
scopes: DEFAULT_SCOPES,
|
||||
} as CLIAuthorizeForm,
|
||||
errors: ({ userCode, projectId, scopes }) => ({
|
||||
userCode: !userCode
|
||||
? 'Please enter the code from your terminal'
|
||||
: userCode.length !== 9
|
||||
? 'Code must be 9 characters (XXXX-XXXX)'
|
||||
: undefined,
|
||||
projectId: !projectId ? 'Please select a project' : undefined,
|
||||
scopes: !scopes?.length ? ('Your API key needs at least one scope' as any) : undefined,
|
||||
}),
|
||||
submit: async ({ userCode, projectId, scopes }) => {
|
||||
try {
|
||||
const response = await api.create('api/cli-auth/authorize/', {
|
||||
user_code: userCode.toUpperCase().replace(/\s/g, ''),
|
||||
project_id: projectId,
|
||||
scopes: scopes,
|
||||
})
|
||||
return response
|
||||
} catch (error: any) {
|
||||
const errorCode = error?.code || error?.error
|
||||
if (errorCode === 'invalid_code') {
|
||||
throw { userCode: 'Invalid or expired code. Please try again.' }
|
||||
} else if (errorCode === 'expired') {
|
||||
throw { userCode: 'This code has expired. Please request a new code in your terminal.' }
|
||||
} else if (errorCode === 'access_denied') {
|
||||
throw { projectId: 'You do not have access to this project.' }
|
||||
} else if (errorCode === 'invalid_project') {
|
||||
throw { projectId: 'Project not found.' }
|
||||
} else {
|
||||
throw { userCode: 'An error occurred. Please try again.' }
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
})),
|
||||
selectors(() => ({
|
||||
formScopeRadioValues: [
|
||||
(s) => [s.authorize],
|
||||
(authorize): Record<string, string> => {
|
||||
if (!authorize || !authorize.scopes) {
|
||||
return {}
|
||||
}
|
||||
return scopesArrayToObject(authorize.scopes)
|
||||
},
|
||||
],
|
||||
missingSchemaScopes: [
|
||||
(s) => [s.authorize, s.requestedUseCases],
|
||||
(authorize, requestedUseCases): boolean => {
|
||||
// Only show warning if schema use case was requested
|
||||
if (!requestedUseCases.includes('schema')) {
|
||||
return false
|
||||
}
|
||||
if (!authorize || !authorize.scopes) {
|
||||
return false
|
||||
}
|
||||
// Warn if missing BOTH event_definition (read or write) AND property_definition (read or write)
|
||||
// Note: write permissions include read, so having write is sufficient
|
||||
const hasEventDefinition =
|
||||
authorize.scopes.includes('event_definition:read') ||
|
||||
authorize.scopes.includes('event_definition:write')
|
||||
const hasPropertyDefinition =
|
||||
authorize.scopes.includes('property_definition:read') ||
|
||||
authorize.scopes.includes('property_definition:write')
|
||||
return !hasEventDefinition || !hasPropertyDefinition
|
||||
},
|
||||
],
|
||||
missingErrorTrackingScopes: [
|
||||
(s) => [s.authorize, s.requestedUseCases],
|
||||
(authorize, requestedUseCases): boolean => {
|
||||
// Only show warning if error_tracking use case was requested
|
||||
if (!requestedUseCases.includes('error_tracking')) {
|
||||
return false
|
||||
}
|
||||
if (!authorize || !authorize.scopes) {
|
||||
return false
|
||||
}
|
||||
// Warn if missing error_tracking entirely (neither read nor write)
|
||||
// Note: write permissions include read, so having write is sufficient
|
||||
return (
|
||||
!authorize.scopes.includes('error_tracking:read') &&
|
||||
!authorize.scopes.includes('error_tracking:write')
|
||||
)
|
||||
},
|
||||
],
|
||||
})),
|
||||
listeners(({ actions, values }) => ({
|
||||
loadProjectsSuccess: () => {
|
||||
// Set default project to first project if not already set
|
||||
if (values.projects.length > 0 && !values.authorize.projectId) {
|
||||
actions.setAuthorizeValue('projectId', values.projects[0].id)
|
||||
}
|
||||
},
|
||||
setAuthorizeValue: (payload) => {
|
||||
// Initialize displayed scope values when scopes are first set
|
||||
if (payload.name === 'scopes' && Object.keys(values.displayedScopeValues).length === 0) {
|
||||
// Directly compute scope values from the scopes array being set
|
||||
const scopesArray = payload.value as string[]
|
||||
const scopeValues = scopesArrayToObject(scopesArray)
|
||||
actions.setDisplayedScopeValues(scopeValues)
|
||||
}
|
||||
},
|
||||
updateDisplayedScopeSnapshot: () => {
|
||||
// Update displayed scope values with current form values
|
||||
actions.setDisplayedScopeValues(values.formScopeRadioValues)
|
||||
},
|
||||
submitAuthorizeSuccess: () => {
|
||||
actions.setSuccess(true)
|
||||
},
|
||||
submitAuthorizeFailure: () => {
|
||||
// Error handling is done in the form errors
|
||||
},
|
||||
setRequestedUseCases: ({ useCases }) => {
|
||||
// Update scopes when requested use cases change
|
||||
const newScopes = getDefaultScopesForUseCases(useCases)
|
||||
actions.setAuthorizeValue('scopes', newScopes)
|
||||
// Directly compute and update displayed scope values
|
||||
const scopeValues = scopesArrayToObject(newScopes)
|
||||
actions.setDisplayedScopeValues(scopeValues)
|
||||
},
|
||||
setScopeRadioValue: ({ key, action }) => {
|
||||
if (!values.authorize || !values.authorize.scopes) {
|
||||
return
|
||||
}
|
||||
|
||||
// Remove existing scope with this key
|
||||
const filteredScopes = values.authorize.scopes.filter((scope) => !scope.startsWith(`${key}:`))
|
||||
|
||||
// Add new scope if not 'none'
|
||||
const newScopes = action === 'none' ? filteredScopes : [...filteredScopes, `${key}:${action}`]
|
||||
|
||||
actions.setAuthorizeValue('scopes', newScopes)
|
||||
},
|
||||
})),
|
||||
urlToAction(({ actions }) => ({
|
||||
'/cli/authorize': (_, searchParams) => {
|
||||
const code = searchParams.code
|
||||
if (code) {
|
||||
// Set the form field value directly
|
||||
actions.setAuthorizeValue('userCode', code)
|
||||
}
|
||||
|
||||
// Parse use_cases from URL (comma-separated)
|
||||
const useCasesParam = searchParams.use_cases
|
||||
if (useCasesParam) {
|
||||
const useCases = useCasesParam.split(',').filter((uc: string): uc is CLIUseCase => {
|
||||
return uc === 'schema' || uc === 'error_tracking'
|
||||
})
|
||||
if (useCases.length > 0) {
|
||||
actions.setRequestedUseCases(useCases)
|
||||
}
|
||||
}
|
||||
},
|
||||
})),
|
||||
afterMount(({ actions }) => {
|
||||
actions.loadProjects()
|
||||
}),
|
||||
])
|
||||
@@ -20,6 +20,7 @@ export enum Scene {
|
||||
BillingAuthorizationStatus = 'BillingAuthorizationStatus',
|
||||
BillingSection = 'BillingSection',
|
||||
Canvas = 'Canvas',
|
||||
CLIAuthorize = 'CLIAuthorize',
|
||||
Cohort = 'Cohort',
|
||||
CohortCalculationHistory = 'CohortCalculationHistory',
|
||||
Cohorts = 'Cohorts',
|
||||
|
||||
@@ -65,6 +65,12 @@ export const sceneConfigurations: Record<Scene | string, SceneConfig> = {
|
||||
defaultDocsPath: '/blog/introducing-notebooks',
|
||||
hideProjectNotice: true,
|
||||
},
|
||||
[Scene.CLIAuthorize]: {
|
||||
name: 'Authorize CLI',
|
||||
projectBased: false,
|
||||
organizationBased: false,
|
||||
layout: 'plain',
|
||||
},
|
||||
[Scene.Cohort]: { projectBased: true, name: 'Cohort', defaultDocsPath: '/docs/data/cohorts' },
|
||||
[Scene.CohortCalculationHistory]: { projectBased: true, name: 'Cohort Calculation History' },
|
||||
[Scene.Cohorts]: {
|
||||
@@ -674,6 +680,7 @@ export const routes: Record<string, [Scene | string, string]> = {
|
||||
[urls.site(':url')]: [Scene.Site, 'site'],
|
||||
[urls.login()]: [Scene.Login, 'login'],
|
||||
[urls.login2FA()]: [Scene.Login2FA, 'login2FA'],
|
||||
[urls.cliAuthorize()]: [Scene.CLIAuthorize, 'cliAuthorize'],
|
||||
[urls.emailMFAVerify()]: [Scene.EmailMFAVerify, 'emailMFAVerify'],
|
||||
[urls.preflight()]: [Scene.PreflightCheck, 'preflight'],
|
||||
[urls.signup()]: [Scene.Signup, 'signup'],
|
||||
|
||||
@@ -9,6 +9,7 @@ import api from 'lib/api'
|
||||
import { CodeSnippet } from 'lib/components/CodeSnippet'
|
||||
import { OrganizationMembershipLevel } from 'lib/constants'
|
||||
import { lemonToast } from 'lib/lemon-ui/LemonToast/LemonToast'
|
||||
import { scopesArrayToObject, scopesObjectToArray } from 'lib/scopes'
|
||||
import { hasMembershipLevelOrHigher, organizationAllowsPersonalApiKeysForMembers } from 'lib/utils/permissioning'
|
||||
import { urls } from 'scenes/urls'
|
||||
import { userLogic } from 'scenes/userLogic'
|
||||
@@ -138,14 +139,7 @@ export const personalAPIKeysLogic = kea<personalAPIKeysLogicType>([
|
||||
formScopeRadioValues: [
|
||||
(s) => [s.editingKey],
|
||||
(editingKey): Record<string, string> => {
|
||||
const result: Record<string, string> = {}
|
||||
|
||||
editingKey.scopes.forEach((scope) => {
|
||||
const [key, action] = scope.split(':')
|
||||
result[key] = action
|
||||
})
|
||||
|
||||
return result
|
||||
return scopesArrayToObject(editingKey.scopes)
|
||||
},
|
||||
],
|
||||
allAccessSelected: [
|
||||
@@ -374,11 +368,19 @@ export const personalAPIKeysLogic = kea<personalAPIKeysLogicType>([
|
||||
},
|
||||
|
||||
setScopeRadioValue: ({ key, action }) => {
|
||||
const newScopes = values.editingKey.scopes.filter((scope) => !scope.startsWith(key))
|
||||
if (action !== 'none') {
|
||||
newScopes.push(`${key}:${action}`)
|
||||
// Convert current scopes array to object for easier manipulation
|
||||
const scopesObject = scopesArrayToObject(values.editingKey.scopes)
|
||||
|
||||
// Update the specific scope
|
||||
if (action === 'none') {
|
||||
delete scopesObject[key]
|
||||
} else {
|
||||
scopesObject[key] = action
|
||||
}
|
||||
|
||||
// Convert back to array format
|
||||
const newScopes = scopesObjectToArray(scopesObject)
|
||||
|
||||
actions.setEditingKeyValue('scopes', newScopes)
|
||||
},
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ export const urls = {
|
||||
login: (): string => '/login',
|
||||
login2FA: (): string => '/login/2fa',
|
||||
login2FASetup: (): string => '/login/2fa_setup',
|
||||
cliAuthorize: (): string => '/cli/authorize',
|
||||
emailMFAVerify: (): string => '/login/verify',
|
||||
liveDebugger: (): string => '/live-debugger',
|
||||
passwordReset: (): string => '/reset',
|
||||
|
||||
@@ -68,6 +68,7 @@ from . import (
|
||||
app_metrics,
|
||||
async_migration,
|
||||
authentication,
|
||||
cli_auth,
|
||||
comments,
|
||||
dead_letter_queue,
|
||||
debug_ch_queries,
|
||||
@@ -521,6 +522,7 @@ router.register(r"login/email-mfa", authentication.EmailMFAViewSet, "login_email
|
||||
router.register(r"reset", authentication.PasswordResetViewSet, "password_reset")
|
||||
router.register(r"users", user.UserViewSet, "users")
|
||||
router.register(r"personal_api_keys", personal_api_key.PersonalAPIKeyViewSet, "personal_api_keys")
|
||||
router.register(r"cli-auth", cli_auth.CLIAuthViewSet, "cli_auth")
|
||||
router.register(r"instance_status", instance_status.InstanceStatusViewSet, "instance_status")
|
||||
router.register(r"dead_letter_queue", dead_letter_queue.DeadLetterQueueViewSet, "dead_letter_queue")
|
||||
router.register(r"async_migrations", async_migration.AsyncMigrationsViewset, "async_migrations")
|
||||
|
||||
355
posthog/api/cli_auth.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
CLI Authentication API using OAuth2 Device Flow
|
||||
|
||||
This implements the device authorization flow (RFC 8628) for the PostHog CLI.
|
||||
Users can authenticate without copying/pasting API keys.
|
||||
|
||||
Flow:
|
||||
1. CLI requests device code
|
||||
2. User opens browser and authorizes
|
||||
3. CLI polls for completion
|
||||
4. Returns Personal API Key
|
||||
"""
|
||||
|
||||
import string
|
||||
import secrets
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.utils import timezone
|
||||
|
||||
from rest_framework import serializers, status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from posthog.auth import SessionAuthentication
|
||||
from posthog.models import PersonalAPIKey, User
|
||||
from posthog.models.personal_api_key import hash_key_value
|
||||
from posthog.models.utils import generate_random_token_personal, mask_key_value
|
||||
|
||||
|
||||
class CLIAuthSessionAuthentication(SessionAuthentication):
|
||||
"""
|
||||
Custom session authentication for CLI authorization.
|
||||
|
||||
The user_code serves as an additional authorization token beyond CSRF,
|
||||
so we can safely skip CSRF validation for this specific endpoint.
|
||||
"""
|
||||
|
||||
def authenticate(self, request):
|
||||
"""Authenticate using session authentication."""
|
||||
return super().authenticate(request)
|
||||
|
||||
def enforce_csrf(self, request):
|
||||
"""Skip CSRF enforcement - the user_code acts as the authorization token."""
|
||||
return None
|
||||
|
||||
|
||||
# Device code lives for 10 minutes
|
||||
DEVICE_CODE_EXPIRY_SECONDS = 600
|
||||
|
||||
# CLI polling interval (5 seconds)
|
||||
CLI_POLL_INTERVAL_SECONDS = 5
|
||||
|
||||
# Scopes granted to CLI
|
||||
CLI_SCOPES = [
|
||||
"event_definition:read",
|
||||
"property_definition:read",
|
||||
"error_tracking:write",
|
||||
]
|
||||
|
||||
|
||||
def generate_user_code() -> str:
|
||||
"""Generate a human-readable code like 'ABCD-1234'"""
|
||||
letters = "".join(secrets.choice(string.ascii_uppercase) for _ in range(4))
|
||||
numbers = "".join(secrets.choice(string.digits) for _ in range(4))
|
||||
return f"{letters}-{numbers}"
|
||||
|
||||
|
||||
def generate_device_code() -> str:
|
||||
"""Generate a secure random device code"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def get_device_cache_key(device_code: str) -> str:
|
||||
"""Get cache key for device code"""
|
||||
return f"cli_device:{device_code}"
|
||||
|
||||
|
||||
def get_user_code_cache_key(user_code: str) -> str:
|
||||
"""Get cache key for user code"""
|
||||
return f"cli_user_code:{user_code}"
|
||||
|
||||
|
||||
class DeviceCodeRequestSerializer(serializers.Serializer):
|
||||
"""Request to initiate device authorization flow"""
|
||||
|
||||
pass # No input required
|
||||
|
||||
|
||||
class DeviceCodeResponseSerializer(serializers.Serializer):
|
||||
"""Response containing device and user codes"""
|
||||
|
||||
device_code = serializers.CharField(help_text="Code for CLI to poll with")
|
||||
user_code = serializers.CharField(help_text="Code for user to enter in browser")
|
||||
verification_uri = serializers.CharField(help_text="URL for user to visit")
|
||||
verification_uri_complete = serializers.CharField(help_text="URL with code pre-filled")
|
||||
expires_in = serializers.IntegerField(help_text="Seconds until code expires")
|
||||
interval = serializers.IntegerField(help_text="Polling interval in seconds")
|
||||
|
||||
|
||||
class DeviceAuthorizationSerializer(serializers.Serializer):
|
||||
"""User authorizes the device code"""
|
||||
|
||||
user_code = serializers.CharField(max_length=9, help_text="The user code displayed in CLI")
|
||||
project_id = serializers.IntegerField(help_text="The project to authorize CLI access for")
|
||||
scopes = serializers.ListField(
|
||||
child=serializers.CharField(),
|
||||
required=False,
|
||||
help_text="Scopes to grant to the CLI (defaults to CLI_SCOPES)",
|
||||
)
|
||||
|
||||
|
||||
class DevicePollSerializer(serializers.Serializer):
|
||||
"""CLI polls for authorization status"""
|
||||
|
||||
device_code = serializers.CharField(help_text="Device code from initial request")
|
||||
|
||||
|
||||
class DevicePollResponseSerializer(serializers.Serializer):
|
||||
"""Response to poll request"""
|
||||
|
||||
status = serializers.ChoiceField(choices=["pending", "authorized", "expired"])
|
||||
personal_api_key = serializers.CharField(required=False, help_text="The API key (only if authorized)")
|
||||
label = serializers.CharField(required=False, help_text="Label of the created key") # type: ignore[assignment]
|
||||
project_id = serializers.CharField(required=False, help_text="The project ID (only if authorized)")
|
||||
|
||||
|
||||
class CLIAuthViewSet(viewsets.ViewSet):
|
||||
"""
|
||||
OAuth2 Device Authorization Flow for CLI authentication
|
||||
|
||||
Endpoints:
|
||||
- POST /api/cli-auth/device-code/ (no auth required)
|
||||
- POST /api/cli-auth/authorize/ (session auth required)
|
||||
- POST /api/cli-auth/poll/ (no auth required)
|
||||
"""
|
||||
|
||||
def get_permissions(self):
|
||||
"""Authorize endpoint requires auth, others don't"""
|
||||
if getattr(self, "action", None) == "authorize":
|
||||
return [IsAuthenticated()]
|
||||
return [AllowAny()]
|
||||
|
||||
def get_authenticators(self):
|
||||
"""Only use session auth for browser-based authorization"""
|
||||
action = getattr(self, "action", None)
|
||||
|
||||
# Check both action and URL path since action might not be set yet
|
||||
if action == "authorize" or (hasattr(self, "request") and "authorize" in self.request.path):
|
||||
return [CLIAuthSessionAuthentication()]
|
||||
|
||||
return []
|
||||
|
||||
@action(methods=["POST"], detail=False, url_path="device-code")
|
||||
def device_code(self, request):
|
||||
"""
|
||||
Step 1: CLI requests device code
|
||||
|
||||
Returns device code for polling and user code for browser authorization.
|
||||
"""
|
||||
device_code = generate_device_code()
|
||||
user_code = generate_user_code()
|
||||
|
||||
# Store in cache with expiry
|
||||
device_cache_key = get_device_cache_key(device_code)
|
||||
cache.set(
|
||||
device_cache_key,
|
||||
{
|
||||
"user_code": user_code,
|
||||
"status": "pending",
|
||||
"created_at": timezone.now().isoformat(),
|
||||
},
|
||||
timeout=DEVICE_CODE_EXPIRY_SECONDS,
|
||||
)
|
||||
|
||||
# Also create reverse lookup (user_code -> device_code) for authorization
|
||||
user_code_cache_key = get_user_code_cache_key(user_code)
|
||||
cache.set(user_code_cache_key, device_code, timeout=DEVICE_CODE_EXPIRY_SECONDS)
|
||||
|
||||
# Get the base URL for verification
|
||||
# In production this would be the actual domain
|
||||
base_url = request.build_absolute_uri("/").rstrip("/")
|
||||
|
||||
response_data = {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": f"{base_url}/cli/authorize",
|
||||
"verification_uri_complete": f"{base_url}/cli/authorize?code={user_code}",
|
||||
"expires_in": DEVICE_CODE_EXPIRY_SECONDS,
|
||||
"interval": CLI_POLL_INTERVAL_SECONDS,
|
||||
}
|
||||
|
||||
serializer = DeviceCodeResponseSerializer(response_data)
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
@action(methods=["POST"], detail=False, url_path="authorize")
|
||||
def authorize(self, request):
|
||||
"""
|
||||
Step 2: User authorizes in browser
|
||||
|
||||
Requires authenticated session. Creates a Personal API Key and marks
|
||||
the device code as authorized.
|
||||
|
||||
The user_code itself acts as a single-use authorization token,
|
||||
providing additional security beyond CSRF tokens.
|
||||
"""
|
||||
serializer = DeviceAuthorizationSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
user_code = serializer.validated_data["user_code"]
|
||||
project_id = serializer.validated_data["project_id"]
|
||||
scopes = serializer.validated_data.get("scopes", CLI_SCOPES)
|
||||
|
||||
# Validate that at least one scope is provided
|
||||
if not scopes or len(scopes) == 0:
|
||||
return Response(
|
||||
{"error": "invalid_request", "error_description": "At least one scope is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Look up device code from user code
|
||||
user_code_cache_key = get_user_code_cache_key(user_code)
|
||||
device_code = cache.get(user_code_cache_key)
|
||||
if not device_code:
|
||||
return Response(
|
||||
{"error": "invalid_code", "error_description": "User code not found or expired"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Get device code data
|
||||
device_cache_key = get_device_cache_key(device_code)
|
||||
device_data = cache.get(device_cache_key)
|
||||
if not device_data:
|
||||
return Response(
|
||||
{"error": "expired", "error_description": "Device code expired"}, status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# Prevent duplicate authorization (race condition)
|
||||
if device_data.get("status") == "authorized":
|
||||
return Response(
|
||||
{"error": "already_authorized", "error_description": "This code has already been authorized"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Verify user has access to the project
|
||||
user: User = request.user
|
||||
from posthog.models import Team
|
||||
|
||||
try:
|
||||
team = Team.objects.get(id=project_id)
|
||||
# Check if user has access to this team's organization
|
||||
if not user.organization_memberships.filter(organization=team.organization).exists():
|
||||
return Response(
|
||||
{"error": "access_denied", "error_description": "You do not have access to this project"},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
except Team.DoesNotExist:
|
||||
return Response(
|
||||
{"error": "invalid_project", "error_description": "Project not found"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Create Personal API Key for the CLI
|
||||
api_key_value = generate_random_token_personal()
|
||||
mask_value = mask_key_value(api_key_value)
|
||||
secure_value = hash_key_value(api_key_value)
|
||||
|
||||
# Label max length is 40 chars, so truncate if needed
|
||||
timestamp = timezone.now().strftime("%Y-%m-%d %H:%M")
|
||||
max_team_name_len = 40 - len("CLI - ") - len(f" - {timestamp}")
|
||||
team_name_truncated = team.name[:max_team_name_len] if len(team.name) > max_team_name_len else team.name
|
||||
label = f"CLI - {team_name_truncated} - {timestamp}"
|
||||
|
||||
PersonalAPIKey.objects.create(
|
||||
user=user,
|
||||
label=label,
|
||||
secure_value=secure_value,
|
||||
mask_value=mask_value,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
# Mark device as authorized and store the API key
|
||||
device_data["status"] = "authorized"
|
||||
device_data["personal_api_key"] = api_key_value
|
||||
device_data["label"] = label
|
||||
device_data["project_id"] = str(project_id)
|
||||
device_data["authorized_at"] = timezone.now().isoformat()
|
||||
device_data["user_id"] = user.id
|
||||
|
||||
# Update cache with longer TTL to ensure CLI can poll
|
||||
cache.set(device_cache_key, device_data, timeout=60) # 1 minute to retrieve
|
||||
|
||||
return Response(
|
||||
{
|
||||
"status": "success",
|
||||
"label": label,
|
||||
"mask_value": mask_value,
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
@action(methods=["POST"], detail=False, url_path="poll")
|
||||
def poll(self, request):
|
||||
"""
|
||||
Step 3: CLI polls for authorization status
|
||||
|
||||
Returns:
|
||||
- 202: Still pending (keep polling)
|
||||
- 200: Authorized (includes API key)
|
||||
- 400: Expired or invalid
|
||||
"""
|
||||
serializer = DevicePollSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
device_code = serializer.validated_data["device_code"]
|
||||
|
||||
# Look up device code
|
||||
device_cache_key = get_device_cache_key(device_code)
|
||||
device_data = cache.get(device_cache_key)
|
||||
|
||||
if not device_data:
|
||||
return Response(
|
||||
{"status": "expired", "error": "expired_token", "error_description": "Device code expired"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
if device_data["status"] == "pending":
|
||||
# Still waiting for authorization
|
||||
return Response(
|
||||
{"status": "pending"},
|
||||
status=status.HTTP_202_ACCEPTED, # Indicates to keep polling
|
||||
)
|
||||
|
||||
if device_data["status"] == "authorized":
|
||||
# Success! Return the API key
|
||||
response_data = {
|
||||
"status": "authorized",
|
||||
"personal_api_key": device_data["personal_api_key"],
|
||||
"label": device_data["label"],
|
||||
"project_id": device_data["project_id"],
|
||||
}
|
||||
|
||||
# Clean up - key has been retrieved
|
||||
cache.delete(device_cache_key)
|
||||
user_code_cache_key = get_user_code_cache_key(device_data["user_code"])
|
||||
cache.delete(user_code_cache_key)
|
||||
|
||||
response_serializer = DevicePollResponseSerializer(response_data)
|
||||
return Response(response_serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
# Unknown status
|
||||
return Response(
|
||||
{"error": "invalid_request", "error_description": "Invalid device code status"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Manager
|
||||
from django.db.models import Manager, Q
|
||||
|
||||
from loginas.utils import is_impersonated_session
|
||||
from rest_framework import mixins, request, response, serializers, status, viewsets
|
||||
@@ -16,7 +18,7 @@ from posthog.constants import AvailableFeature, EventDefinitionType
|
||||
from posthog.event_usage import report_user_action
|
||||
from posthog.exceptions import EnterpriseFeatureException
|
||||
from posthog.filters import TermSearchFilterBackend, term_search_filter_sql
|
||||
from posthog.models import EventDefinition, Team
|
||||
from posthog.models import EventDefinition, EventSchema, Team
|
||||
from posthog.models.activity_logging.activity_log import Detail, log_activity
|
||||
from posthog.models.user import User
|
||||
from posthog.models.utils import UUIDT
|
||||
@@ -261,6 +263,269 @@ class EventDefinitionViewSet(
|
||||
)
|
||||
return response.Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
# Version of the TypeScript generator - increment when changing the structure
|
||||
# This ensures clients update even when schemas don't change
|
||||
TYPESCRIPT_GENERATOR_VERSION = "1.0.0"
|
||||
|
||||
@action(detail=False, methods=["GET"], url_path="typescript", required_scopes=["event_definition:read"])
|
||||
def typescript_definitions(self, *args, **kwargs):
|
||||
"""Generate TypeScript definitions from event schemas"""
|
||||
# System events that users should be able to manually capture
|
||||
# These are commonly used in user code and should be typed
|
||||
included_system_events = [
|
||||
"$pageview", # Manually captured in SPAs (React Router, Vue Router, etc.)
|
||||
"$pageleave", # Sometimes manually captured alongside $pageview
|
||||
"$screen", # Manually captured in mobile apps (iOS, Android, React Native, Flutter)
|
||||
]
|
||||
|
||||
# Fetch event definitions: either non-system events or explicitly included system events
|
||||
event_definitions = (
|
||||
EventDefinition.objects.filter(team__project_id=self.project_id)
|
||||
.filter(
|
||||
Q(name__in=included_system_events) # Include whitelisted system events
|
||||
| ~Q(name__startswith="$") # Include all non-system events
|
||||
)
|
||||
.order_by("name")
|
||||
)
|
||||
|
||||
# Fetch all event schemas with their property groups
|
||||
event_schemas = (
|
||||
EventSchema.objects.filter(event_definition__team__project_id=self.project_id)
|
||||
.select_related("property_group")
|
||||
.prefetch_related("property_group__properties")
|
||||
)
|
||||
|
||||
# Build a mapping of event_definition_id -> property group properties
|
||||
schema_map: dict[str, list[Any]] = {}
|
||||
for event_schema in event_schemas:
|
||||
event_id = str(event_schema.event_definition_id)
|
||||
if event_id not in schema_map:
|
||||
schema_map[event_id] = []
|
||||
schema_map[event_id].extend(event_schema.property_group.properties.all())
|
||||
|
||||
# Calculate deterministic hash based on schema data AND generator version
|
||||
# This ensures clients update when either schemas or generator structure changes
|
||||
schema_data = []
|
||||
for event_def in event_definitions:
|
||||
properties = schema_map.get(str(event_def.id), [])
|
||||
prop_data = [(p.name, p.property_type, p.is_required) for p in properties]
|
||||
schema_data.append((event_def.name, sorted(prop_data)))
|
||||
|
||||
# Include version in hash calculation
|
||||
hash_input = {
|
||||
"version": self.TYPESCRIPT_GENERATOR_VERSION,
|
||||
"schemas": schema_data,
|
||||
}
|
||||
schema_hash = hashlib.sha256(json.dumps(hash_input, sort_keys=True).encode()).hexdigest()[:32]
|
||||
|
||||
# Generate TypeScript definitions
|
||||
ts_content = self._generate_typescript(event_definitions, schema_map)
|
||||
|
||||
return response.Response(
|
||||
{
|
||||
"content": ts_content,
|
||||
"event_count": len(event_definitions),
|
||||
"schema_hash": schema_hash,
|
||||
"generator_version": self.TYPESCRIPT_GENERATOR_VERSION,
|
||||
}
|
||||
)
|
||||
|
||||
def _generate_typescript(self, event_definitions, schema_map):
|
||||
"""Generate complete TypeScript module with type definitions and exports"""
|
||||
# Generate file header
|
||||
header = f"""/**
|
||||
* GENERATED FILE - DO NOT EDIT
|
||||
*
|
||||
* This file was auto-generated by PostHog
|
||||
* Generated at: {datetime.now().isoformat()}
|
||||
* Generator version: {self.TYPESCRIPT_GENERATOR_VERSION}
|
||||
*
|
||||
* Provides capture() for type-safe events and captureRaw() for flexibility
|
||||
*/
|
||||
import originalPostHog from 'posthog-js'
|
||||
import type {{ CaptureOptions, CaptureResult, PostHog as OriginalPostHog, Properties }} from 'posthog-js'
|
||||
"""
|
||||
|
||||
# Generate event schemas interface
|
||||
event_schemas_lines = [
|
||||
"// Define event schemas with their required and optional fields",
|
||||
"interface EventSchemas {",
|
||||
]
|
||||
|
||||
for event_def in event_definitions:
|
||||
properties = schema_map.get(str(event_def.id), [])
|
||||
# Escape event name for use as object key
|
||||
event_name = event_def.name.replace("'", "\\'")
|
||||
|
||||
if not properties:
|
||||
event_schemas_lines.append(f" '{event_name}': Record<string, any>")
|
||||
else:
|
||||
event_schemas_lines.append(f" '{event_name}': {{")
|
||||
for prop in properties:
|
||||
ts_type = self._map_property_type(prop.property_type)
|
||||
optional_marker = "" if prop.is_required else "?"
|
||||
# Always quote property names (simpler and handles all edge cases)
|
||||
prop_name = f"'{prop.name.replace("'", "\\'")}'"
|
||||
event_schemas_lines.append(f" {prop_name}{optional_marker}: {ts_type}")
|
||||
event_schemas_lines.append(" }")
|
||||
|
||||
event_schemas_lines.append("}")
|
||||
event_schemas = "\n".join(event_schemas_lines)
|
||||
|
||||
# Generate type aliases
|
||||
type_aliases = """
|
||||
// Type alias for all valid event names
|
||||
export type EventName = keyof EventSchemas
|
||||
|
||||
// Type helper to get properties for a specific event
|
||||
// Intersects the schema with Record<string, any> to allow additional properties
|
||||
export type EventProperties<K extends EventName> = EventSchemas[K] & Record<string, any>
|
||||
|
||||
// Helper type to check if a type has required properties
|
||||
type HasRequiredProperties<K extends EventName> = {} extends EventSchemas[K] ? false : true
|
||||
|
||||
// Helper to detect if T is exactly 'string' (not a literal)
|
||||
type IsExactlyString<T> = string extends T ? (T extends string ? true : false) : false
|
||||
"""
|
||||
|
||||
# Generate TypedPostHog interface
|
||||
typed_posthog_interface = """
|
||||
// Enhanced PostHog interface with typed capture
|
||||
interface TypedPostHog extends Omit<OriginalPostHog, 'capture'> {
|
||||
/**
|
||||
* Type-safe capture for defined events, or flexible capture for undefined events
|
||||
*
|
||||
* Note: For defined events, wrap properties in a variable to allow additional properties:
|
||||
* const props = { file_size_b: 100, extra: 'data' }
|
||||
* posthog.capture('downloaded_file', props)
|
||||
*
|
||||
* @example
|
||||
* // Defined event with type safety
|
||||
* posthog.capture('uploaded_file', {
|
||||
* file_name: 'test.txt',
|
||||
* file_size_b: 100
|
||||
* })
|
||||
*
|
||||
* @example
|
||||
* // For events with all optional properties, properties argument is optional
|
||||
* posthog.capture('logged_out') // no properties needed
|
||||
*
|
||||
* @example
|
||||
* // Undefined events work with arbitrary properties
|
||||
* posthog.capture('custom_event', { whatever: 'data' })
|
||||
* posthog.capture('another_event') // or no properties
|
||||
*/
|
||||
// Overload 1: For known events (specific EventName literals)
|
||||
// This should match first for all known event names
|
||||
capture<K extends EventName>(
|
||||
event_name: K,
|
||||
...args: HasRequiredProperties<K> extends true
|
||||
? [properties: EventProperties<K>, options?: CaptureOptions]
|
||||
: [properties?: EventProperties<K>, options?: CaptureOptions]
|
||||
): CaptureResult | undefined
|
||||
|
||||
// Overload 2: For undefined events and blocking string variables
|
||||
// Only matches if event_name is NOT a known EventName
|
||||
// The conditional type rejects broad string type
|
||||
capture<T extends string>(
|
||||
event_name: IsExactlyString<T> extends true ? never : (T extends EventName ? never : T),
|
||||
properties?: Properties | null,
|
||||
options?: CaptureOptions
|
||||
): CaptureResult | undefined
|
||||
|
||||
/**
|
||||
* Raw capture for any event (original behavior, no type checking)
|
||||
*
|
||||
* Use capture() for type-safe defined events or flexible undefined events.
|
||||
* Use captureRaw() only when you need to bypass all type checking.
|
||||
*
|
||||
* @example
|
||||
* posthog.captureRaw('Any Event Name', { whatever: 'data' })
|
||||
*/
|
||||
captureRaw(event_name: string, properties?: Properties | null, options?: CaptureOptions): CaptureResult | undefined
|
||||
}
|
||||
"""
|
||||
|
||||
# Generate implementation
|
||||
implementation = """
|
||||
// Create the implementation
|
||||
const createTypedPostHog = (original: OriginalPostHog): TypedPostHog => {
|
||||
// Create the enhanced PostHog object
|
||||
const enhanced: TypedPostHog = Object.create(original)
|
||||
|
||||
// Add capture method (type-safe for defined events, flexible for undefined)
|
||||
enhanced.capture = function (event_name: string, ...args: any[]): CaptureResult | undefined {
|
||||
const [properties, options] = args
|
||||
return original.capture(event_name, properties, options)
|
||||
}
|
||||
|
||||
// Add captureRaw method for untyped/flexible event tracking
|
||||
enhanced.captureRaw = function (
|
||||
event_name: string,
|
||||
properties?: Properties | null,
|
||||
options?: CaptureOptions
|
||||
): CaptureResult | undefined {
|
||||
return original.capture(event_name, properties, options)
|
||||
}
|
||||
|
||||
// Proxy to delegate all other properties/methods to the original
|
||||
return new Proxy(enhanced, {
|
||||
get(target, prop) {
|
||||
if (prop in target) {
|
||||
return (target as any)[prop]
|
||||
}
|
||||
return (original as any)[prop]
|
||||
},
|
||||
set(target, prop, value) {
|
||||
;(original as any)[prop] = value
|
||||
return true
|
||||
},
|
||||
})
|
||||
}
|
||||
"""
|
||||
|
||||
# Generate exports
|
||||
exports = """
|
||||
// Create and export the typed instance
|
||||
const posthog = createTypedPostHog(originalPostHog as OriginalPostHog)
|
||||
|
||||
export default posthog
|
||||
export { posthog }
|
||||
export type { EventSchemas, TypedPostHog }
|
||||
|
||||
// Re-export everything else from posthog-js
|
||||
export * from 'posthog-js'
|
||||
|
||||
/**
|
||||
* USAGE GUIDE
|
||||
* ===========
|
||||
*
|
||||
* For type-safe defined events (recommended):
|
||||
* posthog.capture('uploaded_file', { file_name: 'test.txt', file_size_b: 100 })
|
||||
*
|
||||
* For undefined events (flexible):
|
||||
* posthog.capture('Custom Event', { whatever: 'data' })
|
||||
*
|
||||
* For bypassing all type checking (rare):
|
||||
* posthog.captureRaw('Any Event', { whatever: 'data' })
|
||||
*/
|
||||
"""
|
||||
|
||||
# Combine all sections
|
||||
return header + event_schemas + type_aliases + typed_posthog_interface + implementation + exports
|
||||
|
||||
def _map_property_type(self, property_type: str) -> str:
|
||||
"""Map PostHog property types to TypeScript types"""
|
||||
type_map = {
|
||||
"String": "string",
|
||||
"Numeric": "number",
|
||||
"Boolean": "boolean",
|
||||
"DateTime": "string | Date",
|
||||
"Array": "any[]",
|
||||
"Object": "Record<string, any>",
|
||||
}
|
||||
return type_map.get(property_type, "any")
|
||||
|
||||
@action(detail=True, methods=["GET"], url_path="metrics")
|
||||
def metrics_totals(self, *args, **kwargs):
|
||||
instance: EventDefinition = self.get_object()
|
||||
|
||||
476
posthog/api/test/test_cli_auth.py
Normal file
@@ -0,0 +1,476 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from freezegun import freeze_time
|
||||
from posthog.test.base import APIBaseTest
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.utils import timezone
|
||||
|
||||
from rest_framework import status
|
||||
|
||||
from posthog.api.cli_auth import CLI_SCOPES, DEVICE_CODE_EXPIRY_SECONDS, get_device_cache_key, get_user_code_cache_key
|
||||
from posthog.models import PersonalAPIKey, Team, User
|
||||
from posthog.models.organization import Organization
|
||||
from posthog.models.personal_api_key import hash_key_value
|
||||
|
||||
|
||||
class TestCLIAuthDeviceCodeEndpoint(APIBaseTest):
|
||||
"""
|
||||
Tests for the device code request endpoint (step 1 of OAuth device flow)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
cache.clear() # Clear cache before each test
|
||||
|
||||
def test_device_code_request_returns_correct_data(self):
|
||||
"""Test that requesting a device code returns all required fields"""
|
||||
response = self.client.post("/api/cli-auth/device-code/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
|
||||
# Check all required fields are present
|
||||
self.assertIn("device_code", data)
|
||||
self.assertIn("user_code", data)
|
||||
self.assertIn("verification_uri", data)
|
||||
self.assertIn("verification_uri_complete", data)
|
||||
self.assertIn("expires_in", data)
|
||||
self.assertIn("interval", data)
|
||||
|
||||
# Check values are correct
|
||||
self.assertEqual(data["expires_in"], DEVICE_CODE_EXPIRY_SECONDS)
|
||||
self.assertIn("/cli/authorize", data["verification_uri"])
|
||||
self.assertIn(data["user_code"], data["verification_uri_complete"])
|
||||
|
||||
# Verify user code format (XXXX-XXXX)
|
||||
user_code = data["user_code"]
|
||||
self.assertEqual(len(user_code), 9)
|
||||
self.assertEqual(user_code[4], "-")
|
||||
self.assertTrue(user_code[:4].isalpha())
|
||||
self.assertTrue(user_code[5:].isdigit())
|
||||
|
||||
def test_device_code_is_stored_in_cache(self):
|
||||
"""Test that device code and user code are properly stored in cache"""
|
||||
response = self.client.post("/api/cli-auth/device-code/")
|
||||
data = response.json()
|
||||
|
||||
device_code = data["device_code"]
|
||||
user_code = data["user_code"]
|
||||
|
||||
# Check device code is in cache
|
||||
device_cache_key = get_device_cache_key(device_code)
|
||||
device_data = cache.get(device_cache_key)
|
||||
self.assertIsNotNone(device_data)
|
||||
self.assertEqual(device_data["user_code"], user_code)
|
||||
self.assertEqual(device_data["status"], "pending")
|
||||
|
||||
# Check user code reverse lookup is in cache
|
||||
user_code_cache_key = get_user_code_cache_key(user_code)
|
||||
cached_device_code = cache.get(user_code_cache_key)
|
||||
self.assertEqual(cached_device_code, device_code)
|
||||
|
||||
def test_device_code_works_without_authentication(self):
|
||||
"""Test that device code endpoint works for unauthenticated requests"""
|
||||
self.client.logout()
|
||||
response = self.client.post("/api/cli-auth/device-code/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
class TestCLIAuthAuthorizeEndpoint(APIBaseTest):
|
||||
"""
|
||||
Tests for the authorization endpoint (step 2 of OAuth device flow)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
cache.clear()
|
||||
|
||||
# Create a device code for testing
|
||||
response = self.client.post("/api/cli-auth/device-code/")
|
||||
self.device_data = response.json()
|
||||
self.device_code = self.device_data["device_code"]
|
||||
self.user_code = self.device_data["user_code"]
|
||||
|
||||
def test_successful_authorization_creates_api_key(self):
|
||||
"""Test that successful authorization creates a Personal API Key"""
|
||||
initial_key_count = PersonalAPIKey.objects.filter(user=self.user).count()
|
||||
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
self.assertEqual(data["status"], "success")
|
||||
self.assertIn("label", data)
|
||||
self.assertIn("mask_value", data)
|
||||
|
||||
# Check that API key was created
|
||||
new_key_count = PersonalAPIKey.objects.filter(user=self.user).count()
|
||||
self.assertEqual(new_key_count, initial_key_count + 1)
|
||||
|
||||
# Check the created key has correct scopes
|
||||
api_key = PersonalAPIKey.objects.filter(user=self.user).order_by("-created_at").first()
|
||||
assert api_key is not None
|
||||
self.assertEqual(api_key.scopes, CLI_SCOPES)
|
||||
self.assertIn("CLI -", api_key.label)
|
||||
|
||||
def test_authorization_requires_authentication(self):
|
||||
"""Test that authorization endpoint requires user to be logged in"""
|
||||
self.client.logout()
|
||||
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
def test_authorization_rejects_invalid_user_code(self):
|
||||
"""Test that authorization fails with invalid user code"""
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": "XXXX-9999", "project_id": self.team.id},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
data = response.json()
|
||||
self.assertIn("error", data)
|
||||
self.assertEqual(data["error"], "invalid_code")
|
||||
|
||||
def test_authorization_rejects_expired_user_code(self):
|
||||
"""Test that authorization fails with expired user code"""
|
||||
# Wait for the code to expire
|
||||
with freeze_time(timezone.now() + timedelta(seconds=DEVICE_CODE_EXPIRY_SECONDS + 1)):
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response.json()["error"], "invalid_code")
|
||||
|
||||
def test_authorization_rejects_user_without_team_access(self):
|
||||
"""Test that user cannot authorize for a team they don't have access to"""
|
||||
# Create another organization and team
|
||||
other_org = Organization.objects.create(name="Other Org")
|
||||
other_team = Team.objects.create(organization=other_org, name="Other Team")
|
||||
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": other_team.id},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
self.assertEqual(response.json()["error"], "access_denied")
|
||||
|
||||
# Verify no API key was created
|
||||
api_keys = PersonalAPIKey.objects.filter(user=self.user).count()
|
||||
self.assertEqual(api_keys, 0)
|
||||
|
||||
def test_authorization_rejects_nonexistent_project(self):
|
||||
"""Test that authorization fails with non-existent project ID"""
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": 99999},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response.json()["error"], "invalid_project")
|
||||
|
||||
def test_authorization_updates_cache_with_api_key(self):
|
||||
"""Test that authorization updates the cache with the API key"""
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
# Check cache was updated
|
||||
device_cache_key = get_device_cache_key(self.device_code)
|
||||
device_data = cache.get(device_cache_key)
|
||||
|
||||
self.assertEqual(device_data["status"], "authorized")
|
||||
self.assertIn("personal_api_key", device_data)
|
||||
self.assertEqual(device_data["project_id"], str(self.team.id))
|
||||
self.assertEqual(device_data["user_id"], self.user.id)
|
||||
|
||||
def test_multiple_users_can_authorize_different_codes(self):
|
||||
"""Test that multiple users can authorize different device codes concurrently"""
|
||||
# Create another user
|
||||
other_user = User.objects.create_and_join(self.organization, "other@posthog.com", "password123")
|
||||
|
||||
# Create device code for other user
|
||||
response2 = self.client.post("/api/cli-auth/device-code/")
|
||||
user_code2 = response2.json()["user_code"]
|
||||
|
||||
# First user authorizes their code
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
# Switch to other user
|
||||
self.client.force_login(other_user)
|
||||
|
||||
# Other user authorizes their code
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": user_code2, "project_id": self.team.id},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
# Both users should have their own API keys
|
||||
self.assertEqual(PersonalAPIKey.objects.filter(user=self.user).count(), 1)
|
||||
self.assertEqual(PersonalAPIKey.objects.filter(user=other_user).count(), 1)
|
||||
|
||||
def test_authorization_prevents_duplicate_api_keys_from_race_condition(self):
|
||||
"""Test that attempting to authorize the same code twice does not create duplicate API keys"""
|
||||
initial_key_count = PersonalAPIKey.objects.filter(user=self.user).count()
|
||||
|
||||
# First authorization succeeds
|
||||
response1 = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
self.assertEqual(response1.status_code, status.HTTP_200_OK)
|
||||
|
||||
# Second authorization attempt should fail (code already authorized)
|
||||
response2 = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
self.assertEqual(response2.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response2.json()["error"], "already_authorized")
|
||||
|
||||
# Verify only one API key was created
|
||||
final_key_count = PersonalAPIKey.objects.filter(user=self.user).count()
|
||||
self.assertEqual(final_key_count, initial_key_count + 1)
|
||||
|
||||
|
||||
class TestCLIAuthPollEndpoint(APIBaseTest):
|
||||
"""
|
||||
Tests for the poll endpoint (step 3 of OAuth device flow)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
cache.clear()
|
||||
|
||||
# Create and authorize a device code
|
||||
response = self.client.post("/api/cli-auth/device-code/")
|
||||
self.device_data = response.json()
|
||||
self.device_code = self.device_data["device_code"]
|
||||
self.user_code = self.device_data["user_code"]
|
||||
|
||||
def test_poll_returns_pending_before_authorization(self):
|
||||
"""Test that polling returns pending status before user authorizes"""
|
||||
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
data = response.json()
|
||||
self.assertEqual(data["status"], "pending")
|
||||
|
||||
def test_poll_returns_api_key_after_authorization(self):
|
||||
"""Test that polling returns API key after user authorizes"""
|
||||
# Authorize the code
|
||||
self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
# Poll for the result
|
||||
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
data = response.json()
|
||||
self.assertEqual(data["status"], "authorized")
|
||||
self.assertIn("personal_api_key", data)
|
||||
self.assertIn("label", data)
|
||||
self.assertEqual(data["project_id"], str(self.team.id))
|
||||
|
||||
# Verify the API key is valid
|
||||
api_key = data["personal_api_key"]
|
||||
self.assertTrue(api_key.startswith("phx_"))
|
||||
|
||||
def test_poll_returns_expired_for_old_code(self):
|
||||
"""Test that polling returns expired for old device codes"""
|
||||
with freeze_time(timezone.now() + timedelta(seconds=DEVICE_CODE_EXPIRY_SECONDS + 1)):
|
||||
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
data = response.json()
|
||||
self.assertEqual(data["status"], "expired")
|
||||
|
||||
def test_poll_returns_expired_for_nonexistent_code(self):
|
||||
"""Test that polling returns expired for non-existent device codes"""
|
||||
response = self.client.post("/api/cli-auth/poll/", {"device_code": "nonexistent_code"})
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
data = response.json()
|
||||
self.assertEqual(data["status"], "expired")
|
||||
|
||||
def test_poll_cleans_up_cache_after_successful_retrieval(self):
|
||||
"""Test that cache is cleaned up after API key is retrieved"""
|
||||
# Authorize the code
|
||||
self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
# Poll for the result
|
||||
self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
|
||||
|
||||
# Verify cache is cleaned up
|
||||
device_cache_key = get_device_cache_key(self.device_code)
|
||||
user_code_cache_key = get_user_code_cache_key(self.user_code)
|
||||
|
||||
self.assertIsNone(cache.get(device_cache_key))
|
||||
self.assertIsNone(cache.get(user_code_cache_key))
|
||||
|
||||
def test_poll_can_be_called_multiple_times_before_authorization(self):
|
||||
"""Test that poll can be called multiple times while pending"""
|
||||
# Poll multiple times
|
||||
for _ in range(3):
|
||||
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
self.assertEqual(response.json()["status"], "pending")
|
||||
|
||||
def test_poll_works_without_authentication(self):
|
||||
"""Test that poll endpoint works for unauthenticated requests"""
|
||||
self.client.logout()
|
||||
response = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
|
||||
def test_poll_returns_api_key_only_once(self):
|
||||
"""Test that API key can only be retrieved once"""
|
||||
# Authorize the code
|
||||
self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": self.user_code, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
# First poll succeeds
|
||||
response1 = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
|
||||
self.assertEqual(response1.status_code, status.HTTP_200_OK)
|
||||
|
||||
# Second poll fails (cache cleaned up)
|
||||
response2 = self.client.post("/api/cli-auth/poll/", {"device_code": self.device_code})
|
||||
self.assertEqual(response2.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response2.json()["status"], "expired")
|
||||
|
||||
|
||||
class TestCLIAuthEndToEnd(APIBaseTest):
|
||||
"""
|
||||
End-to-end tests for the complete CLI authentication flow
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
cache.clear()
|
||||
|
||||
def test_complete_authentication_flow(self):
|
||||
"""Test the complete device flow from start to finish"""
|
||||
# Step 1: Request device code (CLI)
|
||||
response = self.client.post("/api/cli-auth/device-code/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
device_data = response.json()
|
||||
device_code = device_data["device_code"]
|
||||
user_code = device_data["user_code"]
|
||||
|
||||
# Step 2: User opens browser and authorizes
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": user_code, "project_id": self.team.id},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
# Step 3: CLI polls and gets API key
|
||||
response = self.client.post("/api/cli-auth/poll/", {"device_code": device_code})
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
api_key = response.json()["personal_api_key"]
|
||||
|
||||
# Step 4: Verify the API key works
|
||||
self.client.logout()
|
||||
response = self.client.get(
|
||||
f"/api/projects/{self.team.pk}/event_definitions/",
|
||||
HTTP_AUTHORIZATION=f"Bearer {api_key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
def test_api_key_has_correct_scopes(self):
|
||||
"""Test that created API key has only the CLI scopes"""
|
||||
# Complete the flow
|
||||
response = self.client.post("/api/cli-auth/device-code/")
|
||||
device_code = response.json()["device_code"]
|
||||
user_code = response.json()["user_code"]
|
||||
|
||||
self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": user_code, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
response = self.client.post("/api/cli-auth/poll/", {"device_code": device_code})
|
||||
api_key_value = response.json()["personal_api_key"]
|
||||
|
||||
# Get the API key from database
|
||||
api_key = PersonalAPIKey.objects.get(secure_value=hash_key_value(api_key_value))
|
||||
|
||||
# Verify scopes
|
||||
self.assertEqual(api_key.scopes, CLI_SCOPES)
|
||||
self.assertIn("event_definition:read", api_key.scopes)
|
||||
self.assertIn("property_definition:read", api_key.scopes)
|
||||
self.assertIn("error_tracking:write", api_key.scopes)
|
||||
|
||||
def test_cross_team_access_is_prevented(self):
|
||||
"""Test that user cannot authorize CLI for a team in a different organization"""
|
||||
# Create a new organization that the user is NOT a member of
|
||||
other_org = Organization.objects.create(name="Other Organization")
|
||||
other_team = Team.objects.create(organization=other_org, name="Other Team")
|
||||
|
||||
# Complete device code request
|
||||
response = self.client.post("/api/cli-auth/device-code/")
|
||||
user_code = response.json()["user_code"]
|
||||
|
||||
# Try to authorize for the other team
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": user_code, "project_id": other_team.id},
|
||||
)
|
||||
|
||||
# Should be rejected
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
self.assertEqual(response.json()["error"], "access_denied")
|
||||
|
||||
# Verify no API key was created
|
||||
self.assertEqual(PersonalAPIKey.objects.filter(user=self.user).count(), 0)
|
||||
|
||||
def test_user_can_authorize_for_multiple_teams_in_same_org(self):
|
||||
"""Test that user can authorize CLI for multiple teams in the same organization"""
|
||||
# Create another team in the same organization
|
||||
team2 = Team.objects.create(organization=self.organization, name="Team 2")
|
||||
|
||||
# Authorize for first team
|
||||
response1 = self.client.post("/api/cli-auth/device-code/")
|
||||
user_code1 = response1.json()["user_code"]
|
||||
self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": user_code1, "project_id": self.team.id},
|
||||
)
|
||||
|
||||
# Authorize for second team
|
||||
response2 = self.client.post("/api/cli-auth/device-code/")
|
||||
user_code2 = response2.json()["user_code"]
|
||||
response = self.client.post(
|
||||
"/api/cli-auth/authorize/",
|
||||
{"user_code": user_code2, "project_id": team2.id},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
# Both API keys should be created
|
||||
self.assertEqual(PersonalAPIKey.objects.filter(user=self.user).count(), 2)
|
||||
275
posthog/api/test/test_event_definition_typescript.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Integration test for TypeScript definition generation.
|
||||
|
||||
Tests the complete flow:
|
||||
1. Create EventDefinitions with EventSchemas
|
||||
2. Generate TypeScript via the typescript_definitions method
|
||||
3. Write generated TypeScript to temp file
|
||||
4. Create a test file that uses the types
|
||||
5. Run TypeScript compiler to verify no errors
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from posthog.test.base import APIBaseTest
|
||||
|
||||
from rest_framework import status
|
||||
|
||||
from posthog.models import EventDefinition, EventSchema, SchemaPropertyGroup, SchemaPropertyGroupProperty
|
||||
|
||||
|
||||
class TestEventDefinitionTypeScriptGeneration(APIBaseTest):
|
||||
"""
|
||||
Critical integration test ensuring TypeScript generation maintains type safety
|
||||
while allowing additional properties beyond the schema.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# Create property group with required and optional fields
|
||||
self.property_group = SchemaPropertyGroup.objects.create(
|
||||
team=self.team, project=self.project, name="Test Properties"
|
||||
)
|
||||
|
||||
SchemaPropertyGroupProperty.objects.create(
|
||||
property_group=self.property_group,
|
||||
name="required_field",
|
||||
property_type="Numeric",
|
||||
is_required=True,
|
||||
description="A required numeric field",
|
||||
)
|
||||
|
||||
SchemaPropertyGroupProperty.objects.create(
|
||||
property_group=self.property_group,
|
||||
name="optional_field",
|
||||
property_type="String",
|
||||
is_required=False,
|
||||
description="An optional string field",
|
||||
)
|
||||
|
||||
# Create event definition and link to property group
|
||||
self.event_def = EventDefinition.objects.create(team=self.team, project=self.project, name="test_event")
|
||||
|
||||
EventSchema.objects.create(event_definition=self.event_def, property_group=self.property_group)
|
||||
|
||||
# Create event with all optional fields
|
||||
self.optional_event_def = EventDefinition.objects.create(
|
||||
team=self.team, project=self.project, name="optional_event"
|
||||
)
|
||||
|
||||
optional_property_group = SchemaPropertyGroup.objects.create(
|
||||
team=self.team, project=self.project, name="Optional Properties"
|
||||
)
|
||||
|
||||
SchemaPropertyGroupProperty.objects.create(
|
||||
property_group=optional_property_group,
|
||||
name="optional_only",
|
||||
property_type="String",
|
||||
is_required=False,
|
||||
)
|
||||
|
||||
EventSchema.objects.create(event_definition=self.optional_event_def, property_group=optional_property_group)
|
||||
|
||||
# Create event with no schema (all properties allowed)
|
||||
self.untyped_event_def = EventDefinition.objects.create(
|
||||
team=self.team, project=self.project, name="untyped_event"
|
||||
)
|
||||
|
||||
def _generate_typescript(self) -> str:
|
||||
"""Generate TypeScript definitions by calling the actual API endpoint"""
|
||||
response = self.client.get(f"/api/projects/{self.project.id}/event_definitions/typescript/")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
return response.json()["content"]
|
||||
|
||||
def test_typescript_allows_additional_properties(self):
|
||||
"""
|
||||
Critical test: Verify that additional properties beyond schema
|
||||
are allowed while required properties are still validated.
|
||||
|
||||
This is the core functionality that prevents "excess property checking"
|
||||
errors in TypeScript while maintaining type safety for required fields.
|
||||
|
||||
Uses the real posthog-js package to ensure compatibility with actual types.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
|
||||
# Generate TypeScript
|
||||
ts_content = self._generate_typescript()
|
||||
|
||||
# Create minimal package.json to install only required dependencies
|
||||
package_json = tmpdir_path / "package.json"
|
||||
package_json.write_text('{"dependencies": {"typescript": "^5.0.0", "posthog-js": "^1.0.0"}}')
|
||||
install_result = subprocess.run(
|
||||
["pnpm", "install", "--no-frozen-lockfile"],
|
||||
cwd=str(tmpdir_path),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if install_result.returncode != 0:
|
||||
self.fail(
|
||||
f"Failed to install dependencies:\n"
|
||||
f"STDOUT: {install_result.stdout}\n"
|
||||
f"STDERR: {install_result.stderr}"
|
||||
)
|
||||
|
||||
# Write generated types (using real posthog-js)
|
||||
types_file = tmpdir_path / "posthog-typed.ts"
|
||||
types_file.write_text(ts_content)
|
||||
|
||||
# Create test file that exercises all type scenarios
|
||||
test_file = tmpdir_path / "test.ts"
|
||||
test_file.write_text(
|
||||
"""
|
||||
import posthog, { EventName } from './posthog-typed'
|
||||
|
||||
// ========================================
|
||||
// TEST 1: Additional properties are allowed
|
||||
// ========================================
|
||||
|
||||
// ✅ Should compile: required field + extra properties (CRITICAL TEST)
|
||||
posthog.capture('test_event', {
|
||||
required_field: 123,
|
||||
optional_field: 'test',
|
||||
extra_property: 'this should be allowed',
|
||||
another_extra: true,
|
||||
nested_extra: { foo: 'bar' }
|
||||
})
|
||||
|
||||
// ✅ Should compile: only required field
|
||||
posthog.capture('test_event', {
|
||||
required_field: 456
|
||||
})
|
||||
|
||||
// ========================================
|
||||
// TEST 2: Required properties are validated
|
||||
// ========================================
|
||||
|
||||
// ❌ Should fail: missing required field
|
||||
// @ts-expect-error
|
||||
posthog.capture('test_event', {
|
||||
optional_field: 'test'
|
||||
})
|
||||
|
||||
// ❌ Should fail: wrong type for required field
|
||||
// @ts-expect-error
|
||||
posthog.capture('test_event', {
|
||||
required_field: 'string not allowed'
|
||||
})
|
||||
|
||||
// ========================================
|
||||
// TEST 3: Events with all optional properties
|
||||
// ========================================
|
||||
|
||||
// ✅ Should compile: no properties needed
|
||||
posthog.capture('optional_event')
|
||||
|
||||
// ✅ Should compile: with properties
|
||||
posthog.capture('optional_event', {
|
||||
optional_only: 'value',
|
||||
extra_field: 123
|
||||
})
|
||||
|
||||
// ========================================
|
||||
// TEST 4: Untyped events accept anything
|
||||
// ========================================
|
||||
|
||||
// ✅ Should compile: any properties
|
||||
posthog.capture('untyped_event', {
|
||||
anything: 'goes',
|
||||
here: 123
|
||||
})
|
||||
|
||||
// ✅ Should compile: no properties
|
||||
posthog.capture('untyped_event')
|
||||
|
||||
// ========================================
|
||||
// TEST 5: Undefined events work flexibly
|
||||
// ========================================
|
||||
|
||||
// ✅ Should compile: custom event with properties
|
||||
posthog.capture('custom_undefined_event', {
|
||||
any: 'properties',
|
||||
work: 'here'
|
||||
})
|
||||
|
||||
// ✅ Should compile: custom event without properties
|
||||
posthog.capture('another_custom_event')
|
||||
|
||||
// ========================================
|
||||
// TEST 6: String variables are blocked
|
||||
// ========================================
|
||||
|
||||
// ❌ Should fail: broad string type not allowed
|
||||
let stringVar: string = 'test_event'
|
||||
// @ts-expect-error
|
||||
posthog.capture(stringVar)
|
||||
|
||||
// ✅ Should compile: EventName type works
|
||||
let typedVar: EventName = 'test_event'
|
||||
posthog.capture(typedVar, { required_field: 789 })
|
||||
|
||||
// ✅ Should compile: const infers literal type
|
||||
const constVar = 'test_event'
|
||||
posthog.capture(constVar, { required_field: 999 })
|
||||
|
||||
// ========================================
|
||||
// TEST 7: captureRaw bypasses all checking
|
||||
// ========================================
|
||||
|
||||
// ✅ Should compile: missing required fields is OK
|
||||
posthog.captureRaw('test_event', {
|
||||
optional_field: 'test'
|
||||
})
|
||||
|
||||
// ✅ Should compile: wrong types are OK
|
||||
posthog.captureRaw('test_event', {
|
||||
required_field: 'string is fine here'
|
||||
})
|
||||
|
||||
// ✅ Should compile: string variables work
|
||||
posthog.captureRaw(stringVar, { any: 'data' })
|
||||
"""
|
||||
)
|
||||
|
||||
# Create tsconfig.json
|
||||
tsconfig_file = tmpdir_path / "tsconfig.json"
|
||||
tsconfig_file.write_text(
|
||||
"""
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ESNext",
|
||||
"lib": ["ES2020", "DOM"],
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"moduleResolution": "node"
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
# Run TypeScript compiler using pnpm
|
||||
result = subprocess.run(
|
||||
["pnpm", "exec", "tsc", "--noEmit", "--project", str(tsconfig_file)],
|
||||
cwd=str(tmpdir_path),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Assert compilation succeeded
|
||||
self.assertEqual(
|
||||
result.returncode,
|
||||
0,
|
||||
f"TypeScript compilation failed. This indicates the type system is broken.\n\n"
|
||||
f"STDOUT:\n{result.stdout}\n\n"
|
||||
f"STDERR:\n{result.stderr}\n\n"
|
||||
f"Generated TypeScript file location: {types_file}",
|
||||
)
|
||||
@@ -159,6 +159,10 @@ class AutoProjectMiddleware:
|
||||
self.token_allowlist = PROJECT_SWITCHING_TOKEN_ALLOWLIST
|
||||
|
||||
def __call__(self, request: HttpRequest):
|
||||
# Skip project switching for CLI authorization page
|
||||
if request.path.startswith("/cli/authorize"):
|
||||
return self.get_response(request)
|
||||
|
||||
if request.user.is_authenticated:
|
||||
path_parts = request.path.strip("/").split("/")
|
||||
project_id_in_url = None
|
||||
|
||||