Skip to content

Commit

Permalink
feat: support dynamic language in pyo3
Browse files Browse the repository at this point in the history
fix #1143
  • Loading branch information
HerringtonDarkholme committed Aug 8, 2024
1 parent 9240e7b commit 09aedaf
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 10 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 19,11 @@ crate-type = ["cdylib"]
ast-grep-core.workspace = true
ast-grep-config.workspace = true
ast-grep-language.workspace = true
ast-grep-dynamic.workspace = true
anyhow.workspace = true
pyo3 = { version = "0.21.2", optional = true, features = ["anyhow"] }
pythonize = { version = "0.21.1", optional = true }
serde.workspace = true

# uncomment default features when developing pyo3
[features]
Expand Down
7 changes: 4 additions & 3 deletions crates/pyo3/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 1,14 @@
#![cfg(not(test))]
#![cfg(feature = "python")]
mod py_lang;
mod py_node;
mod range;
mod unicode_position;
use py_node::{Edit, SgNode};
use range::{Pos, Range};

use ast_grep_core::{AstGrep, Language, NodeMatch, StrDoc};
use ast_grep_language::SupportLang;
use py_lang::PyLang;
use pyo3::prelude::*;

use unicode_position::UnicodePosition;
Expand All @@ -25,7 26,7 @@ fn ast_grep_py(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {

#[pyclass]
struct SgRoot {
inner: AstGrep<StrDoc<SupportLang>>,
inner: AstGrep<StrDoc<PyLang>>,
filename: String,
pub(crate) position: UnicodePosition,
}
Expand All @@ -35,7 36,7 @@ impl SgRoot {
#[new]
fn new(src: &str, lang: &str) -> Self {
let position = UnicodePosition::new(src);
let lang: SupportLang = lang.parse().unwrap();
let lang: PyLang = lang.parse().unwrap();
let inner = lang.ast_grep(src);
Self {
inner,
Expand Down
114 changes: 114 additions & 0 deletions crates/pyo3/src/py_lang.rs
Original file line number Diff line number Diff line change
@@ -0,0 1,114 @@
use ast_grep_core::language::TSLanguage;
use ast_grep_dynamic::{DynamicLang, Registration};
use ast_grep_language::{Language, SupportLang};
use serde::{Deserialize, Serialize};

use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::path::{Path, PathBuf};
use std::str::FromStr;

#[derive(Serialize, Deserialize, Clone)]
pub struct CustomLang {
library_path: PathBuf,
/// the dylib symbol to load ts-language, default is `tree_sitter_{name}`
language_symbol: Option<String>,
meta_var_char: Option<char>,
expando_char: Option<char>,
extensions: Vec<String>,
}

impl CustomLang {
pub fn register(base: PathBuf, langs: HashMap<String, CustomLang>) {
let registrations = langs
.into_iter()
.map(|(name, custom)| to_registration(name, custom, &base))
.collect();
// TODO, add error handling
unsafe { DynamicLang::register(registrations).expect("TODO") }
}
}

fn to_registration(name: String, custom_lang: CustomLang, base: &Path) -> Registration {
let path = base.join(custom_lang.library_path);
let sym = custom_lang
.language_symbol
.unwrap_or_else(|| format!("tree_sitter_{name}"));
Registration {
lang_name: name,
lib_path: path,
symbol: sym,
meta_var_char: custom_lang.meta_var_char,
expando_char: custom_lang.expando_char,
extensions: custom_lang.extensions,
}
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum PyLang {
// inlined support lang expando char
Builtin(SupportLang),
Custom(DynamicLang),
}
#[derive(Debug)]
pub enum PyLangErr {
LanguageNotSupported(String),
}

impl Display for PyLangErr {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
use PyLangErr::*;
match self {
LanguageNotSupported(lang) => write!(f, "{} is not supported!", lang),
}
}
}

impl std::error::Error for PyLangErr {}

impl FromStr for PyLang {
type Err = PyLangErr;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(b) = SupportLang::from_str(s) {
Ok(PyLang::Builtin(b))
} else if let Ok(c) = DynamicLang::from_str(s) {
Ok(PyLang::Custom(c))
} else {
Err(PyLangErr::LanguageNotSupported(s.into()))
}
}
}

use PyLang::*;
impl Language for PyLang {
fn get_ts_language(&self) -> TSLanguage {
match self {
Builtin(b) => b.get_ts_language(),
Custom(c) => c.get_ts_language(),
}
}

fn pre_process_pattern<'q>(&self, query: &'q str) -> Cow<'q, str> {
match self {
Builtin(b) => b.pre_process_pattern(query),
Custom(c) => c.pre_process_pattern(query),

Check warning on line 95 in crates/pyo3/src/py_lang.rs

View check run for this annotation

Codecov / codecov/patch

crates/pyo3/src/py_lang.rs#L92-L95

Added lines #L92 - L95 were not covered by tests
}
}

#[inline]
fn meta_var_char(&self) -> char {
match self {
Builtin(b) => b.meta_var_char(),
Custom(c) => c.meta_var_char(),

Check warning on line 103 in crates/pyo3/src/py_lang.rs

View check run for this annotation

Codecov / codecov/patch

crates/pyo3/src/py_lang.rs#L100-L103

Added lines #L100 - L103 were not covered by tests
}
}

#[inline]
fn expando_char(&self) -> char {
match self {
Builtin(b) => b.expando_char(),
Custom(c) => c.expando_char(),

Check warning on line 111 in crates/pyo3/src/py_lang.rs

View check run for this annotation

Codecov / codecov/patch

crates/pyo3/src/py_lang.rs#L108-L111

Added lines #L108 - L111 were not covered by tests
}
}
}
11 changes: 4 additions & 7 deletions crates/pyo3/src/py_node.rs
Original file line number Diff line number Diff line change
@@ -1,9 1,9 @@
use crate::py_lang::PyLang;
use crate::range::Range;
use crate::SgRoot;

use ast_grep_config::{DeserializeEnv, RuleCore, SerializableRuleCore};
use ast_grep_core::{NodeMatch, StrDoc};
use ast_grep_language::SupportLang;

use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
Expand All @@ -16,7 16,7 @@ use pythonize::depythonize_bound;

#[pyclass(mapping)]
pub struct SgNode {
pub inner: NodeMatch<'static, StrDoc<SupportLang>>,
pub inner: NodeMatch<'static, StrDoc<PyLang>>,
// refcount SgRoot
pub(crate) root: Py<SgRoot>,
}
Expand Down Expand Up @@ -328,7 328,7 @@ impl SgNode {
&self,
config: Option<Bound<PyDict>>,
kwargs: Option<Bound<PyDict>>,
) -> PyResult<RuleCore<SupportLang>> {
) -> PyResult<RuleCore<PyLang>> {
let lang = self.inner.lang();
let config = if let Some(config) = config {
config_from_dict(config)?
Expand Down Expand Up @@ -358,10 358,7 @@ fn config_from_rule(dict: Bound<PyDict>) -> PyResult<SerializableRuleCore> {
})
}

fn get_matcher_from_rule(
lang: &SupportLang,
dict: Option<Bound<PyDict>>,
) -> PyResult<RuleCore<SupportLang>> {
fn get_matcher_from_rule(lang: &PyLang, dict: Option<Bound<PyDict>>) -> PyResult<RuleCore<PyLang>> {
let rule = dict.ok_or_else(|| PyErr::new::<PyValueError, _>("rule must not be empty"))?;
let env = DeserializeEnv::new(*lang);
let config = config_from_rule(rule)?;
Expand Down

0 comments on commit 09aedaf

Please sign in to comment.