diff options
Diffstat (limited to 'Lib/tracemalloc.py')
| -rw-r--r-- | Lib/tracemalloc.py | 73 | 
1 files changed, 55 insertions, 18 deletions
| diff --git a/Lib/tracemalloc.py b/Lib/tracemalloc.py index 6288da8409..75b391891f 100644 --- a/Lib/tracemalloc.py +++ b/Lib/tracemalloc.py @@ -244,17 +244,21 @@ class Trace:      __slots__ = ("_trace",)      def __init__(self, trace): -        # trace is a tuple: (size, traceback), see Traceback constructor -        # for the format of the traceback tuple +        # trace is a tuple: (domain: int, size: int, traceback: tuple). +        # See Traceback constructor for the format of the traceback tuple.          self._trace = trace      @property -    def size(self): +    def domain(self):          return self._trace[0]      @property +    def size(self): +        return self._trace[1] + +    @property      def traceback(self): -        return Traceback(self._trace[1]) +        return Traceback(self._trace[2])      def __eq__(self, other):          return (self._trace == other._trace) @@ -266,8 +270,8 @@ class Trace:          return "%s: %s" % (self.traceback, _format_size(self.size, False))      def __repr__(self): -        return ("<Trace size=%s, traceback=%r>" -                % (_format_size(self.size, False), self.traceback)) +        return ("<Trace domain=%s size=%s, traceback=%r>" +                % (self.domain, _format_size(self.size, False), self.traceback))  class _Traces(Sequence): @@ -302,19 +306,29 @@ def _normalize_filename(filename):      return filename -class Filter: +class BaseFilter: +    def __init__(self, inclusive): +        self.inclusive = inclusive + +    def _match(self, trace): +        raise NotImplementedError + + +class Filter(BaseFilter):      def __init__(self, inclusive, filename_pattern, -                 lineno=None, all_frames=False): +                 lineno=None, all_frames=False, domain=None): +        super().__init__(inclusive)          self.inclusive = inclusive          self._filename_pattern = _normalize_filename(filename_pattern)          self.lineno = lineno          self.all_frames = all_frames +        self.domain = domain      @property      def filename_pattern(self):          return self._filename_pattern -    def __match_frame(self, filename, lineno): +    def _match_frame_impl(self, filename, lineno):          filename = _normalize_filename(filename)          if not fnmatch.fnmatch(filename, self._filename_pattern):              return False @@ -324,11 +338,11 @@ class Filter:              return (lineno == self.lineno)      def _match_frame(self, filename, lineno): -        return self.__match_frame(filename, lineno) ^ (not self.inclusive) +        return self._match_frame_impl(filename, lineno) ^ (not self.inclusive)      def _match_traceback(self, traceback):          if self.all_frames: -            if any(self.__match_frame(filename, lineno) +            if any(self._match_frame_impl(filename, lineno)                     for filename, lineno in traceback):                  return self.inclusive              else: @@ -337,6 +351,30 @@ class Filter:              filename, lineno = traceback[0]              return self._match_frame(filename, lineno) +    def _match(self, trace): +        domain, size, traceback = trace +        res = self._match_traceback(traceback) +        if self.domain is not None: +            if self.inclusive: +                return res and (domain == self.domain) +            else: +                return res or (domain != self.domain) +        return res + + +class DomainFilter(BaseFilter): +    def __init__(self, inclusive, domain): +        super().__init__(inclusive) +        self._domain = domain + +    @property +    def domain(self): +        return self._domain + +    def _match(self, trace): +        domain, size, traceback = trace +        return (domain == self.domain) ^ (not self.inclusive) +  class Snapshot:      """ @@ -365,13 +403,12 @@ class Snapshot:              return pickle.load(fp)      def _filter_trace(self, include_filters, exclude_filters, trace): -        traceback = trace[1]          if include_filters: -            if not any(trace_filter._match_traceback(traceback) +            if not any(trace_filter._match(trace)                         for trace_filter in include_filters):                  return False          if exclude_filters: -            if any(not trace_filter._match_traceback(traceback) +            if any(not trace_filter._match(trace)                     for trace_filter in exclude_filters):                  return False          return True @@ -379,8 +416,8 @@ class Snapshot:      def filter_traces(self, filters):          """          Create a new Snapshot instance with a filtered traces sequence, filters -        is a list of Filter instances.  If filters is an empty list, return a -        new Snapshot instance with a copy of the traces. +        is a list of Filter or DomainFilter instances.  If filters is an empty +        list, return a new Snapshot instance with a copy of the traces.          """          if not isinstance(filters, Iterable):              raise TypeError("filters must be a list of filters, not %s" @@ -412,7 +449,7 @@ class Snapshot:          tracebacks = {}          if not cumulative:              for trace in self.traces._traces: -                size, trace_traceback = trace +                domain, size, trace_traceback = trace                  try:                      traceback = tracebacks[trace_traceback]                  except KeyError: @@ -433,7 +470,7 @@ class Snapshot:          else:              # cumulative statistics              for trace in self.traces._traces: -                size, trace_traceback = trace +                domain, size, trace_traceback = trace                  for frame in trace_traceback:                      try:                          traceback = tracebacks[frame] | 
