Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/exploit_iq_commons/utils/chain_of_calls_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa
if parents:
direct_parents.extend(parents)
function_name_to_search = self.language_parser.get_function_name(document_function)
if not function_name_to_search:
return None
if function_name_to_search == self.language_parser.get_constructor_method_name():
function_name_to_search = self.language_parser.get_class_name_from_class_function(document_function)
function_file_name = document_function.metadata.get('source')
Expand Down Expand Up @@ -319,6 +321,8 @@ def get_possible_docs(self, function_name_to_search: str, package: str, exclusio
(self.language_parser.is_function(doc) or self.language_parser.is_script_language()) and
not self._is_doc_excluded(doc, exclusions)]

if not function_name_to_search:
return []
return [doc for doc in filter_1 if doc.page_content.__contains__(f"{function_name_to_search}(")]

def __find_caller_functions_bfs(self, document_function: Document, function_package: str,
Expand All @@ -344,6 +348,8 @@ def __find_caller_functions_bfs(self, document_function: Document, function_pack
# direct_parents.extend([function_package])
# gets list of documents to search in only from parents of function' package.
function_name_to_search = self.language_parser.get_function_name(document_function)
if not function_name_to_search:
return []
function_file_name = document_function.metadata.get('source')
relevant_docs_to_search_in = list()
# Search for caller functions only at parents according to dependency tree.
Expand All @@ -365,7 +371,12 @@ def __find_caller_functions_bfs(self, document_function: Document, function_pack
file_name = doc.metadata.get('source')
if doc.metadata.get('state') == "invalid":
continue
func_name = self.language_parser.get_function_name(doc)
try:
func_name = self.language_parser.get_function_name(doc)
except ValueError:
continue
if not func_name:
continue
# check for same doc
if (function_name_to_search == func_name) and (file_name == function_file_name):
continue
Expand Down Expand Up @@ -438,6 +449,8 @@ def _breadth_first_search(self, matching_documents: List[Document], target_funct
logger.debug("get_relevant_documents: invalid function %s", target_doc.metadata['source'])
continue
function_name = self.language_parser.get_function_name(target_doc)
if not function_name:
continue

function_file = target_doc.metadata.get('source')
hashed_value = calculate_hashable_string_for_function(function_file, function_name)
Expand Down Expand Up @@ -577,9 +590,9 @@ def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]:
root_package = [key for (key, value) in self.tree_dict.items() if ROOT_LEVEL_SENTINEL in value]
prefix_of_3rd_parties_libs = self.language_parser.dir_name_for_3rd_party_packages()
# find all parents ( all importing packages) of the ibput package so we'll have candidate pkgs to search in.
parents = list({self.language_parser.get_package_names(doc)[1] for doc in importing_docs if
parents = list({self.language_parser.get_package_names(doc)[-1] for doc in importing_docs if
doc.metadata['source'].startswith(
prefix_of_3rd_parties_libs) and self.language_parser.get_package_names(doc)[1]
prefix_of_3rd_parties_libs) and self.language_parser.get_package_names(doc)[-1]
in self.tree_dict.keys()})
for doc in importing_docs:
if not doc.metadata.get('source').startswith(prefix_of_3rd_parties_libs):
Expand Down Expand Up @@ -651,7 +664,11 @@ def __find_initial_function(self, function_name: str, package_name: str, documen
self.get_functions_for_package(package_name, relevant_docs, sources_location_packages=True),
self.get_functions_for_package(package_name, relevant_docs, sources_location_packages=False),
):
if function_name.lower() == self.language_parser.get_function_name(document).lower():
doc_func_name = self.language_parser.get_function_name(document)
Comment thread
TamarW0 marked this conversation as resolved.
if not doc_func_name:
logger.warning("Skipping document with empty function name: %s", document.metadata.get('source', ''))
continue
if function_name.lower() == doc_func_name.lower():
package_exclusions.append(document)
return document

Expand Down Expand Up @@ -684,7 +701,10 @@ def print_call_hierarchy(self, call_hierarchy_list: list[Document]) -> list[str]
package_name = package_function.metadata['source']
try:
function_name = self.language_parser.get_function_name(package_function)
current_level = f"(package={package_name},function={function_name},depth={i})"
if not function_name:
current_level = f"(document={package_name},depth={i})"
else:
current_level = f"(package={package_name},function={function_name},depth={i})"
except ValueError:
current_level = f"(document={package_name},depth={i})"
results.append(current_level)
Expand Down
4 changes: 4 additions & 0 deletions src/exploit_iq_commons/utils/document_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def lazy_parse(self, blob: Blob) -> typing.Iterator[Document]:
)
return

segmenter_cls = self.LANGUAGE_SEGMENTERS.get(language)
if segmenter_cls is not None and hasattr(segmenter_cls, "should_skip") and isinstance(blob.source, str) and segmenter_cls.should_skip(blob.source):
return

Comment thread
zvigrinberg marked this conversation as resolved.
if self.parser_threshold >= len(code.splitlines()):
yield Document(
page_content=code,
Expand Down
Loading