Skip to content
Snippets Groups Projects
Commit b5036334 authored by Olivier Bertrand's avatar Olivier Bertrand
Browse files

Add memory loading for db

parent 42f25af3
No related branches found
No related tags found
No related merge requests found
...@@ -34,6 +34,23 @@ def convert_array(text): ...@@ -34,6 +34,23 @@ def convert_array(text):
return np.load(out) return np.load(out)
def database2memory(source):
""" Load a database in the RAM of the machine
It can be useful to speed up certain calculation
However quite some memory may be used.
:param source: path to the databse to load into memory
:returns: a database
"""
db = DataBaseLoad(source)
dbmem = DataBaseSave(":memory:")
for row_id, posorient in db.posorients.iterrows():
scene = db.scene(posorient)
dbmem.write_image(posorient, np.squeeze(scene[..., 0]))
return dbmem
# Converts np.array to TEXT when inserting # Converts np.array to TEXT when inserting
sqlite3.register_adapter(np.ndarray, adapt_array) sqlite3.register_adapter(np.ndarray, adapt_array)
# Converts TEXT to np.array when selecting # Converts TEXT to np.array when selecting
...@@ -65,7 +82,7 @@ class DataBase(): ...@@ -65,7 +82,7 @@ class DataBase():
self._logger.exception(msg) self._logger.exception(msg)
raise TypeError(msg) raise TypeError(msg)
_, ext = os.path.splitext(filename) _, ext = os.path.splitext(filename)
if ext != '.db': if ext != '.db' and (filename != ':memory:'):
msg = 'filename must have the .db extension' msg = 'filename must have the .db extension'
self._logger.exception(msg) self._logger.exception(msg)
raise NameError(msg) raise NameError(msg)
...@@ -116,12 +133,18 @@ class DataBase(): ...@@ -116,12 +133,18 @@ class DataBase():
for col in self.normalisation_columns: for col in self.normalisation_columns:
self.tablecolumns['normalisation'][col] = 'real' self.tablecolumns['normalisation'][col] = 'real'
if os.path.exists(filename): if self.create is False:
self._logger.info('Connect to database') self._logger.info('Connect to database')
self.db = sqlite3.connect(
'file:' + filename + '?cache=shared', uri=True, if os.path.exists(filename) or filename == ':memory:':
detect_types=sqlite3.PARSE_DECLTYPES, self.db = sqlite3.connect(
timeout=10) 'file:' + filename + '?cache=shared', uri=True,
detect_types=sqlite3.PARSE_DECLTYPES,
timeout=10)
else:
msg = 'Database {} does not exist'.format(filename)
self._logger.exception(msg)
raise NameError(msg)
self.db_cursor = self.db.cursor() self.db_cursor = self.db.cursor()
# Check table # Check table
self._logger.debug('Check tables') self._logger.debug('Check tables')
...@@ -131,10 +154,12 @@ class DataBase(): ...@@ -131,10 +154,12 @@ class DataBase():
named {}'.format(filename, tablename) named {}'.format(filename, tablename)
self._logger.exception(msg) self._logger.exception(msg)
raise Exception(msg) raise Exception(msg)
elif self.create: else:
self._logger.info('Create to database') self._logger.info('Create to database')
self.db = sqlite3.connect( self.db = sqlite3.connect(
filename, detect_types=sqlite3.PARSE_DECLTYPES) 'file:' + filename + '?cache=shared', uri=True,
detect_types=sqlite3.PARSE_DECLTYPES,
timeout=10)
self.db_cursor = self.db.cursor() self.db_cursor = self.db.cursor()
# Create table # Create table
self._logger.info('Create tables') self._logger.info('Create tables')
...@@ -146,10 +171,6 @@ class DataBase(): ...@@ -146,10 +171,6 @@ class DataBase():
self.db_cursor.execute( self.db_cursor.execute(
"create table {} {}".format(key, columns)) "create table {} {}".format(key, columns))
self.db.commit() self.db.commit()
else:
msg = 'Database {} does not exist'.format(filename)
self._logger.exception(msg)
raise NameError(msg)
azimuth = np.deg2rad(np.linspace(-180, 180, 360)) azimuth = np.deg2rad(np.linspace(-180, 180, 360))
elevation = np.deg2rad(np.linspace(-90, 90, 180)) elevation = np.deg2rad(np.linspace(-90, 90, 180))
...@@ -780,12 +801,12 @@ class DataBaseLoad(DataBase): ...@@ -780,12 +801,12 @@ class DataBaseLoad(DataBase):
return denormed_im return denormed_im
class DataBaseSave(DataBase): class DataBaseSave(DataBaseLoad):
def __init__(self, filename, channels=['R', 'G', 'B', 'D'], def __init__(self, filename, channels=['R', 'G', 'B', 'D'],
arr_dtype=np.uint8): arr_dtype=np.uint8):
""" """
""" """
DataBase.__init__(self, filename, channels=channels) DataBaseLoad.__init__(self, filename, channels=channels)
self.arr_dtype = arr_dtype self.arr_dtype = arr_dtype
@property @property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment