跳到内容

代码

节点解析器。

CodeSplitter #

基类: TextSplitter

使用 AST 解析器分割代码。

感谢 Kevin Lu / SweepAI 提出了这个优雅的代码分割解决方案。https://docs.sweep.dev/blogs/chunking-2m-files

参数

名称 类型 描述 默认值
language str

要分割的代码的编程语言。

必需
chunk_lines int

每个代码块包含的行数。

40
chunk_lines_overlap int

每个代码块与前一个块重叠的行数。

15
max_chars int

每个代码块的最大字符数。

1500
源代码位于 llama-index-core/llama_index/core/node_parser/text/code.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class CodeSplitter(TextSplitter):
    """
    Split code using a AST parser.

    Thank you to Kevin Lu / SweepAI for suggesting this elegant code splitting solution.
    https://docs.sweep.dev/blogs/chunking-2m-files
    """

    language: str = Field(
        description="The programming language of the code being split."
    )
    chunk_lines: int = Field(
        default=DEFAULT_CHUNK_LINES,
        description="The number of lines to include in each chunk.",
        gt=0,
    )
    chunk_lines_overlap: int = Field(
        default=DEFAULT_LINES_OVERLAP,
        description="How many lines of code each chunk overlaps with.",
        gt=0,
    )
    max_chars: int = Field(
        default=DEFAULT_MAX_CHARS,
        description="Maximum number of characters per chunk.",
        gt=0,
    )
    _parser: Any = PrivateAttr()

    def __init__(
        self,
        language: str,
        chunk_lines: int = DEFAULT_CHUNK_LINES,
        chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP,
        max_chars: int = DEFAULT_MAX_CHARS,
        parser: Any = None,
        callback_manager: Optional[CallbackManager] = None,
        include_metadata: bool = True,
        include_prev_next_rel: bool = True,
        id_func: Optional[Callable[[int, Document], str]] = None,
    ) -> None:
        """Initialize a CodeSplitter."""
        from tree_sitter import Parser  # pants: no-infer-dep

        callback_manager = callback_manager or CallbackManager([])
        id_func = id_func or default_id_func

        super().__init__(
            language=language,
            chunk_lines=chunk_lines,
            chunk_lines_overlap=chunk_lines_overlap,
            max_chars=max_chars,
            callback_manager=callback_manager,
            include_metadata=include_metadata,
            include_prev_next_rel=include_prev_next_rel,
            id_func=id_func,
        )

        if parser is None:
            try:
                import tree_sitter_language_pack  # pants: no-infer-dep

                parser = tree_sitter_language_pack.get_parser(language)  # type: ignore
            except ImportError:
                raise ImportError(
                    "Please install tree_sitter_language_pack to use CodeSplitter."
                    "Or pass in a parser object."
                )
            except Exception:
                print(
                    f"Could not get parser for language {language}. Check "
                    "https://github.com/Goldziher/tree-sitter-language-pack?tab=readme-ov-file#available-languages "
                    "for a list of valid languages."
                )
                raise
        if not isinstance(parser, Parser):
            raise ValueError("Parser must be a tree-sitter Parser object.")

        self._parser = parser

    @classmethod
    def from_defaults(
        cls,
        language: str,
        chunk_lines: int = DEFAULT_CHUNK_LINES,
        chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP,
        max_chars: int = DEFAULT_MAX_CHARS,
        callback_manager: Optional[CallbackManager] = None,
        parser: Any = None,
    ) -> "CodeSplitter":
        """Create a CodeSplitter with default values."""
        return cls(
            language=language,
            chunk_lines=chunk_lines,
            chunk_lines_overlap=chunk_lines_overlap,
            max_chars=max_chars,
            callback_manager=callback_manager,
            parser=parser,
        )

    @classmethod
    def class_name(cls) -> str:
        return "CodeSplitter"

    def _chunk_node(self, node: Any, text: str, last_end: int = 0) -> List[str]:
        """
        Recursively chunk a node into smaller pieces based on character limits.

        Args:
            node (Any): The AST node to chunk.
            text (str): The original source code text.
            last_end (int, optional): The ending position of the last processed chunk. Defaults to 0.

        Returns:
            List[str]: A list of code chunks that respect the max_chars limit.

        """
        new_chunks = []
        current_chunk = ""
        for child in node.children:
            if child.end_byte - child.start_byte > self.max_chars:
                # Child is too big, recursively chunk the child
                if len(current_chunk) > 0:
                    new_chunks.append(current_chunk)
                current_chunk = ""
                new_chunks.extend(self._chunk_node(child, text, last_end))
            elif (
                len(current_chunk) + child.end_byte - child.start_byte > self.max_chars
            ):
                # Child would make the current chunk too big, so start a new chunk
                new_chunks.append(current_chunk)
                current_chunk = text[last_end : child.end_byte]
            else:
                current_chunk += text[last_end : child.end_byte]
            last_end = child.end_byte
        if len(current_chunk) > 0:
            new_chunks.append(current_chunk)
        return new_chunks

    def split_text(self, text: str) -> List[str]:
        """
        Split incoming code into chunks using the AST parser.

        This method parses the input code into an AST and then chunks it while preserving
        syntactic structure. It handles error cases and ensures the code can be properly parsed.

        Args:
            text (str): The source code text to split.

        Returns:
            List[str]: A list of code chunks.

        Raises:
            ValueError: If the code cannot be parsed for the specified language.

        """
        """Split incoming code and return chunks using the AST."""
        with self.callback_manager.event(
            CBEventType.CHUNKING, payload={EventPayload.CHUNKS: [text]}
        ) as event:
            tree = self._parser.parse(bytes(text, "utf-8"))

            if (
                not tree.root_node.children
                or tree.root_node.children[0].type != "ERROR"
            ):
                chunks = [
                    chunk.strip() for chunk in self._chunk_node(tree.root_node, text)
                ]
                event.on_end(
                    payload={EventPayload.CHUNKS: chunks},
                )

                return chunks
            else:
                raise ValueError(f"Could not parse code with language {self.language}.")

