summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/npyio.py5
-rw-r--r--numpy/lib/tests/test_io.py15
2 files changed, 19 insertions, 1 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 3a575a048..c502e2cc5 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -10,6 +10,7 @@ import re
import sys
import itertools
import warnings
+import weakref
from operator import itemgetter
from cPickle import load as _cload, loads
@@ -108,7 +109,8 @@ class BagObj(object):
"""
def __init__(self, obj):
- self._obj = obj
+ # Use weakref to make NpzFile objects collectable by refcount
+ self._obj = weakref.proxy(obj)
def __getattribute__(self, key):
try:
return object.__getattribute__(self, '_obj')[key]
@@ -212,6 +214,7 @@ class NpzFile(object):
if self.fid is not None:
self.fid.close()
self.fid = None
+ self.f = None # break reference cycle
def __del__(self):
self.close()
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index 04ca3fb4e..949d8fb45 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -6,6 +6,7 @@ from tempfile import mkstemp, NamedTemporaryFile
import time
from datetime import datetime
import warnings
+import gc
from numpy.testing.utils import WarningManager
import numpy as np
@@ -1525,6 +1526,20 @@ def test_npzfile_dict():
assert_('x' in list(z.iterkeys()))
+def test_load_refcount():
+ # Check that objects returned by np.load are directly freed based on
+ # their refcount, rather than needing the gc to collect them.
+
+ f = StringIO()
+ np.savez(f, [1, 2, 3])
+ f.seek(0)
+
+ gc.collect()
+ n_before = len(gc.get_objects())
+ np.load(f)
+ n_after = len(gc.get_objects())
+
+ assert_equal(n_before, n_after)
if __name__ == "__main__":
run_module_suite()