diff --git a/.travis.yml b/.travis.yml index 9a7e99b5..e911484e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ -# http://travis-ci.org/#!/MongoEngine/flask_mongoengine language: python + python: - "2.6" - "2.7" @@ -44,6 +44,7 @@ script: notifications: irc: "irc.freenode.org#flask-mongoengine" + branches: only: - master diff --git a/flask_mongoengine/__init__.py b/flask_mongoengine/__init__.py index 3ab5a066..44f641c5 100644 --- a/flask_mongoengine/__init__.py +++ b/flask_mongoengine/__init__.py @@ -16,33 +16,6 @@ from .wtf import WtfBaseField -def redirect_connection_calls(cls): - """ - Monkey-patch mongoengine's connection methods so that they use - Flask-MongoEngine's equivalents. - - Given a random mongoengine class (`cls`), get the module it's in, - and iterate through all of that module's members to find the - particular methods we want to monkey-patch. - """ - # TODO this is so whack... Why don't we pass particular connection - # settings down to mongoengine and just use their original implementation? - - # Map of mongoengine method/variable names and flask-mongoengine - # methods they should point to - connection_methods = { - 'get_db': get_db, - 'DEFAULT_CONNECTION_NAME': DEFAULT_CONNECTION_NAME, - 'get_connection': get_connection - } - cls_module = inspect.getmodule(cls) - if cls_module != mongoengine.connection: - for attr in inspect.getmembers(cls_module): - n = attr[0] - if n in connection_methods: - setattr(cls_module, n, connection_methods[n]) - - def _patch_base_field(obj, name): """ If the object submitted has a class whose base class is @@ -59,6 +32,8 @@ def _patch_base_field(obj, name): @param obj: MongoEngine instance in which we should locate the class. @param name: Name of an attribute which may or may not be a BaseField. """ + # TODO is there a less hacky way to accomplish the same level of + # extensibility/control? # get an attribute of the MongoEngine class and return if it's not # a class @@ -79,7 +54,6 @@ def _patch_base_field(obj, name): # re-assign the class back to the MongoEngine instance delattr(obj, name) setattr(obj, name, cls) - redirect_connection_calls(cls) def _include_mongoengine(obj): @@ -99,10 +73,7 @@ def _include_mongoengine(obj): def current_mongoengine_instance(): - """ - Obtain instance of MongoEngine in the - current working app instance. - """ + """Return a MongoEngine instance associated with current Flask app.""" me = current_app.extensions.get('mongoengine', {}) for k, v in me.items(): if isinstance(k, MongoEngine): @@ -139,36 +110,22 @@ def init_app(self, app, config=None): raise Exception('Extension already initialized') if not config: - # If not passed a config then we - # read the connection settings from - # the app config. + # If not passed a config then we read the connection settings + # from the app config. config = app.config - # Obtain db connection - connection = create_connection(config, app) + # Obtain db connection(s) + connections = create_connections(config) - # Store objects in application instance - # so that multiple apps do not end up - # accessing the same objects. - s = {'app': app, 'conn': connection} + # Store objects in application instance so that multiple apps do not + # end up accessing the same objects. + s = {'app': app, 'conn': connections} app.extensions['mongoengine'][self] = s - def disconnect(self): - """Close all connections to MongoDB.""" - conn_settings = fetch_connection_settings(current_app.config) - if isinstance(conn_settings, list): - for setting in conn_settings: - alias = setting.get('alias', DEFAULT_CONNECTION_NAME) - disconnect(alias, setting.get('preserve_temp_db', False)) - else: - alias = conn_settings.get('alias', DEFAULT_CONNECTION_NAME) - disconnect(alias, conn_settings.get('preserve_temp_db', False)) - return True - @property def connection(self): """ - Return MongoDB connection associated with this MongoEngine + Return MongoDB connection(s) associated with this MongoEngine instance. """ return current_app.extensions['mongoengine'][self]['conn'] diff --git a/flask_mongoengine/connection.py b/flask_mongoengine/connection.py index 9fa63a1e..708bcb37 100644 --- a/flask_mongoengine/connection.py +++ b/flask_mongoengine/connection.py @@ -1,389 +1,122 @@ -import atexit -import os.path -import shutil -import subprocess -import tempfile -import time - -from flask import current_app import mongoengine -from mongoengine import connection -from pymongo import MongoClient, ReadPreference, errors -from pymongo.errors import InvalidURI +from pymongo import ReadPreference, uri_parser __all__ = ( - 'create_connection', 'disconnect', 'get_connection', - 'DEFAULT_CONNECTION_NAME', 'fetch_connection_settings', - 'InvalidSettingsError', 'get_db' + 'create_connections', 'get_connection_settings', 'InvalidSettingsError', ) -DEFAULT_CONNECTION_NAME = 'default-mongodb-connection' - -_connection_settings = {} -_connections = {} -_tmpdir = None -_conn = None -_process = None -_app_instance = current_app - class InvalidSettingsError(Exception): pass -class ConnectionError(Exception): - pass - - -def disconnect(alias=DEFAULT_CONNECTION_NAME, preserved=False): - global _connections, _process, _tmpdir - - if alias in _connections: - conn = get_connection(alias=alias) - client = conn.client - if client: - client.close() - else: - conn.close() - del _connections[alias] - - if _process: - _process.terminate() - _process.wait() - _process = None - - if (not preserved and _tmpdir): - sock_file = 'mongodb-27111.sock' - if os.path.exists(_tmpdir): - shutil.rmtree(_tmpdir, ignore_errors=True) - if os.path.exists(sock_file): - os.remove("{0}/{1}".format(tempfile.gettempdir(), sock_file)) - - -def _validate_settings(is_test, temp_db, preserved, conn_host): - """ - Validate unitest settings to ensure - valid values are supplied before obtaining - connection. +def _sanitize_settings(settings): + """Given a dict of connection settings, sanitize the keys and fall + back to some sane defaults. """ - if (not isinstance(is_test, bool) or not isinstance(temp_db, bool) or - not isinstance(preserved, bool)): - msg = ('`TESTING`, `TEMP_DB`, and `PRESERVE_TEMP_DB`' - ' must be boolean values') - raise InvalidSettingsError(msg) - - elif not is_test and conn_host.startswith('mongomock://'): - msg = ("`MongoMock` connection is only required for `unittest`." - "To enable this set `TESTING` to true`.") - raise InvalidURI(msg) - - elif not is_test and temp_db or preserved: - msg = ('`TESTING` and/or `TEMP_DB` can be used ' - 'only when `TESTING` is set to true.') - raise InvalidSettingsError(msg) - - -def __get_app_config(key): - return (_app_instance.get(key, False) - if isinstance(_app_instance, dict) - else _app_instance.config.get(key, False)) - - -def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): - global _connections - set_global_attributes() - - if reconnect: - disconnect(alias, _connection_settings.get('preserve_temp_db', False)) - - # Establish new connection unless - # already established - if alias not in _connections: - if alias not in _connection_settings: - msg = 'Connection with alias "%s" has not been defined' % alias - if alias == DEFAULT_CONNECTION_NAME: - msg = 'You have not defined a default connection' - raise ConnectionError(msg) - - conn_settings = _connection_settings[alias].copy() - conn_host = conn_settings['host'] - db_name = conn_settings.pop('name') - - is_test = __get_app_config('TESTING') - temp_db = __get_app_config('TEMP_DB') - preserved = __get_app_config('PRESERVE_TEMP_DB') - - # Validation - _validate_settings(is_test, temp_db, preserved, conn_host) - - # Obtain connection - if is_test: - connection_class = None - - if temp_db: - db_alias = conn_settings['alias'] - port = conn_settings['port'] - return _register_test_connection(port, db_alias, preserved) - - elif (conn_host.startswith('mongomock://') and - mongoengine.VERSION < (0, 10, 6)): - # Use MongoClient from mongomock - try: - import mongomock - except ImportError: - msg = 'You need mongomock installed to mock MongoEngine.' - raise RuntimeError(msg) - - # `mongomock://` is not a valid url prefix and - # must be replaced by `mongodb://` - conn_settings['host'] = \ - conn_host.replace('mongomock://', 'mongodb://', 1) - connection_class = mongomock.MongoClient - else: - # Let mongoengine handle the default - _connections[alias] = mongoengine.connect(db_name, **conn_settings) - else: - # Let mongoengine handle the default - _connections[alias] = mongoengine.connect(db_name, **conn_settings) - - try: - connection = None - - # check for shared connections - connection_settings_iterator = ( - (db_alias, settings.copy()) - for db_alias, settings in _connection_settings.items() - ) - for db_alias, connection_settings in connection_settings_iterator: - connection_settings.pop('name', None) - connection_settings.pop('username', None) - connection_settings.pop('password', None) - - if _connections.get(db_alias, None): - connection = _connections[db_alias] - break - - if connection: - _connections[alias] = connection - else: - if connection_class: - _connections[alias] = connection_class(**conn_settings) - - except Exception as e: - msg = "Cannot connect to database %s :\n%s" % (alias, e) - raise ConnectionError(msg) - - return mongoengine.connection.get_db(alias) - - -def _sys_exec(cmd, shell=True, env=None): - if env is None: - env = os.environ - - a = subprocess.Popen(cmd, shell=shell, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, env=env) - a.wait() # Wait for process to terminate - if a.returncode: # Not 0 => Error has occured - raise Exception(a.communicate()[1]) - return a.communicate()[0] - - -def set_global_attributes(): - setattr(connection, '_connection_settings', _connection_settings) - setattr(connection, '_connections', _connections) - setattr(connection, 'disconnect', disconnect) - - -def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): - set_global_attributes() - return connection.get_db(alias, reconnect) - - -def _register_test_connection(port, db_alias, preserved): - global _process, _tmpdir - - # Lets check MongoDB is installed locally - # before making connection to it - try: - found = _sys_exec("mongod --version") or False - except: - msg = 'You need `MongoDB` service installed on localhost'\ - ' to create a TEMP_DB instance.' - raise RuntimeError(msg) - - if found: - # TEMP_DB setting uses 27111 as - # default port - if not port or port == 27017: - port = 27111 - - _tmpdir = current_app.config.get('TEMP_DB_LOC', tempfile.mkdtemp()) - print("@@ TEMP_DB_LOC = %s" % _tmpdir) - print("@@ TEMP_DB port = %s" % str(port)) - print("@@ TEMP_DB host = localhost") - _conn = _connections.get(db_alias, None) - - if _conn is None: - _process = subprocess.Popen([ - 'mongod', '--bind_ip', 'localhost', - '--port', str(port), - '--dbpath', _tmpdir, - '--nojournal', '--nohttpinterface', - '--noauth', '--smallfiles', - '--syncdelay', '0', - '--maxConns', '10', - '--nssize', '1', ], - stdout=open(os.devnull, 'wb'), - stderr=subprocess.STDOUT) - atexit.register(disconnect, preserved=preserved) - - # wait for the instance db to be ready - # before opening a Connection. - for i in range(3): - time.sleep(0.1) - try: - _conn = MongoClient('localhost', port) - except errors.ConnectionFailure: - continue - else: - break - else: - msg = 'Cannot connect to the mongodb test instance' - raise mongoengine.ConnectionError(msg) - _connections[db_alias] = _conn - return _conn - - -def _resolve_settings(settings, settings_prefix=None, remove_pass=True): - - if settings and isinstance(settings, dict): - resolved_settings = dict() - for k, v in settings.items(): - if settings_prefix: - # Only resolve parameters that contain the prefix, ignoring the rest. - if k.startswith(settings_prefix): - resolved_settings[k[len(settings_prefix):].lower()] = v - else: - # If no prefix is provided then we assume that all parameters are relevant for the DB connection string. - resolved_settings[k.lower()] = v + # Remove the "MONGODB_" prefix and make all settings keys lower case. + resolved_settings = {} + for k, v in settings.items(): + if k.startswith('MONGODB_'): + k = k[len('MONGODB_'):] + k = k.lower() + resolved_settings[k] = v + + # Handle uri style connections + if "://" in resolved_settings.get('host', ''): + uri_dict = uri_parser.parse_uri(resolved_settings['host']) + resolved_settings['db'] = uri_dict['database'] + + # Add a default name param or use the "db" key if exists + if resolved_settings.get('db'): + resolved_settings['name'] = resolved_settings.pop('db') + else: + resolved_settings['name'] = 'test' - # Add various default values. - resolved_settings['alias'] = resolved_settings.get('alias', DEFAULT_CONNECTION_NAME) - if 'db' in resolved_settings: - resolved_settings['name'] = resolved_settings.pop('db') - else: - resolved_settings['name'] = 'test' + # Add various default values. + resolved_settings['alias'] = resolved_settings.get('alias', mongoengine.DEFAULT_CONNECTION_NAME) # TODO do we have to specify it here? MongoEngine should take care of that + resolved_settings['host'] = resolved_settings.get('host', 'localhost') # TODO this is the default host in pymongo.mongo_client.MongoClient, we may not need to explicitly set a default here + resolved_settings['port'] = resolved_settings.get('port', 27017) # TODO this is the default port in pymongo.mongo_client.MongoClient, we may not need to explicitly set a default here - resolved_settings['host'] = resolved_settings.get('host', 'localhost') - resolved_settings['port'] = resolved_settings.get('port', 27017) - resolved_settings['username'] = resolved_settings.get('username', None) + # Default to ReadPreference.PRIMARY if no read_preference is supplied + resolved_settings['read_preference'] = resolved_settings.get('read_preference', ReadPreference.PRIMARY) - # default to ReadPreference.PRIMARY if no read_preference is supplied - resolved_settings['read_preference'] = resolved_settings.get('read_preference', ReadPreference.PRIMARY) - if 'replicaset' in resolved_settings: - resolved_settings['replicaSet'] = resolved_settings.pop('replicaset') - if remove_pass: - try: - del resolved_settings['password'] - except KeyError: - # Password not specified, ignore. - pass + # Rename "replicaset" to "replicaSet" if it exists in the dict + # TODO is this necessary? PyMongo normalizes the options and makes them + # all lowercase via pymongo.common.validate (which is called in + # MongoClient.__init__), so both "replicaset and "replicaSet" should be + # valid + # if 'replicaset' in resolved_settings: + # resolved_settings['replicaSet'] = resolved_settings.pop('replicaset') - return resolved_settings + # Clean up empty values + for k, v in resolved_settings.items(): + if v is None: + del resolved_settings[k] - return settings + return resolved_settings -def fetch_connection_settings(config, remove_pass=True): +def get_connection_settings(config): """ - Fetch DB connection settings from FlaskMongoEngine - application instance configuration. For backward - compactibility reasons the settings name has not - been replaced. - - It has instead been mapped correctly - to avoid connection issues. - - @param config: FlaskMongoEngine instance config - - @param remove_pass: Flag to instruct the method to either - remove password or maintain as is. - By default a call to this method returns - settings without password. + Given a config dict, return a sanitized dict of MongoDB connection + settings that we can then use to establish connections. For new + applications, settings should exist in a "MONGODB_SETTINGS" key, but + for backward compactibility we also support several config keys + prefixed by "MONGODB_", e.g. "MONGODB_HOST", "MONGODB_PORT", etc. """ - # TODO why do we need remove_pass and why is the default True? - # this function is only used in this file (called with remove_pass=False) - # and in __init__.py (where it's passed to `disconnect`, which doesn't - # do anything password-related either...) - + # Sanitize all the settings living under a "MONGODB_SETTINGS" config var if 'MONGODB_SETTINGS' in config: settings = config['MONGODB_SETTINGS'] + + # If MONGODB_SETTINGS is a list of settings dicts, sanitize each + # dict separately. if isinstance(settings, list): # List of connection settings. settings_list = [] for setting in settings: - settings_list.append(_resolve_settings(setting, remove_pass=remove_pass)) + settings_list.append(_sanitize_settings(setting)) return settings_list + + # Otherwise, it should be a single dict describing a single connection. else: - # Connection settings provided as a dictionary. - return _resolve_settings(settings, remove_pass=remove_pass) + return _sanitize_settings(settings) + + # If "MONGODB_SETTINGS" doesn't exist, sanitize all the keys starting with + # "MONGODB_" as if they all describe a single connection. else: - # Connection settings provided in standard format. - return _resolve_settings(config, settings_prefix='MONGODB_', remove_pass=remove_pass) + config = dict((k, v) for k, v in config.items() if k.startswith('MONGODB_')) # ugly dict comprehention in order to support python 2.6 + return _sanitize_settings(config) -def create_connection(config, app): +def create_connections(config): """ - Connection is created based on application configuration - setting. Application settings which is enabled as TESTING - can submit MongoMock URI or enable TEMP_DB setting to provide - default temporary MongoDB instance on localhost for testing - purposes. This connection is initiated with a separate temporary - directory location. - - Unless PRESERVE_TEST_DB is setting is enabled in application - configuration, temporary MongoDB instance will be deleted when - application instance goes out of scope. - - Setting to request MongoMock instance connection: - >> app.config['TESTING'] = True - >> app.config['MONGODB_ALIAS'] = 'unittest' - >> app.config['MONGODB_HOST'] = 'mongo://localhost' - - Setting to request temporary localhost instance of MongoDB - connection: - >> app.config['TESTING'] = True - >> app.config['TEMP_DB'] = True - - To avoid temporary localhost instance of MongoDB been deleted - when application go out of scope: - >> app.config['PRESERVE_TEMP_DB'] = true - - You can specify the location of the temporary database instance - by setting TEMP_DB_LOC. If not specified, a default temp directory - location will be generated and used instead: - >> app.config['TEMP_DB_LOC'] = '/path/to/temp_dir/' - - @param config: Flask-MongoEngine application configuration. - @param app: instance of flask.Flask + Given Flask application's config dict, extract relevant config vars + out of it and establish MongoEngine connection(s) based on them. """ - global _connection_settings, _app_instance - _app_instance = app if app else config - + # Validate that the config is a dict if config is None or not isinstance(config, dict): - raise InvalidSettingsError("Invalid application configuration") + raise InvalidSettingsError('Invalid application configuration') - conn_settings = fetch_connection_settings(config, remove_pass=False) + # Get sanitized connection settings based on the config + conn_settings = get_connection_settings(config) - # if conn_settings is a list, set up each item as a separate connection + # If conn_settings is a list, set up each item as a separate connection + # and return a dict of connection aliases and their connections. if isinstance(conn_settings, list): connections = {} - for conn_setting in conn_settings: - alias = conn_setting['alias'] - _connection_settings[alias] = conn_setting - connections[alias] = get_connection(alias) + for each in conn_settings: + alias = each['alias'] + connections[alias] = _connect(each) return connections - else: - alias = conn_settings.get('alias', DEFAULT_CONNECTION_NAME) - _connection_settings[alias] = conn_settings - return get_connection(alias) + + # Otherwise, return a single connection + return _connect(conn_settings) + + +def _connect(conn_settings): + """Given a dict of connection settings, create a connection to + MongoDB by calling mongoengine.connect and return its result. + """ + db_name = conn_settings.pop('name') + return mongoengine.connect(db_name, **conn_settings) diff --git a/flask_mongoengine/json.py b/flask_mongoengine/json.py index 9fb81a59..558b706b 100644 --- a/flask_mongoengine/json.py +++ b/flask_mongoengine/json.py @@ -17,6 +17,8 @@ def default(self, obj): return json_util._json_convert(obj.as_pymongo()) return superclass.default(self, obj) return MongoEngineJSONEncoder + + MongoEngineJSONEncoder = _make_encoder(JSONEncoder) diff --git a/setup.cfg b/setup.cfg index a7191198..08e7a30e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ cover-package = flask_mongoengine tests = tests [flake8] -ignore=E501,F403,F405 +ignore=E501,F403,F405,I201 exclude=build,dist,docs,examples,venv,.tox,.eggs max-complexity=17 application-import-names=flask_mongoengine,tests diff --git a/setup.py b/setup.py index 435ae19b..caacc17c 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ except: pass -test_requirements = ['coverage', 'mongomock', 'nose', 'rednose'] +test_requirements = ['coverage', 'nose', 'rednose'] setup( name='flask-mongoengine', diff --git a/tests/__init__.py b/tests/__init__.py index 974ddf16..3928f3e4 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,7 +1,6 @@ import unittest import flask -from flask_mongoengine import current_mongoengine_instance class FlaskMongoEngineTestCase(unittest.TestCase): @@ -15,7 +14,4 @@ def setUp(self): self.ctx.push() def tearDown(self): - me_instance = current_mongoengine_instance() - if me_instance: - me_instance.disconnect() self.ctx.pop() diff --git a/tests/test_connection.py b/tests/test_connection.py index f98c2a10..b07c3167 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,58 +1,19 @@ -import mongoengine -import mongomock +from mongoengine.context_managers import switch_db import pymongo from pymongo.errors import InvalidURI +from pymongo.read_preferences import ReadPreference + +from flask_mongoengine import MongoEngine -from flask_mongoengine import InvalidSettingsError, MongoEngine from tests import FlaskMongoEngineTestCase class ConnectionTestCase(FlaskMongoEngineTestCase): - def ensure_mongomock_connection(self): - db = MongoEngine(self.app) - self.assertTrue(isinstance(db.connection.client, mongomock.MongoClient)) - - def test_mongomock_connection_request_on_most_recent_mongoengine(self): - self.app.config['TESTING'] = True - self.app.config['MONGODB_ALIAS'] = 'unittest_0' - self.app.config['MONGODB_HOST'] = 'mongomock://localhost' - - if mongoengine.VERSION >= (0, 10, 6): - self.ensure_mongomock_connection() - - def test_mongomock_connection_request_on_most_old_mongoengine(self): - self.app.config['TESTING'] = 'True' - self.assertRaises(InvalidSettingsError, MongoEngine, self.app) - - self.app.config['TESTING'] = True - self.app.config['MONGODB_ALIAS'] = 'unittest_1' - self.app.config['MONGODB_HOST'] = 'mongomock://localhost' - - if mongoengine.VERSION < (0, 10, 6): - self.ensure_mongomock_connection() - - def test_live_connection(self): - db = MongoEngine() - self.app.config['TEMP_DB'] = True - self.app.config['MONGODB_SETTINGS'] = { - 'HOST': 'localhost', - 'PORT': 27017, - 'USERNAME': None, - 'PASSWORD': None, - 'DB': 'test' - } - - self._do_persist(db) - - def test_uri_connection_string(self): - db = MongoEngine() - self.app.config['TEMP_DB'] = True - self.app.config['MONGO_URI'] = 'mongodb://localhost:27017/test_uri' - - self._do_persist(db) - def _do_persist(self, db): + """Initialize a test Flask application and persist some data in + MongoDB, ultimately asserting that the connection works. + """ class Todo(db.Document): title = db.StringField(max_length=60) text = db.StringField() @@ -71,22 +32,51 @@ class Todo(db.Document): f_to = Todo.objects().first() self.assertEqual(s_todo.title, f_to.title) + def test_simple_connection(self): + """Make sure a simple connection to a standalone MongoDB works.""" + db = MongoEngine() + self.app.config['MONGODB_SETTINGS'] = { + 'ALIAS': 'simple_conn', + 'HOST': 'localhost', + 'PORT': 27017, + 'DB': 'flask_mongoengine_test_db' + } + self._do_persist(db) + + def test_host_as_uri_string(self): + """Make sure we can connect to a standalone MongoDB if we specify + the host as a MongoDB URI. + """ + db = MongoEngine() + self.app.config['MONGODB_HOST'] = 'mongodb://localhost:27017/flask_mongoengine_test_db' + self._do_persist(db) + + def test_host_as_list(self): + """Make sure MONGODB_HOST can be a list hosts.""" + db = MongoEngine() + self.app.config['MONGODB_SETTINGS'] = { + 'ALIAS': 'host_list', + 'HOST': ['localhost:27017'], + } + self._do_persist(db) + def test_multiple_connections(self): + """Make sure establishing multiple connections to a standalone + MongoDB and switching between them works. + """ db = MongoEngine() - self.app.config['TESTING'] = True - self.app.config['TEMP_DB'] = True self.app.config['MONGODB_SETTINGS'] = [ { 'ALIAS': 'default', - 'DB': 'testing_db1', + 'DB': 'flask_mongoengine_test_db_1', 'HOST': 'localhost', 'PORT': 27017 }, { - "ALIAS": "testing_db2", - "DB": 'testing_db2', - "HOST": 'localhost', - "PORT": 27017 + 'ALIAS': 'alternative', + 'DB': 'flask_mongoengine_test_db_2', + 'HOST': 'localhost', + 'PORT': 27017 }, ] @@ -94,13 +84,12 @@ class Todo(db.Document): title = db.StringField(max_length=60) text = db.StringField() done = db.BooleanField(default=False) - meta = {"db_alias": "testing_db2"} + meta = {'db_alias': 'alternative'} db.init_app(self.app) Todo.drop_collection() - # Switch DB - from mongoengine.context_managers import switch_db + # Test saving a doc via the default connection with switch_db(Todo, 'default') as Todo: todo = Todo() todo.text = "Sample" @@ -111,40 +100,47 @@ class Todo(db.Document): f_to = Todo.objects().first() self.assertEqual(s_todo.title, f_to.title) - def test_mongodb_temp_instance(self): - # String value used instead of boolean - self.app.config['TESTING'] = True - self.app.config['TEMP_DB'] = 'True' - self.assertRaises(InvalidSettingsError, MongoEngine, self.app) + # Make sure the doc doesn't exist in the alternative db + with switch_db(Todo, 'alternative') as Todo: + doc = Todo.objects().first() + self.assertEqual(doc, None) - self.app.config['TEMP_DB'] = True - db = MongoEngine(self.app) - self.assertTrue(isinstance(db.connection, pymongo.MongoClient)) + # Make sure switching back to the default connection shows the doc + with switch_db(Todo, 'default') as Todo: + doc = Todo.objects().first() + self.assertNotEqual(doc, None) - def test_InvalidURI_exception_connections(self): - # Invalid URI - self.app.config['TESTING'] = True - self.app.config['MONGODB_ALIAS'] = 'unittest_2' + def test_connection_with_invalid_uri(self): + """Make sure connecting via an invalid URI raises an InvalidURI + exception. + """ self.app.config['MONGODB_HOST'] = 'mongo://localhost' self.assertRaises(InvalidURI, MongoEngine, self.app) - def test_parse_uri_if_testing_true_and_not_uses_mongomock_schema(self): - # TESTING is false but mongomock URI - self.app.config['TESTING'] = False - self.app.config['MONGODB_ALIAS'] = 'unittest_3' - self.app.config['MONGODB_HOST'] = 'mongomock://localhost' - self.assertRaises(InvalidURI, MongoEngine, self.app) + def test_connection_kwargs(self): + """Make sure additional connection kwargs work.""" - def test_temp_db_with_false_testing(self): - # TEMP_DB is set to true but testing is false - self.app.config['TESTING'] = False - self.app.config['TEMP_DB'] = True - self.app.config['MONGODB_ALIAS'] = 'unittest_4' - self.assertRaises(InvalidSettingsError, MongoEngine, self.app) + # Figure out whether to use "MAX_POOL_SIZE" or "MAXPOOLSIZE" based + # on PyMongo version (former was changed to the latter as described + # in https://jira.mongodb.org/browse/PYTHON-854) + # TODO remove once PyMongo < 3.0 support is dropped + if pymongo.version_tuple[0] >= 3: + MAX_POOL_SIZE_KEY = 'MAXPOOLSIZE' + else: + MAX_POOL_SIZE_KEY = 'MAX_POOL_SIZE' - def test_connection_kwargs(self): - self.app.config['MONGODB_SETTINGS'] = {'DB': 'testing_tz_aware', 'alias': 'tz_aware_true', 'TZ_AWARE': True} - self.app.config['TESTING'] = True + self.app.config['MONGODB_SETTINGS'] = { + 'ALIAS': 'tz_aware_true', + 'DB': 'flask_mongoengine_testing_tz_aware', + 'TZ_AWARE': True, + 'READ_PREFERENCE': ReadPreference.SECONDARY, + MAX_POOL_SIZE_KEY: 10, + } db = MongoEngine() db.init_app(self.app) - self.assertTrue(db.connection.client.codec_options.tz_aware) + self.assertTrue(db.connection.codec_options.tz_aware) + self.assertEqual(db.connection.max_pool_size, 10) + self.assertEqual( + db.connection.read_preference, + ReadPreference.SECONDARY + ) diff --git a/tests/test_session.py b/tests/test_session.py index 1f2bbf9b..250bf8c0 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -51,5 +51,6 @@ def test_setting_session(self): self.assertEqual(resp.status_code, 200) self.assertEquals(resp.data.decode('utf-8'), 'sessions: 1') + if __name__ == '__main__': unittest.main()