diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 7006ab0..81d42d3 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -71,6 +71,7 @@ body: - C - C++ - Ruby + - Kotlin - Not language-specific validations: required: false diff --git a/.gitignore b/.gitignore index f5bb03e..94fcf31 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ build/ # Virtual environments .venv/ venv/ +uv.lock # Testing .pytest_cache/ diff --git a/AGENTS.md b/AGENTS.md index a6877d0..2995f63 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -52,6 +52,7 @@ All `file_path` arguments are **relative to the repo root** (e.g., `"src/main.py | `.c`, `.h` | CPlugin | `function_definition`, `struct_specifier`, `type_definition` | | `.cpp`, `.cc`, `.cxx`, `.hpp`, `.hh` | CppPlugin | `class_specifier`, `function_definition`, `struct_specifier`, `namespace_definition` | | `.rb` | RubyPlugin | `class`, `module`, `method`, `singleton_method` | +| `.kt` | KotlinPlugin | `class_declaration`, `object_declaration`, `function_declaration` | ## Commands @@ -120,6 +121,7 @@ MCP tool call → server.py → indexer.py → FileEntry.plugin → tree-sitter | `c.py` | CPlugin | | `cpp.py` | CppPlugin (inherits CPlugin) | | `ruby.py` | RubyPlugin | +| `kotlin.py` | KotlinPlugin | | `_template.py` | Boilerplate for adding new languages | Each plugin implements: diff --git a/README.md b/README.md index 45c8d8e..4e5821a 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ **Stop feeding entire files to your AI agent.** -codetree is an [MCP](https://modelcontextprotocol.io/) server that gives coding agents structured code understanding via [tree-sitter](https://tree-sitter.github.io/) — so they ask precise questions instead of reading thousands of lines. 23 tools, 10 languages, ~1 second startup. No vector DB, no embedding model, no config. +codetree is an [MCP](https://modelcontextprotocol.io/) server that gives coding agents structured code understanding via [tree-sitter](https://tree-sitter.github.io/) — so they ask precise questions instead of reading thousands of lines. 23 tools, 11 languages, ~1 second startup. No vector DB, no embedding model, no config. ## Quick Start @@ -152,6 +152,7 @@ The agent sees every class, method, and docstring — with line numbers — with | C | `.c`, `.h` | | C++ | `.cpp`, `.cc`, `.cxx`, `.hpp`, `.hh` | | Ruby | `.rb` | +| Kotlin | `.kt` | ## Editor Setup diff --git a/docs/LANDING_PAGE.md b/docs/LANDING_PAGE.md index 7349130..bbf2d04 100644 --- a/docs/LANDING_PAGE.md +++ b/docs/LANDING_PAGE.md @@ -190,7 +190,7 @@ cls Calculator:4 # A scientific calculator. ## 6. Supported Languages -10 languages. 16 file extensions. All backed by official tree-sitter grammars. Plus a persistent graph layer for onboarding, change impact, and security analysis. +11 languages. 17 file extensions. All backed by official tree-sitter grammars. Plus a persistent graph layer for onboarding, change impact, and security analysis. | Language | Extensions | |----------|-----------| @@ -204,6 +204,7 @@ cls Calculator:4 # A scientific calculator. | C | `.c` `.h` | | C++ | `.cpp` `.cc` `.cxx` `.hpp` `.hh` | | Ruby | `.rb` | +| Kotlin | `.kt` | Adding a new language is mechanical: copy a template file, implement 5 methods, register in one file, done. @@ -308,7 +309,7 @@ claude mcp add codetree -- uvx --from mcp-server-codetree codetree --root . ┌─────────────────────────────────────────────────┐ │ Language Plugins (one per language) │ │ │ -│ Python │ JS │ TS │ Go │ Rust │ Java │ C │ C++ │ Ruby │ +│ Python │ JS │ TS │ Go │ Rust │ Java │ C │ C++ │ Ruby │ Kotlin │ │ │ │ Each implements: │ │ extract_skeleton() │ @@ -335,8 +336,8 @@ claude mcp add codetree -- uvx --from mcp-server-codetree codetree --root . | Metric | Value | |--------|-------| | MCP tools | 23 | -| Supported languages | 10 | -| File extensions | 16 | +| Supported languages | 11 | +| File extensions | 17 | | Test count | 999 | | Startup time | ~1 second | | Install | `uvx --from mcp-server-codetree codetree` | diff --git a/pyproject.toml b/pyproject.toml index 48debe1..5b02f76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "tree-sitter-c>=0.23.0", "tree-sitter-cpp>=0.23.0", "tree-sitter-ruby>=0.23.0", + "tree-sitter-kotlin>=0.23.0", "fastmcp>=2.0.0", ] diff --git a/src/codetree/languages/kotlin.py b/src/codetree/languages/kotlin.py new file mode 100644 index 0000000..1178091 --- /dev/null +++ b/src/codetree/languages/kotlin.py @@ -0,0 +1,272 @@ +from tree_sitter import Language, Parser, Query +import tree_sitter_kotlin as tskotlin +from .base import LanguagePlugin, _matches, _fill_docs_from_siblings + +_LANGUAGE = Language(tskotlin.language()) +_PARSER = Parser(_LANGUAGE) + + +def _parse(source: bytes): + return _PARSER.parse(source) + + +class KotlinPlugin(LanguagePlugin): + extensions = (".kt", ".kts") + + def extract_skeleton(self, source: bytes) -> list[dict]: + tree = _parse(source) + results = [] + + # Top-level classes / interfaces (both use class_declaration) + q = Query(_LANGUAGE, "(source_file (class_declaration (identifier) @name) @def)") + for _, m in _matches(q, tree.root_node): + node_def = m["def"] + # Check if it's an interface + sym_type = "class" + for child in node_def.children: + if child.type == "interface": + sym_type = "interface" + break + + results.append({ + "type": sym_type, + "name": m["name"].text.decode("utf-8", errors="replace"), + "line": m["name"].start_point[0] + 1, + "parent": None, + "params": "", + }) + + # Top-level objects + q = Query(_LANGUAGE, "(source_file (object_declaration (identifier) @name) @def)") + for _, m in _matches(q, tree.root_node): + results.append({ + "type": "class", + "name": m["name"].text.decode("utf-8", errors="replace"), + "line": m["name"].start_point[0] + 1, + "parent": None, + "params": "", + }) + + # Top-level functions + q = Query(_LANGUAGE, "(source_file (function_declaration (identifier) @name (function_value_parameters) @params) @def)") + for _, m in _matches(q, tree.root_node): + results.append({ + "type": "function", + "name": m["name"].text.decode("utf-8", errors="replace"), + "line": m["name"].start_point[0] + 1, + "parent": None, + "params": m["params"].text.decode("utf-8", errors="replace"), + }) + + # Methods inside classes/interfaces + q = Query(_LANGUAGE, """ + (class_declaration + (identifier) @class_name + (class_body + (function_declaration + (identifier) @method_name + (function_value_parameters) @params))) + """) + for _, m in _matches(q, tree.root_node): + results.append({ + "type": "method", + "name": m["method_name"].text.decode("utf-8", errors="replace"), + "line": m["method_name"].start_point[0] + 1, + "parent": m["class_name"].text.decode("utf-8", errors="replace"), + "params": m["params"].text.decode("utf-8", errors="replace"), + }) + + # Methods inside objects + q = Query(_LANGUAGE, """ + (object_declaration + (identifier) @class_name + (class_body + (function_declaration + (identifier) @method_name + (function_value_parameters) @params))) + """) + for _, m in _matches(q, tree.root_node): + results.append({ + "type": "method", + "name": m["method_name"].text.decode("utf-8", errors="replace"), + "line": m["method_name"].start_point[0] + 1, + "parent": m["class_name"].text.decode("utf-8", errors="replace"), + "params": m["params"].text.decode("utf-8", errors="replace"), + }) + + # Fill doc fields from preceding comments + for item in results: + item.setdefault("doc", "") + _fill_docs_from_siblings(results, tree.root_node, _LANGUAGE, [ + "(class_declaration (identifier) @name) @def", + "(object_declaration (identifier) @name) @def", + "(function_declaration (identifier) @name) @def", + ]) + + results.sort(key=lambda x: x["line"]) + return results + + def extract_symbol_source(self, source: bytes, name: str) -> tuple[str, int] | None: + tree = _parse(source) + + # Classes, objects + for q_str in [ + "(class_declaration (identifier) @name) @def", + "(object_declaration (identifier) @name) @def", + ]: + for _, m in _matches(Query(_LANGUAGE, q_str), tree.root_node): + if m["name"].text.decode("utf-8", errors="replace") == name: + node = m["def"] + return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace"), node.start_point[0] + 1 + + # Functions/Methods + q = Query(_LANGUAGE, "(function_declaration (identifier) @name) @def") + for _, m in _matches(q, tree.root_node): + if m["name"].text.decode("utf-8", errors="replace") == name: + node = m["def"] + return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace"), node.start_point[0] + 1 + + return None + + def extract_calls_in_function(self, source: bytes, fn_name: str) -> list[str]: + tree = _parse(source) + fn_node = None + q = Query(_LANGUAGE, "(function_declaration (identifier) @name) @def") + for _, m in _matches(q, tree.root_node): + if m["name"].text.decode("utf-8", errors="replace") == fn_name: + fn_node = m["def"] + break + if fn_node is None: + return [] + + calls = set() + # Method calls: foo(), foo.bar() + q = Query(_LANGUAGE, """ + (call_expression + [ + (identifier) @called + (navigation_expression (identifier) @called) + ]) + """) + for _, m in _matches(q, fn_node): + node = m["called"] + if node.parent and node.parent.type == "navigation_expression": + ids = [c for c in node.parent.children if c.type == "identifier"] + if ids and node == ids[-1]: + calls.add(node.text.decode("utf-8", errors="replace")) + else: + calls.add(node.text.decode("utf-8", errors="replace")) + + return sorted(calls) + + def extract_symbol_usages(self, source: bytes, name: str) -> list[dict]: + tree = _parse(source) + usages = [] + seen = set() + # Kotlin uses identifier for most things + q = Query(_LANGUAGE, f'((identifier) @name (#eq? @name "{name}"))') + for _, m in _matches(q, tree.root_node): + node = m["name"] + key = (node.start_point[0], node.start_point[1]) + if key not in seen: + seen.add(key) + usages.append({"line": node.start_point[0] + 1, "col": node.start_point[1]}) + + usages.sort(key=lambda x: (x["line"], x["col"])) + return usages + + def extract_imports(self, source: bytes) -> list[dict]: + tree = _parse(source) + results = [] + q = Query(_LANGUAGE, "(import) @imp") + for _, m in _matches(q, tree.root_node): + node = m["imp"] + results.append({ + "line": node.start_point[0] + 1, + "text": node.text.decode("utf-8", errors="replace").strip(), + }) + results.sort(key=lambda x: x["line"]) + return results + + def compute_complexity(self, source: bytes, fn_name: str) -> dict | None: + tree = _parse(source) + fn_node = None + q = Query(_LANGUAGE, "(function_declaration (identifier) @name) @def") + for _, m in _matches(q, tree.root_node): + if m["name"].text.decode("utf-8", errors="replace") == fn_name: + fn_node = m["def"] + break + if fn_node is None: + return None + + branch_map = { + "if_expression": "if", + "for_statement": "for", + "while_statement": "while", + "do_while_statement": "do_while", + "catch_block": "catch", + "when_expression": "when", + "when_entry": "case", + } + counts: dict[str, int] = {} + + def walk(node): + if node.type in branch_map: + label = branch_map[node.type] + counts[label] = counts.get(label, 0) + 1 + elif node.type in ("&&", "||"): + counts[node.type] = counts.get(node.type, 0) + 1 + for child in node.children: + walk(child) + + walk(fn_node) + total = 1 + sum(counts.values()) + return {"total": total, "breakdown": counts} + + def extract_variables(self, source: bytes, fn_name: str) -> list[dict]: + tree = _parse(source) + fn_node = None + q = Query(_LANGUAGE, "(function_declaration (identifier) @name) @def") + for _, m in _matches(q, tree.root_node): + if m["name"].text.decode("utf-8", errors="replace") == fn_name: + fn_node = m["def"] + break + if fn_node is None: + return [] + + results = [] + seen = set() + + def _add(name, line, var_type="", kind="local"): + if name not in seen: + seen.add(name) + results.append({"name": name, "line": line, "type": var_type, "kind": kind}) + + # Parameters + q_params = Query(_LANGUAGE, "(parameter (identifier) @name (user_type)? @type)") + for _, m in _matches(q_params, fn_node): + type_text = m.get("type").text.decode("utf-8", errors="replace") if m.get("type") else "" + _add(m["name"].text.decode("utf-8", errors="replace"), + m["name"].start_point[0] + 1, + var_type=type_text, + kind="parameter") + + # Local variables (val/var) + q_vars = Query(_LANGUAGE, "(variable_declaration (identifier) @name (user_type)? @type)") + for _, m in _matches(q_vars, fn_node): + type_text = m.get("type").text.decode("utf-8", errors="replace") if m.get("type") else "" + _add(m["name"].text.decode("utf-8", errors="replace"), + m["name"].start_point[0] + 1, + var_type=type_text, + kind="local") + + return results + + def check_syntax(self, source: bytes) -> bool: + return _parse(source).root_node.has_error + + def _get_parser(self): + return _PARSER + + def _get_language(self): + return _LANGUAGE diff --git a/src/codetree/registry.py b/src/codetree/registry.py index 9dd695e..68df33d 100644 --- a/src/codetree/registry.py +++ b/src/codetree/registry.py @@ -9,6 +9,7 @@ from .languages.c import CPlugin from .languages.cpp import CppPlugin from .languages.ruby import RubyPlugin +from .languages.kotlin import KotlinPlugin # All supported file extensions mapped to plugin instances. # To add a new language: import its plugin and add its extensions here. @@ -29,6 +30,8 @@ ".hpp": CppPlugin(), ".hh": CppPlugin(), ".rb": RubyPlugin(), + ".kt": KotlinPlugin(), + ".kts": KotlinPlugin(), } diff --git a/tests/languages/test_kotlin.py b/tests/languages/test_kotlin.py new file mode 100644 index 0000000..016cd17 --- /dev/null +++ b/tests/languages/test_kotlin.py @@ -0,0 +1,93 @@ +import pytest +from codetree.languages.kotlin import KotlinPlugin + +PLUGIN = KotlinPlugin() + +SAMPLE = b"""\ +class Calculator { + fun add(a: Int, b: Int): Int { + return a + b + } + fun divide(a: Int, b: Int): Int { + if (b == 0) throw IllegalArgumentException("div by zero") + return a / b + } +} + +object Helper { + fun run(): Int { + val calc = Calculator() + return calc.add(1, 2) + } +} + +fun topLevel(x: Int) = x * 2 +""" + + +def test_skeleton_finds_classes_and_objects(): + result = PLUGIN.extract_skeleton(SAMPLE) + names = [item["name"] for item in result] + assert "Calculator" in names + assert "Helper" in names + + +def test_skeleton_finds_methods(): + result = PLUGIN.extract_skeleton(SAMPLE) + names = [item["name"] for item in result] + assert "add" in names + assert "divide" in names + assert "run" in names + + +def test_skeleton_finds_top_level_function(): + result = PLUGIN.extract_skeleton(SAMPLE) + names = [item["name"] for item in result] + assert "topLevel" in names + + +def test_skeleton_method_has_parent(): + result = PLUGIN.extract_skeleton(SAMPLE) + add = next(item for item in result if item["name"] == "add") + assert add["parent"] == "Calculator" + run = next(item for item in result if item["name"] == "run") + assert run["parent"] == "Helper" + + +def test_extract_symbol_finds_class(): + result = PLUGIN.extract_symbol_source(SAMPLE, "Calculator") + assert result is not None + source, _ = result + assert "class Calculator" in source + + +def test_extract_symbol_finds_method(): + result = PLUGIN.extract_symbol_source(SAMPLE, "add") + assert result is not None + source, _ = result + assert "fun add" in source + + +def test_extract_symbol_returns_none_for_missing(): + assert PLUGIN.extract_symbol_source(SAMPLE, "nonexistent") is None + + +def test_extract_calls_in_function(): + calls = PLUGIN.extract_calls_in_function(SAMPLE, "run") + # For Kotlin, it should find 'add' + assert "add" in calls + + +def test_extract_symbol_usages(): + usages = PLUGIN.extract_symbol_usages(SAMPLE, "Calculator") + assert len(usages) >= 2 # Definition + instantiation + + +def test_kts_support(): + kts_sample = b""" + fun buildConfig() { + println("building...") + } + """ + result = PLUGIN.extract_skeleton(kts_sample) + assert any(x["name"] == "buildConfig" for x in result) diff --git a/tests/languages/test_kotlin_comprehensive.py b/tests/languages/test_kotlin_comprehensive.py new file mode 100644 index 0000000..1980d8c --- /dev/null +++ b/tests/languages/test_kotlin_comprehensive.py @@ -0,0 +1,231 @@ +""" +Exhaustive tests for the Kotlin plugin covering every realistic code pattern. + +Code style categories: + - Classes: plain, open, abstract, data, sealed, with extends/implements + - Interfaces: plain, with methods + - Objects: plain, companion object + - Functions: top-level, extension, lambda, member + - extract_symbol_source and extract_calls_in_function +""" +import pytest +from codetree.languages.kotlin import KotlinPlugin + +P = KotlinPlugin() + + +# ─── Class styles ────────────────────────────────────────────────────────────── + +def test_plain_class(): + src = b"class Foo {}\n" + assert any(x["type"] == "class" and x["name"] == "Foo" for x in P.extract_skeleton(src)) + + +def test_data_class(): + src = b"data class User(val id: Int, val name: String)\n" + assert any(x["name"] == "User" for x in P.extract_skeleton(src)) + + +def test_abstract_class(): + src = b"abstract class Shape {\n abstract fun area(): Double\n}\n" + result = P.extract_skeleton(src) + assert any(x["name"] == "Shape" for x in result) + + +def test_class_inheritance(): + src = b"class Dog : Animal(), Runnable {\n override fun run() {}\n}\n" + result = P.extract_skeleton(src) + assert any(x["name"] == "Dog" for x in result) + assert any(x["name"] == "run" and x["parent"] == "Dog" for x in result) + + +# ─── Interface styles ───────────────────────────────────────────────────────── + +def test_plain_interface(): + src = b"interface Printable {\n fun print()\n}\n" + result = P.extract_skeleton(src) + assert any(x["type"] == "interface" and x["name"] == "Printable" for x in result) + + +def test_interface_methods_in_skeleton(): + src = b"interface Dao {\n fun load(id: Int): String\n fun store(v: String)\n}\n" + result = P.extract_skeleton(src) + assert any(x["name"] == "load" and x["parent"] == "Dao" for x in result) + assert any(x["name"] == "store" and x["parent"] == "Dao" for x in result) + + +# ─── Object styles ──────────────────────────────────────────────────────────── + +def test_plain_object(): + src = b"object Database {\n fun connect() {}\n}\n" + result = P.extract_skeleton(src) + assert any(x["name"] == "Database" for x in result) + assert any(x["name"] == "connect" and x["parent"] == "Database" for x in result) + + +# ─── Function styles ────────────────────────────────────────────────────────── + +def test_top_level_function(): + src = b"fun sum(a: Int, b: Int) = a + b\n" + result = P.extract_skeleton(src) + assert any(x["name"] == "sum" and x["parent"] is None for x in result) + + +def test_extension_function(): + src = b"fun String.shout() = this.uppercase()\n" + result = P.extract_skeleton(src) + # Extension functions are tricky in tree-sitter, usually identifier is shout + assert any(x["name"] == "shout" for x in result) + + +def test_member_function(): + src = b"class Logger {\n fun log(msg: String) { println(msg) }\n}\n" + result = P.extract_skeleton(src) + assert any(x["name"] == "log" and x["parent"] == "Logger" for x in result) + + +# ─── Mixed file ──────────────────────────────────────────────────────────────── + +MIXED_SRC = b""" +interface Drawable { + fun draw() +} + +abstract class Shape(val color: String) { + abstract fun area(): Double +} + +class Circle(val radius: Double) : Shape("red"), Drawable { + override fun area() = 3.14 * radius * radius + override fun draw() { println("Drawing circle") } +} + +object App { + fun run() { + val c = Circle(5.0) + c.draw() + } +} +""" + + +def test_mixed_interface_found(): + result = P.extract_skeleton(MIXED_SRC) + assert any(x["type"] == "interface" and x["name"] == "Drawable" for x in result) + + +def test_mixed_abstract_class_found(): + result = P.extract_skeleton(MIXED_SRC) + assert any(x["name"] == "Shape" for x in result) + + +def test_mixed_concrete_class_found(): + result = P.extract_skeleton(MIXED_SRC) + assert any(x["name"] == "Circle" for x in result) + + +def test_mixed_object_found(): + result = P.extract_skeleton(MIXED_SRC) + assert any(x["name"] == "App" for x in result) + + +def test_mixed_member_functions(): + result = P.extract_skeleton(MIXED_SRC) + assert any(x["name"] == "area" and x["parent"] == "Circle" for x in result) + assert any(x["name"] == "draw" and x["parent"] == "Circle" for x in result) + assert any(x["name"] == "run" and x["parent"] == "App" for x in result) + + +def test_mixed_sorted_by_line(): + result = P.extract_skeleton(MIXED_SRC) + lines = [x["line"] for x in result] + assert lines == sorted(lines) + + +# ─── extract_symbol_source ───────────────────────────────────────────────────── + +def test_symbol_source_class(): + src = b"class Calc {\n fun add(a: Int, b: Int) = a + b\n}\n" + result = P.extract_symbol_source(src, "Calc") + assert result is not None + source, line = result + assert "class Calc" in source + assert line == 1 + + +def test_symbol_source_function(): + src = b"fun hello() = \"world\"\n" + result = P.extract_symbol_source(src, "hello") + assert result is not None + source, _ = result + assert "fun hello" in source + + +# ─── extract_calls_in_function ───────────────────────────────────────────────── + +def test_calls_direct(): + src = b"fun run() {\n init()\n start()\n}\n" + calls = P.extract_calls_in_function(src, "run") + assert "init" in calls + assert "start" in calls + + +def test_calls_navigation(): + src = b"fun run(db: Database) {\n db.connect()\n db.query(\"x\")\n}\n" + calls = P.extract_calls_in_function(src, "run") + assert "connect" in calls + assert "query" in calls + + +def test_calls_instantiation(): + src = b"fun create() = Widget()\n" + calls = P.extract_calls_in_function(src, "create") + assert "Widget" in calls + + +# ─── Complexity & Variables ─────────────────────────────────────────────────── + +def test_compute_complexity(): + src = b""" + fun complex(x: Int) { + if (x > 0) { + for (i in 1..x) { + while (true) { + when(i) { + 1 -> println(1) + else -> println(0) + } + } + } + } + } + """ + comp = P.compute_complexity(src, "complex") + assert comp is not None + assert comp["total"] >= 5 + assert "if" in comp["breakdown"] + assert "for" in comp["breakdown"] + assert "while" in comp["breakdown"] + assert "when" in comp["breakdown"] + + +def test_extract_variables(): + src = b""" + fun vars(a: Int, b: String) { + val x = 1 + var y: Double = 2.0 + for (item in list) { + println(item) + } + } + """ + vars = P.extract_variables(src, "vars") + names = [v["name"] for v in vars] + assert "a" in names + assert "b" in names + assert "x" in names + assert "y" in names + # item in for loop is also a variable + # My current implementation doesn't specifically target for loop variables yet + # but they are in variable_declaration nodes in Kotlin 1.1.0 (as I saw in my show output) + assert "item" in names