Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(jupyter): move ZeroMQ server to a separate thread #24373

Merged
merged 6 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
renames
  • Loading branch information
bartlomieju committed Jul 1, 2024
commit 40973a3293a498009ecf312fb13e2a87377de0fc
86 changes: 40 additions & 46 deletions cli/tools/jupyter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 152,12 @@ pub async fn kernel(
let (startup_data_tx, startup_data_rx) =
oneshot::channel::<server::StartupData>();

let mut repl_session_proxy = ReplSessionProxy {
let mut repl_session_proxy = JupyterReplSession {
repl_session,
rx: rx1,
tx: tx2,
};
let repl_session_proxy_channels =
ReplSessionProxyChannels { tx: tx1, rx: rx2 };
let repl_session_proxy_channels = JupyterReplProxy { tx: tx1, rx: rx2 };

let join_handle = std::thread::spawn(move || {
let fut = server::JupyterServer::start(
Expand All @@ -174,7 173,6 @@ pub async fn kernel(
let Ok(startup_data) = startup_data_rx.await else {
bail!("Failed to acquire startup data");
};
// Store `iopub_connection` in the op state so it's accessible to the runtime API.
{
let op_state_rc =
repl_session_proxy.repl_session.worker.js_runtime.op_state();
Expand All @@ -197,7 195,7 @@ pub async fn kernel(
Ok(())
}

pub enum ReplSessionProxyRequest {
pub enum JupyterReplRequest {
LspCompletions {
line_text: String,
position: usize,
Expand All @@ -222,7 220,7 @@ pub enum ReplSessionProxyRequest {
},
}

pub enum ReplSessionProxyResponse {
pub enum JupyterReplResponse {
LspCompletions(Vec<ReplCompletionItem>),
JsGetProperties(Option<cdp::GetPropertiesResponse>),
JsEvaluate(Option<cdp::EvaluateResponse>),
Expand All @@ -232,23 230,22 @@ pub enum ReplSessionProxyResponse {
JsCallFunctionOn(Option<cdp::CallFunctionOnResponse>),
}

pub struct ReplSessionProxyChannels {
tx: mpsc::UnboundedSender<ReplSessionProxyRequest>,
rx: mpsc::UnboundedReceiver<ReplSessionProxyResponse>,
pub struct JupyterReplProxy {
tx: mpsc::UnboundedSender<JupyterReplRequest>,
rx: mpsc::UnboundedReceiver<JupyterReplResponse>,
}

impl ReplSessionProxyChannels {
impl JupyterReplProxy {
pub async fn lsp_completions(
&mut self,
line_text: String,
position: usize,
) -> Vec<ReplCompletionItem> {
let _ = self.tx.send(ReplSessionProxyRequest::LspCompletions {
let _ = self.tx.send(JupyterReplRequest::LspCompletions {
line_text,
position,
});
let Some(ReplSessionProxyResponse::LspCompletions(resp)) =
self.rx.recv().await
let Some(JupyterReplResponse::LspCompletions(resp)) = self.rx.recv().await
else {
unreachable!()
};
Expand All @@ -261,9 258,8 @@ impl ReplSessionProxyChannels {
) -> Option<cdp::GetPropertiesResponse> {
let _ = self
.tx
.send(ReplSessionProxyRequest::JsGetProperties { object_id });
let Some(ReplSessionProxyResponse::JsGetProperties(resp)) =
self.rx.recv().await
.send(JupyterReplRequest::JsGetProperties { object_id });
let Some(JupyterReplResponse::JsGetProperties(resp)) = self.rx.recv().await
else {
unreachable!()
};
Expand All @@ -274,8 270,8 @@ impl ReplSessionProxyChannels {
&mut self,
expr: String,
) -> Option<cdp::EvaluateResponse> {
let _ = self.tx.send(ReplSessionProxyRequest::JsEvaluate { expr });
let Some(ReplSessionProxyResponse::JsEvaluate(resp)) = self.rx.recv().await
let _ = self.tx.send(JupyterReplRequest::JsEvaluate { expr });
let Some(JupyterReplResponse::JsEvaluate(resp)) = self.rx.recv().await
else {
unreachable!()
};
Expand All @@ -285,10 281,8 @@ impl ReplSessionProxyChannels {
pub async fn global_lexical_scope_names(
&mut self,
) -> cdp::GlobalLexicalScopeNamesResponse {
let _ = self
.tx
.send(ReplSessionProxyRequest::JsGlobalLexicalScopeNames);
let Some(ReplSessionProxyResponse::JsGlobalLexicalScopeNames(resp)) =
let _ = self.tx.send(JupyterReplRequest::JsGlobalLexicalScopeNames);
let Some(JupyterReplResponse::JsGlobalLexicalScopeNames(resp)) =
self.rx.recv().await
else {
unreachable!()
Expand All @@ -302,8 296,8 @@ impl ReplSessionProxyChannels {
) -> Result<repl::TsEvaluateResponse, AnyError> {
let _ = self
.tx
.send(ReplSessionProxyRequest::JsEvaluateLineWithObjectWrapping { line });
let Some(ReplSessionProxyResponse::JsEvaluateLineWithObjectWrapping(resp)) =
.send(JupyterReplRequest::JsEvaluateLineWithObjectWrapping { line });
let Some(JupyterReplResponse::JsEvaluateLineWithObjectWrapping(resp)) =
self.rx.recv().await
else {
unreachable!()
Expand All @@ -316,11 310,11 @@ impl ReplSessionProxyChannels {
function_declaration: String,
args: Vec<cdp::RemoteObject>,
) -> Result<cdp::CallFunctionOnResponse, AnyError> {
let _ = self.tx.send(ReplSessionProxyRequest::JsCallFunctionOnArgs {
let _ = self.tx.send(JupyterReplRequest::JsCallFunctionOnArgs {
function_declaration,
args,
});
let Some(ReplSessionProxyResponse::JsCallFunctionOnArgs(resp)) =
let Some(JupyterReplResponse::JsCallFunctionOnArgs(resp)) =
self.rx.recv().await
else {
unreachable!()
Expand All @@ -336,8 330,8 @@ impl ReplSessionProxyChannels {
) -> Option<cdp::CallFunctionOnResponse> {
let _ = self
.tx
.send(ReplSessionProxyRequest::JsCallFunctionOn { arg0, arg1 });
let Some(ReplSessionProxyResponse::JsCallFunctionOn(resp)) =
.send(JupyterReplRequest::JsCallFunctionOn { arg0, arg1 });
let Some(JupyterReplResponse::JsCallFunctionOn(resp)) =
self.rx.recv().await
else {
unreachable!()
Expand All @@ -346,53 340,53 @@ impl ReplSessionProxyChannels {
}
}

pub struct ReplSessionProxy {
pub struct JupyterReplSession {
repl_session: repl::ReplSession,
rx: mpsc::UnboundedReceiver<ReplSessionProxyRequest>,
tx: mpsc::UnboundedSender<ReplSessionProxyResponse>,
rx: mpsc::UnboundedReceiver<JupyterReplRequest>,
tx: mpsc::UnboundedSender<JupyterReplResponse>,
}

impl ReplSessionProxy {
impl JupyterReplSession {
pub async fn start(&mut self) {
loop {
let Some(msg) = self.rx.recv().await else {
break;
};
let resp = match msg {
ReplSessionProxyRequest::LspCompletions {
JupyterReplRequest::LspCompletions {
line_text,
position,
} => ReplSessionProxyResponse::LspCompletions(
} => JupyterReplResponse::LspCompletions(
self.lsp_completions(&line_text, position).await,
),
ReplSessionProxyRequest::JsGetProperties { object_id } => {
ReplSessionProxyResponse::JsGetProperties(
JupyterReplRequest::JsGetProperties { object_id } => {
JupyterReplResponse::JsGetProperties(
self.get_properties(object_id).await,
)
}
ReplSessionProxyRequest::JsEvaluate { expr } => {
ReplSessionProxyResponse::JsEvaluate(self.evaluate(expr).await)
JupyterReplRequest::JsEvaluate { expr } => {
JupyterReplResponse::JsEvaluate(self.evaluate(expr).await)
}
ReplSessionProxyRequest::JsGlobalLexicalScopeNames => {
ReplSessionProxyResponse::JsGlobalLexicalScopeNames(
JupyterReplRequest::JsGlobalLexicalScopeNames => {
JupyterReplResponse::JsGlobalLexicalScopeNames(
self.global_lexical_scope_names().await,
)
}
ReplSessionProxyRequest::JsEvaluateLineWithObjectWrapping { line } => {
ReplSessionProxyResponse::JsEvaluateLineWithObjectWrapping(
JupyterReplRequest::JsEvaluateLineWithObjectWrapping { line } => {
JupyterReplResponse::JsEvaluateLineWithObjectWrapping(
self.evaluate_line_with_object_wrapping(&line).await,
)
}
ReplSessionProxyRequest::JsCallFunctionOnArgs {
JupyterReplRequest::JsCallFunctionOnArgs {
function_declaration,
args,
} => ReplSessionProxyResponse::JsCallFunctionOnArgs(
} => JupyterReplResponse::JsCallFunctionOnArgs(
self
.call_function_on_args(function_declaration, &args)
.await,
),
ReplSessionProxyRequest::JsCallFunctionOn { arg0, arg1 } => {
ReplSessionProxyResponse::JsCallFunctionOn(
JupyterReplRequest::JsCallFunctionOn { arg0, arg1 } => {
JupyterReplResponse::JsCallFunctionOn(
self.call_function_on(arg0, arg1).await,
)
}
Expand Down
18 changes: 9 additions & 9 deletions cli/tools/jupyter/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 31,13 @@ use jupyter_runtime::ReplyError;
use jupyter_runtime::ReplyStatus;
use jupyter_runtime::StreamContent;

use super::ReplSessionProxyChannels;
use super::JupyterReplProxy;

pub struct JupyterServer {
execution_count: usize,
last_execution_request: Arc<Mutex<Option<JupyterMessage>>>,
iopub_connection: Arc<Mutex<KernelIoPubConnection>>,
repl_session_proxy: ReplSessionProxyChannels,
repl_session_proxy: JupyterReplProxy,
}

pub struct StartupData {
Expand All @@ -49,7 49,7 @@ impl JupyterServer {
pub async fn start(
connection_info: ConnectionInfo,
mut stdio_rx: mpsc::UnboundedReceiver<StreamContent>,
repl_session_proxy: ReplSessionProxyChannels,
repl_session_proxy: JupyterReplProxy,
setup_tx: oneshot::Sender<StartupData>,
) -> Result<(), AnyError> {
let mut heartbeat =
Expand Down Expand Up @@ -643,7 643,7 @@ fn kernel_info() -> messaging::KernelInfoReply {
}

async fn publish_result(
repl_session_proxy: &mut ReplSessionProxyChannels,
repl_session_proxy: &mut JupyterReplProxy,
evaluate_result: &cdp::RemoteObject,
execution_count: usize,
) -> Result<Option<HashMap<String, serde_json::Value>>, AnyError> {
Expand Down Expand Up @@ -696,14 696,14 @@ fn is_word_boundary(c: char) -> bool {

// TODO(bartlomieju): dedup with repl::editor
async fn get_global_lexical_scope_names(
repl_session_proxy: &mut ReplSessionProxyChannels,
repl_session_proxy: &mut JupyterReplProxy,
) -> Vec<String> {
repl_session_proxy.global_lexical_scope_names().await.names
}

// TODO(bartlomieju): dedup with repl::editor
async fn get_expression_property_names(
repl_session_proxy: &mut ReplSessionProxyChannels,
repl_session_proxy: &mut JupyterReplProxy,
expr: &str,
) -> Vec<String> {
// try to get the properties from the expression
Expand Down Expand Up @@ -733,7 733,7 @@ async fn get_expression_property_names(

// TODO(bartlomieju): dedup with repl::editor
async fn get_expression_type(
repl_session_proxy: &mut ReplSessionProxyChannels,
repl_session_proxy: &mut JupyterReplProxy,
expr: &str,
) -> Option<String> {
evaluate_expression(repl_session_proxy, expr)
Expand All @@ -743,7 743,7 @@ async fn get_expression_type(

// TODO(bartlomieju): dedup with repl::editor
async fn get_object_expr_properties(
repl_session_proxy: &mut ReplSessionProxyChannels,
repl_session_proxy: &mut JupyterReplProxy,
object_expr: &str,
) -> Option<Vec<String>> {
let evaluate_result =
Expand All @@ -763,7 763,7 @@ async fn get_object_expr_properties(

// TODO(bartlomieju): dedup with repl::editor
async fn evaluate_expression(
repl_session_proxy: &mut ReplSessionProxyChannels,
repl_session_proxy: &mut JupyterReplProxy,
expr: &str,
) -> Option<cdp::EvaluateResponse> {
let evaluate_response = repl_session_proxy.evaluate(expr.to_string()).await?;
Expand Down