from_defaults classmethod #

from_defaults(language: str, chunk_lines: int = DEFAULT_CHUNK_LINES, chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP, max_chars: int = DEFAULT_MAX_CHARS, callback_manager: Optional[CallbackManager] = None, parser: Any = None) -> CodeSplitter

使用默认值创建一个 CodeSplitter。

源代码位于 llama-index-core/llama_index/core/node_parser/text/code.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@classmethod
def from_defaults(
    cls,
    language: str,
    chunk_lines: int = DEFAULT_CHUNK_LINES,
    chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP,
    max_chars: int = DEFAULT_MAX_CHARS,
    callback_manager: Optional[CallbackManager] = None,
    parser: Any = None,
) -> "CodeSplitter":
    """Create a CodeSplitter with default values."""
    return cls(
        language=language,
        chunk_lines=chunk_lines,
        chunk_lines_overlap=chunk_lines_overlap,
        max_chars=max_chars,
        callback_manager=callback_manager,
        parser=parser,
    )

split_text #

split_text(text: str) -> List[str]

使用 AST 解析器将输入的代码分割成代码块。

此方法将输入的代码解析为 AST,然后进行分割,同时保留语法结构。它处理错误情况并确保代码可以正确解析。

参数

名称 类型 描述 默认值
text str

要分割的源代码文本。

必需

返回

类型 描述
List[str]

List[str]: 代码块列表。

抛出

类型 描述
ValueError

如果无法解析指定语言的代码。

源代码位于 llama-index-core/llama_index/core/node_parser/text/code.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def split_text(self, text: str) -> List[str]:
    """
    Split incoming code into chunks using the AST parser.

    This method parses the input code into an AST and then chunks it while preserving
    syntactic structure. It handles error cases and ensures the code can be properly parsed.

    Args:
        text (str): The source code text to split.

    Returns:
        List[str]: A list of code chunks.

    Raises:
        ValueError: If the code cannot be parsed for the specified language.

    """
    """Split incoming code and return chunks using the AST."""
    with self.callback_manager.event(
        CBEventType.CHUNKING, payload={EventPayload.CHUNKS: [text]}
    ) as event:
        tree = self._parser.parse(bytes(text, "utf-8"))

        if (
            not tree.root_node.children
            or tree.root_node.children[0].type != "ERROR"
        ):
            chunks = [
                chunk.strip() for chunk in self._chunk_node(tree.root_node, text)
            ]
            event.on_end(
                payload={EventPayload.CHUNKS: chunks},
            )

            return chunks
        else:
            raise ValueError(f"Could not parse code with language {self.language}.")