From 294c5a31a74da3f8879ab42f53577f60df3368ab Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sat, 27 Apr 2024 08:47:27 +0200 Subject: [PATCH] Add snapshotting mechanism: --- Cargo.lock | 26 ++++++++++- Cargo.toml | 2 +- crates/red_knot/Cargo.toml | 2 +- crates/red_knot/src/cache.rs | 11 +++-- crates/red_knot/src/cancellation.rs | 2 +- crates/red_knot/src/db.rs | 45 +++++++----------- crates/red_knot/src/db/jars.rs | 33 +++++++++++++ crates/red_knot/src/db/query.rs | 6 +++ crates/red_knot/src/db/storage.rs | 61 ++++++++++++++++++++++++ crates/red_knot/src/lint.rs | 36 +++++++------- crates/red_knot/src/main.rs | 9 ++-- crates/red_knot/src/module.rs | 70 +++++++++++++++++----------- crates/red_knot/src/parse.rs | 10 ++-- crates/red_knot/src/program/check.rs | 58 +++++++++++------------ crates/red_knot/src/program/mod.rs | 28 +++++------ crates/red_knot/src/source.rs | 8 ++-- crates/red_knot/src/symbols.rs | 10 ++-- crates/red_knot/src/types/eval.rs | 34 +++++++------- crates/ruff_server/Cargo.toml | 2 +- 19 files changed, 289 insertions(+), 164 deletions(-) create mode 100644 crates/red_knot/src/db/jars.rs create mode 100644 crates/red_knot/src/db/query.rs create mode 100644 crates/red_knot/src/db/storage.rs diff --git a/Cargo.lock b/Cargo.lock index 64be7fff07315c..71ad25fce5b811 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -501,6 +501,19 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.12" @@ -529,6 +542,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.19" @@ -1805,7 +1827,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bitflags 2.5.0", - "crossbeam-channel", + "crossbeam", "ctrlc", "dashmap", "hashbrown 0.14.3", @@ -2342,7 +2364,7 @@ name = "ruff_server" version = "0.2.2" dependencies = [ "anyhow", - "crossbeam-channel", + "crossbeam", "insta", "jod-thread", "libc", diff --git a/Cargo.toml b/Cargo.toml index 468194b058e66e..d1d9f5d2135150 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ console_error_panic_hook = { version = "0.1.7" } console_log = { version = "1.0.0" } countme = { version = "3.0.1" } criterion = { version = "0.5.1", default-features = false } -crossbeam-channel = { version = "0.5.12" } +crossbeam = { version = "0.8.4" } dashmap = { version = "5.5.3" } dirs = { version = "5.0.0" } drop_bomb = { version = "0.1.5" } diff --git a/crates/red_knot/Cargo.toml b/crates/red_knot/Cargo.toml index 7907c8340a4e45..382d7d3062b465 100644 --- a/crates/red_knot/Cargo.toml +++ b/crates/red_knot/Cargo.toml @@ -22,7 +22,7 @@ ruff_notebook = { path = "../ruff_notebook" } anyhow = { workspace = true } bitflags = { workspace = true } ctrlc = "3.4.4" -crossbeam-channel = { workspace = true } +crossbeam = { workspace = true } dashmap = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } diff --git a/crates/red_knot/src/cache.rs b/crates/red_knot/src/cache.rs index 547c48c0f02ddc..e3a032a40fc9cf 100644 --- a/crates/red_knot/src/cache.rs +++ b/crates/red_knot/src/cache.rs @@ -2,6 +2,7 @@ use std::fmt::Formatter; use std::hash::Hash; use std::sync::atomic::{AtomicUsize, Ordering}; +use crate::db::QueryResult; use dashmap::mapref::entry::Entry; use crate::FxDashMap; @@ -27,11 +28,11 @@ where } } - pub fn get(&self, key: &K, compute: F) -> V + pub fn get(&self, key: &K, compute: F) -> QueryResult where - F: FnOnce(&K) -> V, + F: FnOnce(&K) -> QueryResult, { - match self.map.entry(key.clone()) { + Ok(match self.map.entry(key.clone()) { Entry::Occupied(cached) => { self.statistics.hit(); @@ -40,11 +41,11 @@ where Entry::Vacant(vacant) => { self.statistics.miss(); - let value = compute(key); + let value = compute(key)?; vacant.insert(value.clone()); value } - } + }) } pub fn set(&mut self, key: K, value: V) { diff --git a/crates/red_knot/src/cancellation.rs b/crates/red_knot/src/cancellation.rs index 0620d86ab5e761..e1b9084e6e46d4 100644 --- a/crates/red_knot/src/cancellation.rs +++ b/crates/red_knot/src/cancellation.rs @@ -1,6 +1,6 @@ use std::sync::{Arc, Condvar, Mutex}; -#[derive(Debug, Default)] +#[derive(Debug, Clone, Default)] pub struct CancellationTokenSource { signal: Arc<(Mutex, Condvar)>, } diff --git a/crates/red_knot/src/db.rs b/crates/red_knot/src/db.rs index 5f4b8d606be8a8..f654f67fcdcbc3 100644 --- a/crates/red_knot/src/db.rs +++ b/crates/red_knot/src/db.rs @@ -1,3 +1,7 @@ +mod jars; +mod query; +mod storage; + use std::path::Path; use std::sync::Arc; @@ -9,32 +13,35 @@ use crate::source::{Source, SourceStorage}; use crate::symbols::{SymbolId, SymbolTable, SymbolTablesStorage}; use crate::types::{Type, TypeStore}; +pub use jars::HasJar; +pub use query::{QueryError, QueryResult}; + pub trait SourceDb { // queries fn file_id(&self, path: &std::path::Path) -> FileId; fn file_path(&self, file_id: FileId) -> Arc; - fn source(&self, file_id: FileId) -> Source; + fn source(&self, file_id: FileId) -> QueryResult; - fn parse(&self, file_id: FileId) -> Parsed; + fn parse(&self, file_id: FileId) -> QueryResult; - fn lint_syntax(&self, file_id: FileId) -> Diagnostics; + fn lint_syntax(&self, file_id: FileId) -> QueryResult; } pub trait SemanticDb: SourceDb { // queries - fn resolve_module(&self, name: ModuleName) -> Option; + fn resolve_module(&self, name: ModuleName) -> QueryResult>; - fn file_to_module(&self, file_id: FileId) -> Option; + fn file_to_module(&self, file_id: FileId) -> QueryResult>; - fn path_to_module(&self, path: &Path) -> Option; + fn path_to_module(&self, path: &Path) -> QueryResult>; - fn symbol_table(&self, file_id: FileId) -> Arc; + fn symbol_table(&self, file_id: FileId) -> QueryResult>; - fn eval_symbol(&self, file_id: FileId, symbol_id: SymbolId) -> Type; + fn eval_symbol(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult; - fn lint_semantic(&self, file_id: FileId) -> Diagnostics; + fn lint_semantic(&self, file_id: FileId) -> QueryResult; // mutations @@ -60,26 +67,6 @@ pub struct SemanticJar { pub lint_semantic: LintSemanticStorage, } -/// Gives access to a specific jar in the database. -/// -/// Nope, the terminology isn't borrowed from Java but from Salsa , -/// which is an analogy to storing the salsa in different jars. -/// -/// The basic idea is that each crate can define its own jar and the jars can be combined to a single -/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of -/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels. -/// -/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache). -/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide -/// to use a macro to generate the queries. -pub trait HasJar { - /// Gives a read-only reference to the jar. - fn jar(&self) -> &T; - - /// Gives a mutable reference to the jar. - fn jar_mut(&mut self) -> &mut T; -} - #[cfg(test)] pub(crate) mod tests { use std::path::Path; diff --git a/crates/red_knot/src/db/jars.rs b/crates/red_knot/src/db/jars.rs new file mode 100644 index 00000000000000..8431f90e477669 --- /dev/null +++ b/crates/red_knot/src/db/jars.rs @@ -0,0 +1,33 @@ +use crate::db::QueryResult; + +/// Gives access to a specific jar in the database. +/// +/// Nope, the terminology isn't borrowed from Java but from Salsa , +/// which is an analogy to storing the salsa in different jars. +/// +/// The basic idea is that each crate can define its own jar and the jars can be combined to a single +/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of +/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels. +/// +/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache). +/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide +/// to use a macro to generate the queries. +pub trait HasJar { + /// Gives a read-only reference to the jar. + fn jar(&self) -> QueryResult<&T>; + + fn jar_by_pass_cancellation(&self) -> &T; + + /// Gives a mutable reference to the jar. + fn jar_mut(&mut self) -> &mut T; +} + +pub trait HasJars { + type Jars; + + fn jars(&self) -> &QueryResult; + + fn jars_unwrap(&self) -> &Self::Jars; + + fn jars_mut(&mut self) -> &mut Self::Jars; +} diff --git a/crates/red_knot/src/db/query.rs b/crates/red_knot/src/db/query.rs new file mode 100644 index 00000000000000..55980d8a601ca6 --- /dev/null +++ b/crates/red_knot/src/db/query.rs @@ -0,0 +1,6 @@ +#[derive(Debug, Clone, Copy)] +pub enum QueryError { + Cancelled, +} + +pub type QueryResult = Result; diff --git a/crates/red_knot/src/db/storage.rs b/crates/red_knot/src/db/storage.rs new file mode 100644 index 00000000000000..b053bfe65e7134 --- /dev/null +++ b/crates/red_knot/src/db/storage.rs @@ -0,0 +1,61 @@ +use crate::cancellation::{CancellationToken, CancellationTokenSource}; +use crate::db::jars::HasJars; +use crate::db::query::{QueryError, QueryResult}; +use crossbeam::sync::WaitGroup; +use std::sync::Arc; + +pub struct JarStorage +where + T: HasJars, +{ + db: T, +} + +#[derive(Clone, Debug)] +pub struct SharedStorage +where + T: HasJars, +{ + // It's important that the wait group is declared after `jars` to ensure that `jars` is dropped first. + // See https://doc.rust-lang.org/reference/destructors.html + jars: Arc, + + /// Used to count the references to `jars`. Allows implementing [`jars_mut`] without requiring to clone `jars`. + jars_references: WaitGroup, + + cancellation_token_source: CancellationTokenSource, +} + +impl SharedStorage +where + T: HasJars, +{ + pub(super) fn jars(&self) -> QueryResult<&T::Jars> { + self.err_if_cancelled()?; + Ok(&self.jars) + } + + pub(super) fn jars_mut(&mut self) -> &mut T::Jars { + // Cancel all pending queries. + self.cancellation_token_source.cancel(); + + let existing_wait = std::mem::take(&mut self.jars_references); + existing_wait.wait(); + self.cancellation_token_source = CancellationTokenSource::new(); + + // Now all other references to `self.jars` should have been released. We can now safely return a mutable reference + // to the Arc's content. + let jars = + Arc::get_mut(&mut self.jars).expect("All references to jars should have been released"); + + jars + } + + pub(super) fn err_if_cancelled(&self) -> QueryResult<()> { + if self.cancellation_token_source.is_cancelled() { + Err(QueryError::Cancelled) + } else { + Ok(()) + } + } +} diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index e817b337e51451..f6d77314ac6f27 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -6,7 +6,7 @@ use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{ModModule, StringLiteral}; use crate::cache::KeyValueCache; -use crate::db::{HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar}; +use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar, SourceDb, SourceJar}; use crate::files::FileId; use crate::parse::Parsed; use crate::source::Source; @@ -14,19 +14,19 @@ use crate::symbols::{Definition, SymbolId, SymbolTable}; use crate::types::Type; #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn lint_syntax(db: &Db, file_id: FileId) -> Diagnostics +pub(crate) fn lint_syntax(db: &Db, file_id: FileId) -> QueryResult where Db: SourceDb + HasJar, { - let storage = &db.jar().lint_syntax; + let storage = &db.jar()?.lint_syntax; storage.get(&file_id, |file_id| { let mut diagnostics = Vec::new(); - let source = db.source(*file_id); + let source = db.source(*file_id)?; lint_lines(source.text(), &mut diagnostics); - let parsed = db.parse(*file_id); + let parsed = db.parse(*file_id)?; if parsed.errors().is_empty() { let ast = parsed.ast(); @@ -41,7 +41,7 @@ where diagnostics.extend(parsed.errors().iter().map(|err| err.to_string())); } - Diagnostics::from(diagnostics) + Ok(Diagnostics::from(diagnostics)) }) } @@ -63,16 +63,16 @@ fn lint_lines(source: &str, diagnostics: &mut Vec) { } #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn lint_semantic(db: &Db, file_id: FileId) -> Diagnostics +pub(crate) fn lint_semantic(db: &Db, file_id: FileId) -> QueryResult where Db: SemanticDb + HasJar, { - let storage = &db.jar().lint_semantic; + let storage = &db.jar()?.lint_semantic; storage.get(&file_id, |file_id| { - let source = db.source(*file_id); - let parsed = db.parse(*file_id); - let symbols = db.symbol_table(*file_id); + let source = db.source(*file_id)?; + let parsed = db.parse(*file_id)?; + let symbols = db.symbol_table(*file_id)?; let context = SemanticLintContext { file_id: *file_id, @@ -83,24 +83,24 @@ where diagnostics: RefCell::new(Vec::new()), }; - lint_unresolved_imports(&context); + lint_unresolved_imports(&context)?; - Diagnostics::from(context.diagnostics.take()) + Ok(Diagnostics::from(context.diagnostics.take())) }) } -fn lint_unresolved_imports(context: &SemanticLintContext) { +fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> { for (symbol, definition) in context.symbols().all_definitions() { match definition { Definition::Import(import) => { - let ty = context.eval_symbol(symbol); + let ty = context.eval_symbol(symbol)?; if ty.is_unknown() { context.push_diagnostic(format!("Unresolved module {}", import.module)); } } Definition::ImportFrom(import) => { - let ty = context.eval_symbol(symbol); + let ty = context.eval_symbol(symbol)?; if ty.is_unknown() { let message = if let Some(module) = import.module() { @@ -123,6 +123,8 @@ fn lint_unresolved_imports(context: &SemanticLintContext) { _ => {} } } + + Ok(()) } pub struct SemanticLintContext<'a> { @@ -151,7 +153,7 @@ impl<'a> SemanticLintContext<'a> { &self.symbols } - pub fn eval_symbol(&self, symbol_id: SymbolId) -> Type { + pub fn eval_symbol(&self, symbol_id: SymbolId) -> QueryResult { self.db.eval_symbol(self.file_id, symbol_id) } diff --git a/crates/red_knot/src/main.rs b/crates/red_knot/src/main.rs index d8b16305f72128..0c88175482d819 100644 --- a/crates/red_knot/src/main.rs +++ b/crates/red_knot/src/main.rs @@ -4,6 +4,7 @@ use std::collections::hash_map::Entry; use std::path::Path; use std::sync::Mutex; +use crossbeam::channel as crossbeam_channel; use rustc_hash::FxHashMap; use tracing::subscriber::Interest; use tracing::{Level, Metadata}; @@ -13,10 +14,10 @@ use tracing_subscriber::{Layer, Registry}; use tracing_tree::time::Uptime; use red_knot::cancellation::CancellationTokenSource; -use red_knot::db::{HasJar, SourceDb, SourceJar}; +use red_knot::db::{HasJar, QueryError, SourceDb, SourceJar}; use red_knot::files::FileId; use red_knot::module::{ModuleSearchPath, ModuleSearchPathKind}; -use red_knot::program::check::{CheckError, RayonCheckScheduler}; +use red_knot::program::check::RayonCheckScheduler; use red_knot::program::{FileChange, FileChangeKind, Program}; use red_knot::watch::FileWatcher; use red_knot::Workspace; @@ -82,7 +83,7 @@ fn main() -> anyhow::Result<()> { main_loop.run(&mut program); - let source_jar: &SourceJar = program.jar(); + let source_jar: &SourceJar = program.jar().unwrap(); dbg!(source_jar.parsed.statistics()); dbg!(source_jar.sources.statistics()); @@ -158,7 +159,7 @@ impl MainLoop { Ok(result) => sender .send(OrchestratorMessage::CheckProgramCompleted(result)) .unwrap(), - Err(CheckError::Cancelled) => sender + Err(QueryError::Cancelled) => sender .send(OrchestratorMessage::CheckProgramCancelled) .unwrap(), } diff --git a/crates/red_knot/src/module.rs b/crates/red_knot/src/module.rs index 03f80bc9445493..29e6596ac1349a 100644 --- a/crates/red_knot/src/module.rs +++ b/crates/red_knot/src/module.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use dashmap::mapref::entry::Entry; use smol_str::SmolStr; -use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; use crate::FxDashMap; @@ -15,22 +15,22 @@ use crate::FxDashMap; pub struct Module(u32); impl Module { - pub fn name(&self, db: &Db) -> ModuleName + pub fn name(&self, db: &Db) -> QueryResult where Db: HasJar, { - let modules = &db.jar().module_resolver; + let modules = &db.jar()?.module_resolver; - modules.modules.get(self).unwrap().name.clone() + Ok(modules.modules.get(self).unwrap().name.clone()) } - pub fn path(&self, db: &Db) -> ModulePath + pub fn path(&self, db: &Db) -> QueryResult where Db: HasJar, { - let modules = &db.jar().module_resolver; + let modules = &db.jar()?.module_resolver; - modules.modules.get(self).unwrap().path.clone() + Ok(modules.modules.get(self).unwrap().path.clone()) } } @@ -188,20 +188,26 @@ pub struct ModuleData { /// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient and, therefore, cannot be used as part of a query. /// For this to work with salsa, it would be necessary to intern all `ModuleName`s. #[tracing::instrument(level = "debug", skip(db))] -pub fn resolve_module(db: &Db, name: ModuleName) -> Option +pub fn resolve_module(db: &Db, name: ModuleName) -> QueryResult> where Db: SemanticDb + HasJar, { let jar = db.jar(); - let modules = &jar.module_resolver; + let modules = &jar?.module_resolver; let entry = modules.by_name.entry(name.clone()); match entry { - Entry::Occupied(entry) => Some(*entry.get()), + Entry::Occupied(entry) => Ok(Some(*entry.get())), Entry::Vacant(entry) => { - let (root_path, absolute_path) = resolve_name(&name, &modules.search_paths)?; - let normalized = absolute_path.canonicalize().ok()?; + let Some((root_path, absolute_path)) = resolve_name(&name, &modules.search_paths) + else { + return Ok(None); + }; + + let Ok(normalized) = absolute_path.canonicalize() else { + return Ok(None); + }; let file_id = db.file_id(&normalized); let path = ModulePath::new(root_path.clone(), file_id); @@ -227,7 +233,7 @@ where entry.insert_entry(id); - Some(id) + Ok(Some(id)) } } } @@ -236,7 +242,7 @@ where /// /// Returns `None` if the file is not a module in `sys.path`. #[tracing::instrument(level = "debug", skip(db))] -pub fn file_to_module(db: &Db, file: FileId) -> Option +pub fn file_to_module(db: &Db, file: FileId) -> QueryResult> where Db: SemanticDb + HasJar, { @@ -248,34 +254,42 @@ where /// /// Returns `None` if the path is not a module in `sys.path`. #[tracing::instrument(level = "debug", skip(db))] -pub fn path_to_module(db: &Db, path: &Path) -> Option +pub fn path_to_module(db: &Db, path: &Path) -> QueryResult> where Db: SemanticDb + HasJar, { - let jar = db.jar(); + let jar = db.jar()?; let modules = &jar.module_resolver; debug_assert!(path.is_absolute()); if let Some(existing) = modules.by_path.get(path) { - return Some(*existing); + return Ok(Some(*existing)); } - let (root_path, relative_path) = modules.search_paths.iter().find_map(|root| { + let Some((root_path, relative_path)) = modules.search_paths.iter().find_map(|root| { let relative_path = path.strip_prefix(root.path()).ok()?; Some((root.clone(), relative_path)) - })?; + }) else { + return Ok(None); + }; - let module_name = ModuleName::from_relative_path(relative_path)?; + let Some(module_name) = ModuleName::from_relative_path(relative_path) else { + return Ok(None); + }; // Resolve the module name to see if Python would resolve the name to the same path. // If it doesn't, then that means that multiple modules have the same in different // root paths, but that the module corresponding to the past path is in a lower priority path, // in which case we ignore it. - let module_id = resolve_module(db, module_name)?; - let module_path = module_id.path(db); + let Some(module_id) = resolve_module(db, module_name)? else { + return Ok(None); + }; + let module_path = module_id.path(db)?; if module_path.root() == &root_path { - let normalized = path.canonicalize().ok()?; + let Ok(normalized) = path.canonicalize() else { + return Ok(None); + }; let module_path = db.file_path(module_path.file()); if !module_path.starts_with(normalized) { @@ -286,15 +300,15 @@ where // ``` // The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`. // That means we need to ignore `src/foo.py` even though it resolves to the same module name. - return None; + return Ok(None); } // Path has been inserted by `resolved` - Some(module_id) + Ok(Some(module_id)) } else { // This path is for a module with the same name but in a module search path with a lower priority. // Ignore it. - None + Ok(None) } } @@ -328,7 +342,7 @@ where // TODO This needs tests // Note: Intentionally by-pass caching here. Module should not be in the cache yet. - let module = path_to_module(db, path)?; + let module = path_to_module(db, path).ok()??; // The code below is to handle the addition of `__init__.py` files. // When an `__init__.py` file is added, we need to remove all modules that are part of the same package. @@ -342,7 +356,7 @@ where return Some((module, Vec::new())); } - let Some(parent_name) = module.name(db).parent() else { + let Some(parent_name) = module.name(db).ok()?.parent() else { return Some((module, Vec::new())); }; diff --git a/crates/red_knot/src/parse.rs b/crates/red_knot/src/parse.rs index 20df9f4c8caf01..e76cd06706c137 100644 --- a/crates/red_knot/src/parse.rs +++ b/crates/red_knot/src/parse.rs @@ -6,7 +6,7 @@ use ruff_python_parser::{Mode, ParseError}; use ruff_text_size::{Ranged, TextRange}; use crate::cache::KeyValueCache; -use crate::db::{HasJar, SourceDb, SourceJar}; +use crate::db::{HasJar, QueryResult, SourceDb, SourceJar}; use crate::files::FileId; #[derive(Debug, Clone, PartialEq)] @@ -64,16 +64,16 @@ impl Parsed { } #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn parse(db: &Db, file_id: FileId) -> Parsed +pub(crate) fn parse(db: &Db, file_id: FileId) -> QueryResult where Db: SourceDb + HasJar, { - let parsed = db.jar(); + let parsed = db.jar()?; parsed.parsed.get(&file_id, |file_id| { - let source = db.source(*file_id); + let source = db.source(*file_id)?; - Parsed::from_text(source.text()) + Ok(Parsed::from_text(source.text())) }) } diff --git a/crates/red_knot/src/program/check.rs b/crates/red_knot/src/program/check.rs index 3e4a1102371dfd..241275832d5882 100644 --- a/crates/red_knot/src/program/check.rs +++ b/crates/red_knot/src/program/check.rs @@ -4,7 +4,7 @@ use rayon::max_num_threads; use rustc_hash::FxHashSet; use crate::cancellation::CancellationToken; -use crate::db::{SemanticDb, SourceDb}; +use crate::db::{QueryError, QueryResult, SemanticDb, SourceDb}; use crate::files::FileId; use crate::lint::Diagnostics; use crate::program::Program; @@ -16,7 +16,7 @@ impl Program { &self, scheduler: &dyn CheckScheduler, cancellation_token: CancellationToken, - ) -> Result, CheckError> { + ) -> QueryResult> { let check_loop = CheckFilesLoop::new(scheduler, cancellation_token); check_loop.run(self.workspace().open_files.iter().copied()) @@ -29,7 +29,7 @@ impl Program { file: FileId, scheduler: &dyn CheckScheduler, cancellation_token: CancellationToken, - ) -> Result, CheckError> { + ) -> QueryResult> { let check_loop = CheckFilesLoop::new(scheduler, cancellation_token); check_loop.run([file].into_iter()) @@ -40,14 +40,17 @@ impl Program { &self, file: FileId, context: &CheckContext, - ) -> Result { + ) -> QueryResult { context.cancelled_ok()?; - let symbol_table = self.symbol_table(file); + let symbol_table = self.symbol_table(file)?; let dependencies = symbol_table.dependencies(); if !dependencies.is_empty() { - let module_name = self.file_to_module(file).map(|module| module.name(self)); + let module_name = match self.file_to_module(file)? { + Some(module) => Some(module.name(self)?), + None => None, + }; // TODO scheduling all dependencies here is wasteful if we don't infer any types on them // but I think that's unlikely, so it is okay? @@ -58,9 +61,9 @@ impl Program { if let Some(dependency_name) = dependency_name { // TODO The resolver doesn't support native dependencies yet. - if let Some(dependency) = self.resolve_module(dependency_name) { - if dependency.path(self).root().kind().is_first_party() { - context.schedule_check_file(dependency.path(self).file()); + if let Some(dependency) = self.resolve_module(dependency_name)? { + if dependency.path(self)?.root().kind().is_first_party() { + context.schedule_check_file(dependency.path(self)?.file()); } } } @@ -70,8 +73,8 @@ impl Program { let mut diagnostics = Vec::new(); if self.workspace().is_file_open(file) { - diagnostics.extend_from_slice(&self.lint_syntax(file)); - diagnostics.extend_from_slice(&self.lint_semantic(file)); + diagnostics.extend_from_slice(&self.lint_syntax(file)?); + diagnostics.extend_from_slice(&self.lint_semantic(file)?); } Ok(Diagnostics::from(diagnostics)) @@ -151,11 +154,6 @@ impl CheckScheduler for SameThreadCheckScheduler<'_> { } } -#[derive(Debug, Clone)] -pub enum CheckError { - Cancelled, -} - #[derive(Debug)] pub struct CheckFileTask { file_id: FileId, @@ -171,7 +169,7 @@ impl CheckFileTask { .sender .send(CheckFileMessage::Completed(diagnostics)) .unwrap(), - Err(CheckError::Cancelled) => self + Err(QueryError::Cancelled) => self .context .sender .send(CheckFileMessage::Cancelled) @@ -183,13 +181,13 @@ impl CheckFileTask { #[derive(Clone, Debug)] struct CheckContext { cancellation_token: CancellationToken, - sender: crossbeam_channel::Sender, + sender: crossbeam::channel::Sender, } impl CheckContext { fn new( cancellation_token: CancellationToken, - sender: crossbeam_channel::Sender, + sender: crossbeam::channel::Sender, ) -> Self { Self { sender, @@ -208,9 +206,9 @@ impl CheckContext { self.cancellation_token.is_cancelled() } - fn cancelled_ok(&self) -> Result<(), CheckError> { + fn cancelled_ok(&self) -> QueryResult<()> { if self.is_cancelled() { - Err(CheckError::Cancelled) + Err(QueryError::Cancelled) } else { Ok(()) } @@ -235,13 +233,13 @@ impl<'a> CheckFilesLoop<'a> { } } - fn run(mut self, files: impl Iterator) -> Result, CheckError> { + fn run(mut self, files: impl Iterator) -> QueryResult> { let (sender, receiver) = if let Some(max_concurrency) = self.scheduler.max_concurrency() { - crossbeam_channel::bounded(max_concurrency.get()) + crossbeam::channel::bounded(max_concurrency.get()) } else { // The checks run on the current thread. That means it is necessary to store all messages // or we risk deadlocking when the main loop never gets a chance to read the messages. - crossbeam_channel::unbounded() + crossbeam::channel::unbounded() }; let context = CheckContext::new(self.cancellation_token.clone(), sender.clone()); @@ -255,11 +253,11 @@ impl<'a> CheckFilesLoop<'a> { fn run_impl( mut self, - receiver: crossbeam_channel::Receiver, + receiver: crossbeam::channel::Receiver, context: CheckContext, - ) -> Result, CheckError> { + ) -> QueryResult> { if self.cancellation_token.is_cancelled() { - return Err(CheckError::Cancelled); + return Err(QueryError::Cancelled); } let mut result = Vec::default(); @@ -279,7 +277,7 @@ impl<'a> CheckFilesLoop<'a> { self.queue_file(id, context.clone())?; } CheckFileMessage::Cancelled => { - return Err(CheckError::Cancelled); + return Err(QueryError::Cancelled); } } } @@ -287,9 +285,9 @@ impl<'a> CheckFilesLoop<'a> { Ok(result) } - fn queue_file(&mut self, file_id: FileId, context: CheckContext) -> Result<(), CheckError> { + fn queue_file(&mut self, file_id: FileId, context: CheckContext) -> QueryResult<()> { if context.is_cancelled() { - return Err(CheckError::Cancelled); + return Err(QueryError::Cancelled); } if self.queued_files.insert(file_id) { diff --git a/crates/red_knot/src/program/mod.rs b/crates/red_knot/src/program/mod.rs index cd995e1b37b8db..b3c9119e66e20b 100644 --- a/crates/red_knot/src/program/mod.rs +++ b/crates/red_knot/src/program/mod.rs @@ -3,7 +3,7 @@ pub mod check; use std::path::Path; use std::sync::Arc; -use crate::db::{Db, HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar}; +use crate::db::{Db, HasJar, QueryResult, SemanticDb, SemanticJar, SourceDb, SourceJar}; use crate::files::{FileId, Files}; use crate::lint::{lint_semantic, lint_syntax, Diagnostics, LintSyntaxStorage}; use crate::module::{ @@ -83,43 +83,43 @@ impl SourceDb for Program { self.files.path(file_id) } - fn source(&self, file_id: FileId) -> Source { + fn source(&self, file_id: FileId) -> QueryResult { source_text(self, file_id) } - fn parse(&self, file_id: FileId) -> Parsed { + fn parse(&self, file_id: FileId) -> QueryResult { parse(self, file_id) } - fn lint_syntax(&self, file_id: FileId) -> Diagnostics { + fn lint_syntax(&self, file_id: FileId) -> QueryResult { lint_syntax(self, file_id) } } impl SemanticDb for Program { - fn resolve_module(&self, name: ModuleName) -> Option { + fn resolve_module(&self, name: ModuleName) -> QueryResult> { resolve_module(self, name) } - fn file_to_module(&self, file_id: FileId) -> Option { + fn file_to_module(&self, file_id: FileId) -> QueryResult> { file_to_module(self, file_id) } - fn path_to_module(&self, path: &Path) -> Option { + fn path_to_module(&self, path: &Path) -> QueryResult> { path_to_module(self, path) } - fn symbol_table(&self, file_id: FileId) -> Arc { + fn symbol_table(&self, file_id: FileId) -> QueryResult> { symbol_table(self, file_id) } // Mutations - fn eval_symbol(&self, file_id: FileId, symbol_id: SymbolId) -> Type { + fn eval_symbol(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult { eval_symbol(self, file_id, symbol_id) } - fn lint_semantic(&self, file_id: FileId) -> Diagnostics { + fn lint_semantic(&self, file_id: FileId) -> QueryResult { lint_semantic(self, file_id) } @@ -135,8 +135,8 @@ impl SemanticDb for Program { impl Db for Program {} impl HasJar for Program { - fn jar(&self) -> &SourceJar { - &self.source + fn jar(&self) -> QueryResult<&SourceJar> { + Ok(&self.source) } fn jar_mut(&mut self) -> &mut SourceJar { @@ -145,8 +145,8 @@ impl HasJar for Program { } impl HasJar for Program { - fn jar(&self) -> &SemanticJar { - &self.semantic + fn jar(&self) -> QueryResult<&SemanticJar> { + Ok(&self.semantic) } fn jar_mut(&mut self) -> &mut SemanticJar { diff --git a/crates/red_knot/src/source.rs b/crates/red_knot/src/source.rs index 08ad2d8abac3ed..69092d684453f9 100644 --- a/crates/red_knot/src/source.rs +++ b/crates/red_knot/src/source.rs @@ -1,5 +1,5 @@ use crate::cache::KeyValueCache; -use crate::db::{HasJar, SourceDb, SourceJar}; +use crate::db::{HasJar, QueryResult, SourceDb, SourceJar}; use ruff_notebook::Notebook; use ruff_python_ast::PySourceType; use std::ops::{Deref, DerefMut}; @@ -8,11 +8,11 @@ use std::sync::Arc; use crate::files::FileId; #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn source_text(db: &Db, file_id: FileId) -> Source +pub(crate) fn source_text(db: &Db, file_id: FileId) -> QueryResult where Db: SourceDb + HasJar, { - let sources = &db.jar().sources; + let sources = &db.jar()?.sources; sources.get(&file_id, |file_id| { let path = db.file_path(*file_id); @@ -43,7 +43,7 @@ where } }; - Source { kind } + Ok(Source { kind }) }) } diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 79b7f215e6825e..2d02c607d1c886 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -14,22 +14,22 @@ use ruff_python_ast::visitor::preorder::PreorderVisitor; use crate::ast_ids::TypedNodeKey; use crate::cache::KeyValueCache; -use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; use crate::module::ModuleName; use crate::Name; #[allow(unreachable_pub)] #[tracing::instrument(level = "debug", skip(db))] -pub fn symbol_table(db: &Db, file_id: FileId) -> Arc +pub fn symbol_table(db: &Db, file_id: FileId) -> QueryResult> where Db: SemanticDb + HasJar, { - let jar = db.jar(); + let jar = db.jar()?; jar.symbol_tables.get(&file_id, |_| { - let parsed = db.parse(file_id); - Arc::from(SymbolTable::from_ast(parsed.ast())) + let parsed = db.parse(file_id)?; + Ok(Arc::from(SymbolTable::from_ast(parsed.ast()))) }) } diff --git a/crates/red_knot/src/types/eval.rs b/crates/red_knot/src/types/eval.rs index 342655b25afe56..e699d2d362b9b7 100644 --- a/crates/red_knot/src/types/eval.rs +++ b/crates/red_knot/src/types/eval.rs @@ -2,26 +2,26 @@ use ruff_python_ast::AstNode; -use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; use crate::module::ModuleName; use crate::symbols::{Definition, ImportFromDefinition, SymbolId}; use crate::types::Type; use crate::FileId; #[tracing::instrument(level = "trace", skip(db))] -pub fn eval_symbol(db: &Db, file_id: FileId, symbol_id: SymbolId) -> Type +pub fn eval_symbol(db: &Db, file_id: FileId, symbol_id: SymbolId) -> QueryResult where Db: SemanticDb + HasJar, { - let symbols = db.symbol_table(file_id); + let symbols = db.symbol_table(file_id)?; let defs = symbols.definitions(symbol_id); if let Some(ty) = db - .jar() + .jar()? .type_store .get_cached_symbol_type(file_id, symbol_id) { - return ty; + return Ok(ty); } // TODO handle multiple defs, conditional defs... @@ -37,11 +37,11 @@ where assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports").as_str()); - if let Some(module) = db.resolve_module(module_name) { - let remote_file_id = module.path(db).file(); - let remote_symbols = db.symbol_table(remote_file_id); + if let Some(module) = db.resolve_module(module_name)? { + let remote_file_id = module.path(db)?.file(); + let remote_symbols = db.symbol_table(remote_file_id)?; if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) { - db.eval_symbol(remote_file_id, remote_symbol_id) + db.eval_symbol(remote_file_id, remote_symbol_id)? } else { Type::Unknown } @@ -51,19 +51,19 @@ where } Definition::ClassDef(node_key) => { if let Some(ty) = db - .jar() + .jar()? .type_store .get_cached_node_type(file_id, node_key.erased()) { ty } else { - let parsed = db.parse(file_id); + let parsed = db.parse(file_id)?; let ast = parsed.ast(); let node = node_key .resolve(ast.as_any_node_ref()) .expect("node key should resolve"); - let store = &db.jar().type_store; + let store = &db.jar()?.type_store; let ty = store.add_class(file_id, &node.name.id); store.cache_node_type(file_id, *node_key.erased(), ty); ty @@ -72,19 +72,19 @@ where Definition::FunctionDef(node_key) => { if let Some(ty) = db - .jar() + .jar()? .type_store .get_cached_node_type(file_id, node_key.erased()) { ty } else { - let parsed = db.parse(file_id); + let parsed = db.parse(file_id)?; let ast = parsed.ast(); let node = node_key .resolve(ast.as_any_node_ref()) .expect("node key should resolve"); - let store = &db.jar().type_store; + let store = &db.jar()?.type_store; let ty = store.add_function(file_id, &node.name.id); store.cache_node_type(file_id, *node_key.erased(), ty); ty @@ -93,11 +93,11 @@ where _ => todo!("other kinds of definitions"), }; - db.jar() + db.jar()? .type_store .cache_symbol_type(file_id, symbol_id, ty); // TODO record dependencies - ty + Ok(ty) } #[cfg(test)] diff --git a/crates/ruff_server/Cargo.toml b/crates/ruff_server/Cargo.toml index 591a37b3152d5d..a93430d6eb0d31 100644 --- a/crates/ruff_server/Cargo.toml +++ b/crates/ruff_server/Cargo.toml @@ -26,7 +26,7 @@ ruff_text_size = { path = "../ruff_text_size" } ruff_workspace = { path = "../ruff_workspace" } anyhow = { workspace = true } -crossbeam-channel = { workspace = true } +crossbeam = { workspace = true } jod-thread = { workspace = true } libc = { workspace = true } lsp-server = { workspace = true }