Skip to content

Commit

Permalink
Address some review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaReiser committed Apr 27, 2024
1 parent 88b1a51 commit 8491e8f
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 34 deletions.
158 changes: 131 additions & 27 deletions crates/red_knot/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,37 @@ impl Module {

modules.modules.get(self).unwrap().path.clone()
}
}

/// A module name, e.g. `foo.bar`.
///
/// Always normalized to the absolute form (never a relative module name).
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ModuleName(smol_str::SmolStr);

impl ModuleName {
pub fn new(name: &str) -> Self {
debug_assert!(!name.is_empty());
pub fn kind<Db>(&self, db: &Db) -> ModuleKind
where
Db: HasJar<SemanticJar>,
{
let modules = &db.jar().module_resolver;

Self(smol_str::SmolStr::new(name))
modules.modules.get(self).unwrap().kind
}

pub fn relative(&self, level: u32, module: Option<&str>) -> Option<Self> {
let mut components = self.components().peekable();

// Skip over the relative parts.
for _ in 0..level {
components.next_back()?;
pub fn relative_name<Db>(&self, db: &Db, level: u32, module: Option<&str>) -> Option<ModuleName>
where
Db: HasJar<SemanticJar>,
{
let name = self.name(db);
let kind = self.kind(db);

let mut components = name.components().peekable();

if level > 0 {
let start = match kind {
// `.` resolves to the enclosing package
ModuleKind::Module => 0,
// `.` resolves to the current package
ModuleKind::Package => 1,
};

// Skip over the relative parts.
for _ in start..level {
components.next_back()?;
}
}

let mut name = String::new();
Expand All @@ -65,10 +75,28 @@ impl ModuleName {
name.push_str(part);
}

Some(Self(SmolStr::new(name)))
if name.is_empty() {
None
} else {
Some(ModuleName(SmolStr::new(name)))
}
}
}

/// A module name, e.g. `foo.bar`.
///
/// Always normalized to the absolute form (never a relative module name).
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ModuleName(smol_str::SmolStr);

impl ModuleName {
pub fn new(name: &str) -> Self {
debug_assert!(!name.is_empty());

Self(smol_str::SmolStr::new(name))
}

pub fn from_relative_path(path: &Path) -> Option<Self> {
fn from_relative_path(path: &Path) -> Option<Self> {
let path = if path.ends_with("__init__.py") || path.ends_with("__init__.pyi") {
path.parent()?
} else {
Expand Down Expand Up @@ -119,6 +147,14 @@ impl std::fmt::Display for ModuleName {
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum ModuleKind {
Module,

/// A python package (a `__init__.py` or `__init__.pyi` file)
Package,
}

/// A search path in which to search modules.
/// Corresponds to a path in [`sys.path`](https://docs.python.org/3/library/sys_path_init.html) at runtime.
///
Expand Down Expand Up @@ -178,6 +214,7 @@ impl ModuleSearchPathKind {
pub struct ModuleData {
name: ModuleName,
path: ModulePath,
kind: ModuleKind,
}

//////////////////////////////////////////////////////
Expand All @@ -200,7 +237,7 @@ where
match entry {
Entry::Occupied(entry) => Some(*entry.get()),
Entry::Vacant(entry) => {
let (root_path, absolute_path) = resolve_name(&name, &modules.search_paths)?;
let (root_path, absolute_path, kind) = resolve_name(&name, &modules.search_paths)?;
let normalized = absolute_path.canonicalize().ok()?;

let file_id = db.file_id(&normalized);
Expand All @@ -214,7 +251,7 @@ where

modules
.modules
.insert(id, Arc::from(ModuleData { name, path }));
.insert(id, Arc::from(ModuleData { name, path, kind }));

// A path can map to multiple modules because of symlinks:
// ```
Expand Down Expand Up @@ -463,7 +500,7 @@ impl ModulePath {
fn resolve_name(
name: &ModuleName,
search_paths: &[ModuleSearchPath],
) -> Option<(ModuleSearchPath, PathBuf)> {
) -> Option<(ModuleSearchPath, PathBuf, ModuleKind)> {
for search_path in search_paths {
let mut components = name.components();
let module_name = components.next_back()?;
Expand All @@ -475,21 +512,24 @@ fn resolve_name(
package_path.push(module_name);

// Must be a `__init__.pyi` or `__init__.py` or it isn't a package.
if package_path.is_dir() {
let kind = if package_path.is_dir() {
package_path.push("__init__");
}
ModuleKind::Package
} else {
ModuleKind::Module
};

// TODO Implement full https://peps.python.org/pep-0561/#type-checker-module-resolution-order resolution
let stub = package_path.with_extension("pyi");

if stub.is_file() {
return Some((search_path.clone(), stub));
return Some((search_path.clone(), stub, kind));
}

let module = package_path.with_extension("py");

if module.is_file() {
return Some((search_path.clone(), module));
return Some((search_path.clone(), module, kind));
}

// For regular packages, don't search the next search path. All files of that
Expand Down Expand Up @@ -597,7 +637,7 @@ impl PackageKind {
mod tests {
use crate::db::tests::TestDb;
use crate::db::{SemanticDb, SourceDb};
use crate::module::{ModuleName, ModuleSearchPath, ModuleSearchPathKind};
use crate::module::{ModuleKind, ModuleName, ModuleSearchPath, ModuleSearchPathKind};

struct TestCase {
temp_dir: tempfile::TempDir,
Expand Down Expand Up @@ -653,6 +693,7 @@ mod tests {

assert_eq!(ModuleName::new("foo"), foo_module.name(&db));
assert_eq!(&src, foo_module.path(&db).root());
assert_eq!(ModuleKind::Module, foo_module.kind(&db));
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file()));

assert_eq!(Some(foo_module), db.path_to_module(&foo_path));
Expand Down Expand Up @@ -709,6 +750,7 @@ mod tests {

assert_eq!(&src, foo_module.path(&db).root());
assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db).file()));
assert_eq!(ModuleKind::Package, foo_module.kind(&db));

assert_eq!(Some(foo_module), db.path_to_module(&foo_init));
assert_eq!(None, db.path_to_module(&foo_py));
Expand Down Expand Up @@ -927,4 +969,66 @@ mod tests {

Ok(())
}

#[test]
fn relative_name() -> std::io::Result<()> {
let TestCase {
src,
db,
temp_dir: _temp_dir,
..
} = create_resolver()?;

let foo_dir = src.path().join("foo");
let foo_path = foo_dir.join("__init__.py");
let bar_path = foo_dir.join("bar.py");

std::fs::create_dir(&foo_dir)?;
std::fs::write(foo_path, "from .bar import test")?;
std::fs::write(bar_path, "test = 'Hello world'")?;

let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap();
let bar_module = db.resolve_module(ModuleName::new("foo.bar")).unwrap();

// `from . import bar` in `foo/__init__.py` resolves to `foo`
assert_eq!(
Some(ModuleName::new("foo")),
foo_module.relative_name(&db, 1, None)
);

// `from baz import bar` in `foo/__init__.py` should resolve to `foo/baz.py`
assert_eq!(
Some(ModuleName::new("foo.baz")),
foo_module.relative_name(&db, 0, Some("baz"))
);

// from .bar import test in `foo/__init__.py` should resolve to `foo/bar.py`
assert_eq!(
Some(ModuleName::new("foo.bar")),
foo_module.relative_name(&db, 1, Some("bar"))
);

// from .. import test in `foo/__init__.py` resolves to `` which is not a module
assert_eq!(None, foo_module.relative_name(&db, 2, None));

// `from . import test` in `foo/bar.py` resolves to `foo`
assert_eq!(
Some(ModuleName::new("foo")),
bar_module.relative_name(&db, 1, None)
);

// `from baz import test` in `foo/bar.py` resolves to `foo.bar.baz`
assert_eq!(
Some(ModuleName::new("foo.bar.baz")),
bar_module.relative_name(&db, 0, Some("baz"))
);

// `from .baz import test` in `foo/bar.py` resolves to `foo.baz`.
assert_eq!(
Some(ModuleName::new("foo.baz")),
bar_module.relative_name(&db, 1, Some("baz"))
);

Ok(())
}
}
8 changes: 5 additions & 3 deletions crates/red_knot/src/program/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,19 @@ impl Program {
let dependencies = symbol_table.dependencies();

if !dependencies.is_empty() {
let module_name = self.file_to_module(file).map(|module| module.name(self));
let module = self.file_to_module(file);

// TODO scheduling all dependencies here is wasteful if we don't infer any types on them
// but I think that's unlikely, so it is okay?
// Anyway, we need to figure out a way to retrieve the dependencies of a module
// from the persistent cache. So maybe it should be a separate query after all.
for dependency in dependencies {
let dependency_name = dependency.module_name(module_name.as_ref());
let dependency_name = dependency.module_name(self, module);

if let Some(dependency_name) = dependency_name {
// TODO The resolver doesn't support native dependencies yet.
// TODO We may want to have a different check functions for non-first-party
// files because we only need to index them and not check them.
// Supporting non-first-party code also requires supporting typing stubs.
if let Some(dependency) = self.resolve_module(dependency_name) {
if dependency.path(self).root().kind().is_first_party() {
context.schedule_check_file(dependency.path(self).file());
Expand Down
9 changes: 6 additions & 3 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::ast_ids::TypedNodeKey;
use crate::cache::KeyValueCache;
use crate::db::{HasJar, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::module::ModuleName;
use crate::module::{Module, ModuleName};
use crate::Name;

#[allow(unreachable_pub)]
Expand Down Expand Up @@ -128,11 +128,14 @@ pub(crate) enum Dependency {
}

impl Dependency {
pub(crate) fn module_name(&self, relative_to: Option<&ModuleName>) -> Option<ModuleName> {
pub(crate) fn module_name<Db>(&self, db: &Db, relative_to: Option<Module>) -> Option<ModuleName>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
match self {
Dependency::Module(name) => Some(ModuleName::new(name.as_str())),
Dependency::Relative { level, module } => {
relative_to?.relative(*level, module.as_ref().map(|name| name.as_str()))
relative_to?.relative_name(db, *level, module.as_deref())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where
}) => {
// TODO relative imports
assert!(matches!(level, 0));
let module_name = ModuleName::new(&module.as_ref().expect("TODO relative imports"));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(module) = db.resolve_module(module_name) {
let remote_file_id = module.path(db).file();
let remote_symbols = db.symbol_table(remote_file_id);
Expand Down

0 comments on commit 8491e8f

Please sign in to comment.