fix(macos): fix race condition in protocol handlers (#1537)

* store protocol functions globally instead of directly on webview

* formatting

* more formatting

* more formatting

* formatting

* Apply suggestions from code review

Co-authored-by: Lucas Fernandes Nogueira <lucas@tauri.app>

* change file

* Update .changes/macos-global-protocol-handlers.md

Co-authored-by: Lucas Fernandes Nogueira <lucas@tauri.app>

---------

Co-authored-by: Lucas Fernandes Nogueira <lucas@tauri.app>
This commit is contained in:
Brendan Allan
2025-04-07 22:17:10 +08:00
committed by GitHub
parent 2d753c6482
commit 78b83a0d8a
3 changed files with 46 additions and 34 deletions

View File

@@ -0,0 +1,5 @@
---
wry: patch
---
Moved protocol handler functions to a thread local instead of storing them as ivars to prevent a race condition between webview close and custom protocol handling.

View File

@@ -24,7 +24,7 @@ use objc2_foundation::{
};
use objc2_web_kit::{WKURLSchemeHandler, WKURLSchemeTask};
use crate::{wkwebview::WEBVIEW_IDS, RequestAsyncResponder, WryWebView};
use crate::{wkwebview::WEBVIEW_STATE, RequestAsyncResponder, WryWebView};
pub fn create(name: &str) -> &AnyClass {
unsafe {
@@ -33,8 +33,8 @@ pub fn create(name: &str) -> &AnyClass {
let cls = ClassBuilder::new(scheme_name, NSObject::class());
match cls {
Some(mut cls) => {
cls.add_ivar::<*mut c_void>(c"function");
cls.add_ivar::<*mut c_char>(c"webview_id");
cls.add_ivar::<usize>(c"protocol_index");
cls.add_method(
objc2::sel!(webView:startURLSchemeTask:),
start_task as extern "C" fn(_, _, _, _),
@@ -72,12 +72,16 @@ extern "C" fn start_task(
.ok()
.unwrap_or_default();
let ivar = this.class().instance_variable(c"function").unwrap();
let function: &*mut c_void = ivar.load(this);
if !function.is_null() {
let function = &mut *(*function
as *mut Box<dyn Fn(crate::WebViewId, Request<Vec<u8>>, RequestAsyncResponder)>);
let ivar = this.class().instance_variable(c"protocol_index").unwrap();
let protocol_index: usize = *ivar.load(this);
let function = WEBVIEW_STATE.with_borrow(|v| {
v.get(webview_id)
.and_then(|v| v.protocol_ptrs.get(protocol_index))
.cloned()
});
if let Some(function) = function {
// Get url request
let request = task.request();
let url = request.URL().unwrap();
@@ -143,7 +147,7 @@ extern "C" fn start_task(
};
fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
if !WEBVIEW_STATE.with_borrow(|s| s.contains_key(webview_id)) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
@@ -284,15 +288,14 @@ extern "C" fn start_task(
}))
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
{
let ids = WEBVIEW_IDS.lock().unwrap();
if ids.contains(webview_id) {
WEBVIEW_STATE.with_borrow_mut(|ids| {
if ids.contains_key(webview_id) {
webview.remove_custom_task_key(task_key);
Ok(())
} else {
Err(crate::Error::CustomProtocolTaskInvalid)
}
}
})
}
#[cfg(feature = "tracing")]
@@ -314,6 +317,7 @@ extern "C" fn start_task(
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
function(
webview_id,
final_request,
@@ -327,7 +331,7 @@ extern "C" fn start_task(
tracing::warn!(
"Either WebView or WebContext instance is dropped! This handler shouldn't be called."
);
}
};
}
}

View File

@@ -72,12 +72,13 @@ use raw_window_handle::{HasWindowHandle, RawWindowHandle};
use std::{
cell::RefCell,
collections::HashSet,
ffi::{c_void, CString},
collections::HashMap,
ffi::{CStr, CString},
net::Ipv4Addr,
os::raw::c_char,
panic::AssertUnwindSafe,
ptr::{null_mut, NonNull},
rc::Rc,
str::{self, FromStr},
sync::{Arc, Mutex},
time::Duration,
@@ -100,7 +101,14 @@ use http::Request;
use crate::util::Counter;
static COUNTER: Counter = Counter::new();
static WEBVIEW_IDS: Lazy<Mutex<HashSet<String>>> = Lazy::new(Default::default);
thread_local! {
static WEBVIEW_STATE: RefCell<HashMap<String, WebViewState>> = Default::default();
}
struct WebViewState {
pub protocol_ptrs: Vec<Rc<dyn Fn(crate::WebViewId, Request<Vec<u8>>, RequestAsyncResponder)>>,
}
#[derive(Debug, Default, Copy, Clone)]
pub struct PrintMargin {
@@ -140,7 +148,6 @@ pub(crate) struct InnerWebView {
#[allow(dead_code)]
// We need this the keep the reference count
ui_delegate: Retained<WryWebViewUIDelegate>,
protocol_ptrs: Vec<*mut Box<dyn Fn(crate::WebViewId, Request<Vec<u8>>, RequestAsyncResponder)>>,
#[cfg(target_os = "macos")]
// We need this to update the traffic light inset
parent_view: Option<Retained<WryWebViewParent>>,
@@ -192,10 +199,6 @@ impl InnerWebView {
.map(|id| id.to_string())
.unwrap_or_else(|| COUNTER.next().to_string());
let mut wv_ids = WEBVIEW_IDS.lock().unwrap();
wv_ids.insert(webview_id.clone());
drop(wv_ids);
// Safety: objc runtime calls are unsafe
unsafe {
let config = WKWebViewConfiguration::new(mtm);
@@ -227,12 +230,15 @@ impl InnerWebView {
for (name, function) in attributes.custom_protocols {
let url_scheme_handler_cls = url_scheme_handler::create(&name);
let handler: *mut AnyObject = objc2::msg_send![url_scheme_handler_cls, new];
let function = Box::into_raw(Box::new(function));
protocol_ptrs.push(function);
let protocol_index = protocol_ptrs.len();
protocol_ptrs.push(Rc::from(function));
let ivar = (*handler).class().instance_variable(c"function").unwrap();
let ivar_delegate = ivar.load_mut(&mut *handler);
*ivar_delegate = function as *mut _ as *mut c_void;
let ivar = (*handler)
.class()
.instance_variable(CStr::from_bytes_with_nul(b"protocol_index\0").unwrap())
.unwrap();
let ivar_delegate: &mut usize = ivar.load_mut(&mut *handler);
*ivar_delegate = protocol_index;
let ivar = (*handler).class().instance_variable(c"webview_id").unwrap();
let ivar_delegate: &mut *mut c_char = ivar.load_mut(&mut *handler);
@@ -249,6 +255,10 @@ impl InnerWebView {
}
}
WEBVIEW_STATE.with_borrow_mut(|wv_ids| {
wv_ids.insert(webview_id.clone(), WebViewState { protocol_ptrs });
});
// WebView and manager
let manager = config.userContentController();
let webview = WryWebView::alloc(mtm).set_ivars(WryWebViewIvars {
@@ -509,7 +519,6 @@ impl InnerWebView {
navigation_policy_delegate,
download_delegate,
ui_delegate,
protocol_ptrs,
is_child,
#[cfg(target_os = "macos")]
parent_view: None,
@@ -1084,7 +1093,7 @@ pub fn platform_webview_version() -> Result<String> {
impl Drop for InnerWebView {
fn drop(&mut self) {
WEBVIEW_IDS.lock().unwrap().remove(&self.id);
WEBVIEW_STATE.with_borrow_mut(|v| v.remove(&self.id));
// We need to drop handler closures here
unsafe {
@@ -1097,12 +1106,6 @@ impl Drop for InnerWebView {
.removeScriptMessageHandlerForName(&ipc);
}
for ptr in self.protocol_ptrs.iter() {
if !ptr.is_null() {
drop(Box::from_raw(*ptr));
}
}
// Remove webview from window's NSView before dropping.
self.webview.removeFromSuperview();
self.webview.retain();