summaryrefslogtreecommitdiff
path: root/sphinx/util/matching.py
diff options
context:
space:
mode:
Diffstat (limited to 'sphinx/util/matching.py')
-rw-r--r--sphinx/util/matching.py74
1 files changed, 68 insertions, 6 deletions
diff --git a/sphinx/util/matching.py b/sphinx/util/matching.py
index 53a893338..de4a776cf 100644
--- a/sphinx/util/matching.py
+++ b/sphinx/util/matching.py
@@ -1,9 +1,10 @@
"""Pattern-matching utility functions for Sphinx."""
+import os.path
import re
-from typing import Callable, Dict, Iterable, List, Match, Optional, Pattern
+from typing import Callable, Dict, Iterable, Iterator, List, Match, Optional, Pattern
-from sphinx.util.osutil import canon_path
+from sphinx.util.osutil import canon_path, path_stabilize
def _translate_pattern(pat: str) -> str:
@@ -52,7 +53,7 @@ def _translate_pattern(pat: str) -> str:
return res + '$'
-def compile_matchers(patterns: List[str]) -> List[Callable[[str], Optional[Match[str]]]]:
+def compile_matchers(patterns: Iterable[str]) -> List[Callable[[str], Optional[Match[str]]]]:
return [re.compile(_translate_pattern(pat)).match for pat in patterns]
@@ -63,9 +64,10 @@ class Matcher:
For example, "**/index.rst" matches with "index.rst"
"""
- def __init__(self, patterns: List[str]) -> None:
- expanded = [pat[3:] for pat in patterns if pat.startswith('**/')]
- self.patterns = compile_matchers(patterns + expanded)
+ def __init__(self, exclude_patterns: Iterable[str],
+ include_patterns: Iterable[str] = ()) -> None:
+ expanded = [pat[3:] for pat in exclude_patterns if pat.startswith('**/')]
+ self.patterns = compile_matchers(list(exclude_patterns) + expanded)
def __call__(self, string: str) -> bool:
return self.match(string)
@@ -99,3 +101,63 @@ def patfilter(names: Iterable[str], pat: str) -> List[str]:
_pat_cache[pat] = re.compile(_translate_pattern(pat))
match = _pat_cache[pat].match
return list(filter(match, names))
+
+
+def get_matching_files(
+ dirname: str,
+ exclude_patterns: Iterable[str] = (),
+ include_patterns: Iterable[str] = ("**",)
+) -> Iterator[str]:
+ """Get all file names in a directory, recursively.
+
+ Filter file names by the glob-style include_patterns and exclude_patterns.
+ The default values include all files ("**") and exclude nothing ("").
+
+ Only files matching some pattern in *include_patterns* are included, and
+ exclusions from *exclude_patterns* take priority over inclusions.
+
+ """
+ # dirname is a normalized absolute path.
+ dirname = os.path.normpath(os.path.abspath(dirname))
+
+ exclude_matchers = compile_matchers(exclude_patterns)
+ include_matchers = compile_matchers(include_patterns)
+
+ for root, dirs, files in os.walk(dirname, followlinks=True):
+ relative_root = os.path.relpath(root, dirname)
+ if relative_root == ".":
+ relative_root = "" # suppress dirname for files on the target dir
+
+ # Filter files
+ included_files = []
+ for entry in sorted(files):
+ entry = path_stabilize(os.path.join(relative_root, entry))
+ keep = False
+ for matcher in include_matchers:
+ if matcher(entry):
+ keep = True
+ break # break the inner loop
+
+ for matcher in exclude_matchers:
+ if matcher(entry):
+ keep = False
+ break # break the inner loop
+
+ if keep:
+ included_files.append(entry)
+
+ # Filter directories
+ filtered_dirs = []
+ for dir_name in sorted(dirs):
+ normalised = path_stabilize(os.path.join(relative_root, dir_name))
+ for matcher in exclude_matchers:
+ if matcher(normalised):
+ break # break the inner loop
+ else:
+ # if the loop didn't break
+ filtered_dirs.append(dir_name)
+
+ dirs[:] = filtered_dirs
+
+ # Yield filtered files
+ yield from included_files