Merge pull request #12491 from vigoux/treesitter-set-ranges

[RDY] Treesitter set ranges
This commit is contained in:
Matthieu Coudron 2020-06-30 00:02:46 +02:00 committed by GitHub
commit 1920ba4b55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 161 additions and 33 deletions

View File

@ -512,6 +512,9 @@ retained for the lifetime of a buffer but this is subject to change. A plugin
should keep a reference to the parser object as long as it wants incremental
updates.
Parser methods *lua-treesitter-parser*
tsparser:parse() *tsparser:parse()*
Whenever you need to access the current syntax tree, parse the buffer: >
tstree = parser:parse()
@ -528,6 +531,16 @@ shouldn't be done directly in the change callback anyway as they will be very
frequent. Rather a plugin that does any kind of analysis on a tree should use
a timer to throttle too frequent updates.
tsparser:set_included_ranges(ranges) *tsparser:set_included_ranges()*
Changes the ranges the parser should consider. This is used for
language injection. `ranges` should be of the form (all zero-based): >
{
{start_node, end_node},
...
}
<
NOTE: `start_node` and `end_node` are both inclusive.
Tree methods *lua-treesitter-tree*
tstree:root() *tstree:root()*

View File

@ -30,6 +30,12 @@ function Parser:_on_lines(bufnr, _, start_row, old_stop_row, stop_row, old_byte_
self.valid = false
end
function Parser:set_included_ranges(ranges)
self._parser:set_included_ranges(ranges)
-- The buffer will need to be parsed again later
self.valid = false
end
local M = {
parse_query = vim._ts_parse_query,
}

View File

@ -1128,21 +1128,11 @@ void ex_luafile(exarg_T *const eap)
}
}
static int create_tslua_parser(lua_State *L)
{
if (lua_gettop(L) < 1 || !lua_isstring(L, 1)) {
return luaL_error(L, "string expected");
}
const char *lang_name = lua_tostring(L, 1);
return tslua_push_parser(L, lang_name);
}
static void nlua_add_treesitter(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL
{
tslua_init(lstate);
lua_pushcfunction(lstate, create_tslua_parser);
lua_pushcfunction(lstate, tslua_push_parser);
lua_setfield(lstate, -2, "_create_ts_parser");
lua_pushcfunction(lstate, tslua_add_language);

View File

@ -20,6 +20,7 @@
#include "nvim/lua/treesitter.h"
#include "nvim/api/private/handle.h"
#include "nvim/memline.h"
#include "nvim/buffer.h"
typedef struct {
TSParser *parser;
@ -41,6 +42,7 @@ static struct luaL_Reg parser_meta[] = {
{ "parse_buf", parser_parse_buf },
{ "edit", parser_edit },
{ "tree", parser_tree },
{ "set_included_ranges", parser_set_ranges },
{ NULL, NULL }
};
@ -214,8 +216,13 @@ int tslua_inspect_lang(lua_State *L)
return 1;
}
int tslua_push_parser(lua_State *L, const char *lang_name)
int tslua_push_parser(lua_State *L)
{
// Gather language
if (lua_gettop(L) < 1 || !lua_isstring(L, 1)) {
return luaL_error(L, "string expected");
}
const char *lang_name = lua_tostring(L, 1);
TSLanguage *lang = pmap_get(cstr_t)(langs, lang_name);
if (!lang) {
return luaL_error(L, "no such language: %s", lang_name);
@ -377,6 +384,57 @@ static int parser_edit(lua_State *L)
return 0;
}
static int parser_set_ranges(lua_State *L)
{
if (lua_gettop(L) < 2) {
return luaL_error(
L,
"not enough args to parser:set_included_ranges()");
}
TSLua_parser *p = parser_check(L);
if (!p || !p->tree) {
return 0;
}
if (!lua_istable(L, 2)) {
return luaL_error(
L,
"argument for parser:set_included_ranges() should be a table.");
}
size_t tbl_len = lua_objlen(L, 2);
TSRange *ranges = xmalloc(sizeof(TSRange) * tbl_len);
// [ parser, ranges ]
for (size_t index = 0; index < tbl_len; index++) {
lua_rawgeti(L, 2, index + 1); // [ parser, ranges, range ]
TSNode node;
if (!node_check(L, -1, &node)) {
xfree(ranges);
return luaL_error(
L,
"ranges should be tables of nodes.");
}
lua_pop(L, 1); // [ parser, ranges ]
ranges[index] = (TSRange) {
.start_point = ts_node_start_point(node),
.end_point = ts_node_end_point(node),
.start_byte = ts_node_start_byte(node),
.end_byte = ts_node_end_byte(node)
};
}
// This memcpies ranges, thus we can free it afterwards
ts_parser_set_included_ranges(p->parser, ranges, tbl_len);
xfree(ranges);
return 0;
}
// Tree methods
@ -459,9 +517,9 @@ static void push_node(lua_State *L, TSNode node, int uindex)
lua_setfenv(L, -2); // [udata]
}
static bool node_check(lua_State *L, TSNode *res)
static bool node_check(lua_State *L, int index, TSNode *res)
{
TSNode *ud = luaL_checkudata(L, 1, "treesitter_node");
TSNode *ud = luaL_checkudata(L, index, "treesitter_node");
if (ud) {
*res = *ud;
return true;
@ -473,7 +531,7 @@ static bool node_check(lua_State *L, TSNode *res)
static int node_tostring(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
lua_pushstring(L, "<node ");
@ -486,7 +544,7 @@ static int node_tostring(lua_State *L)
static int node_eq(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
// This should only be called if both x and y in "x == y" has the
@ -503,7 +561,7 @@ static int node_eq(lua_State *L)
static int node_range(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
TSPoint start = ts_node_start_point(node);
@ -518,7 +576,7 @@ static int node_range(lua_State *L)
static int node_start(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
TSPoint start = ts_node_start_point(node);
@ -532,7 +590,7 @@ static int node_start(lua_State *L)
static int node_end(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
TSPoint end = ts_node_end_point(node);
@ -546,7 +604,7 @@ static int node_end(lua_State *L)
static int node_child_count(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
uint32_t count = ts_node_child_count(node);
@ -557,7 +615,7 @@ static int node_child_count(lua_State *L)
static int node_named_child_count(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
uint32_t count = ts_node_named_child_count(node);
@ -568,7 +626,7 @@ static int node_named_child_count(lua_State *L)
static int node_type(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
lua_pushstring(L, ts_node_type(node));
@ -578,7 +636,7 @@ static int node_type(lua_State *L)
static int node_symbol(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
TSSymbol symbol = ts_node_symbol(node);
@ -589,7 +647,7 @@ static int node_symbol(lua_State *L)
static int node_named(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
lua_pushboolean(L, ts_node_is_named(node));
@ -599,7 +657,7 @@ static int node_named(lua_State *L)
static int node_sexpr(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
char *allocated = ts_node_string(node);
@ -611,7 +669,7 @@ static int node_sexpr(lua_State *L)
static int node_missing(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
lua_pushboolean(L, ts_node_is_missing(node));
@ -621,7 +679,7 @@ static int node_missing(lua_State *L)
static int node_has_error(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
lua_pushboolean(L, ts_node_has_error(node));
@ -631,7 +689,7 @@ static int node_has_error(lua_State *L)
static int node_child(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
long num = lua_tointeger(L, 2);
@ -644,7 +702,7 @@ static int node_child(lua_State *L)
static int node_named_child(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
long num = lua_tointeger(L, 2);
@ -657,7 +715,7 @@ static int node_named_child(lua_State *L)
static int node_descendant_for_range(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
TSPoint start = { (uint32_t)lua_tointeger(L, 2),
@ -673,7 +731,7 @@ static int node_descendant_for_range(lua_State *L)
static int node_named_descendant_for_range(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
TSPoint start = { (uint32_t)lua_tointeger(L, 2),
@ -689,7 +747,7 @@ static int node_named_descendant_for_range(lua_State *L)
static int node_parent(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
TSNode parent = ts_node_parent(node);
@ -771,7 +829,7 @@ static int query_next_capture(lua_State *L)
static int node_rawquery(lua_State *L)
{
TSNode node;
if (!node_check(L, &node)) {
if (!node_check(L, 1, &node)) {
return 0;
}
TSQuery *query = query_check(L, 2);

View File

@ -404,4 +404,65 @@ static int nlua_schedule(lua_State *const lstate)
end
eq({true,true}, {has_named,has_anonymous})
end)
it('allows to set simple ranges', function()
if not check_parser() then return end
insert(test_text)
local res = exec_lua([[
parser = vim.treesitter.get_parser(0, "c")
return { parser:parse():root():range() }
]])
eq({0, 0, 19, 0}, res)
-- The following sets the included ranges for the current parser
-- As stated here, this only includes the function (thus the whole buffer, without the last line)
local res2 = exec_lua([[
local root = parser:parse():root()
parser:set_included_ranges({root:child(0)})
parser.valid = false
return { parser:parse():root():range() }
]])
eq({0, 0, 18, 1}, res2)
end)
it("allows to set complex ranges", function()
if not check_parser() then return end
insert(test_text)
local res = exec_lua([[
parser = vim.treesitter.get_parser(0, "c")
query = vim.treesitter.parse_query("c", "(declaration) @decl")
local nodes = {}
for _, node in query:iter_captures(parser:parse():root(), 0, 0, 19) do
table.insert(nodes, node)
end
parser:set_included_ranges(nodes)
local root = parser:parse():root()
local res = {}
for i=0,(root:named_child_count() - 1) do
table.insert(res, { root:named_child(i):range() })
end
return res
]])
eq({
{ 2, 2, 2, 40 },
{ 3, 3, 3, 32 },
{ 4, 7, 4, 8 },
{ 4, 8, 4, 25 },
{ 8, 2, 8, 6 },
{ 8, 7, 8, 33 },
{ 9, 8, 9, 20 },
{ 10, 4, 10, 5 },
{ 10, 5, 10, 20 },
{ 14, 9, 14, 27 } }, res)
end)
end)