diff --git a/src/client.rs b/src/client.rs index d143425..f8ce7a0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,7 +13,9 @@ use bson::Document; use super::Result; use super::BsoncError; use super::bsonc::Bsonc; -use super::collection::{Collection,CreatedBy}; +use super::collection; +use super::collection::Collection; +use super::database; use super::database::Database; use super::uri::Uri; use super::read_prefs::ReadPrefs; @@ -233,30 +235,52 @@ pub struct Client<'a> { } impl<'a> Client<'a> { + /// Borrow a collection pub fn get_collection>, CT: Into>>(&'a self, db: DBT, collection: CT) -> Collection<'a> { assert!(!self.inner.is_null()); - let coll = unsafe { - let db_cstring = CString::new(db).unwrap(); - let collection_cstring = CString::new(collection).unwrap(); - bindings::mongoc_client_get_collection( - self.inner, - db_cstring.as_ptr(), - collection_cstring.as_ptr() - ) - }; - Collection::new(CreatedBy::Client(self), coll) + let coll = unsafe { self.collection_ptr(db.into(), collection.into()) }; + Collection::new(collection::CreatedBy::BorrowedClient(self), coll) + } + + /// Take a collection, client is owned by the collection so the collection can easily + /// be passed around + pub fn take_collection>, CT: Into>>(self, db: DBT, collection: CT) -> Collection<'a> { + assert!(!self.inner.is_null()); + let coll = unsafe { self.collection_ptr(db.into(), collection.into()) }; + Collection::new(collection::CreatedBy::OwnedClient(self), coll) + } + + unsafe fn collection_ptr(&self, db: Vec, collection: Vec) -> *mut bindings::mongoc_collection_t { + let db_cstring = CString::new(db).unwrap(); + let collection_cstring = CString::new(collection).unwrap(); + bindings::mongoc_client_get_collection( + self.inner, + db_cstring.as_ptr(), + collection_cstring.as_ptr() + ) } + /// Borrow a database pub fn get_database>>(&'a self, db: S) -> Database<'a> { assert!(!self.inner.is_null()); - let coll = unsafe { - let db_cstring = CString::new(db).unwrap(); - bindings::mongoc_client_get_database( - self.inner, - db_cstring.as_ptr() - ) - }; - Database::new(self, coll) + let coll = unsafe { self.database_ptr(db.into()) }; + Database::new(database::CreatedBy::BorrowedClient(self), coll) + } + + /// Take a database, client is owned by the database so the database can easily + /// be passed around + pub fn take_database>>(self, db: S) -> Database<'a> { + assert!(!self.inner.is_null()); + let coll = unsafe { self.database_ptr(db.into()) }; + Database::new(database::CreatedBy::OwnedClient(self), coll) + } + + unsafe fn database_ptr(&self, db: Vec) -> *mut bindings::mongoc_database_t { + let db_cstring = CString::new(db).unwrap(); + bindings::mongoc_client_get_database( + self.inner, + db_cstring.as_ptr() + ) } /// Queries the server for the current server status. diff --git a/src/collection.rs b/src/collection.rs index 81e837d..4351fbf 100644 --- a/src/collection.rs +++ b/src/collection.rs @@ -21,8 +21,10 @@ use super::write_concern::WriteConcern; use super::read_prefs::ReadPrefs; pub enum CreatedBy<'a> { - Client(&'a Client<'a>), - Database(&'a Database<'a>) + BorrowedClient(&'a Client<'a>), + OwnedClient(Client<'a>), + BorrowedDatabase(&'a Database<'a>), + OwnedDatabase(Database<'a>) } pub struct Collection<'a> { diff --git a/src/database.rs b/src/database.rs index 509bd2d..2563f9e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -16,20 +16,25 @@ use super::cursor; use super::cursor::Cursor; use flags::FlagsValue; +pub enum CreatedBy<'a> { + BorrowedClient(&'a Client<'a>), + OwnedClient(Client<'a>) +} + pub struct Database<'a> { - _client: &'a Client<'a>, + _created_by: CreatedBy<'a>, inner: *mut bindings::mongoc_database_t } impl<'a> Database<'a> { pub fn new( - client: &'a Client<'a>, + created_by: CreatedBy<'a>, inner: *mut bindings::mongoc_database_t ) -> Database<'a> { assert!(!inner.is_null()); Database { - _client: client, - inner: inner + _created_by: created_by, + inner: inner } } @@ -145,22 +150,33 @@ impl<'a> Database<'a> { }; if error.is_empty() { - Ok(Collection::new(collection::CreatedBy::Database(self), coll)) + Ok(Collection::new(collection::CreatedBy::BorrowedDatabase(self), coll)) } else { Err(error.into()) } } + /// Borrow a collection pub fn get_collection>>(&self, collection: S) -> Collection { assert!(!self.inner.is_null()); - let coll = unsafe { - let collection_cstring = CString::new(collection).unwrap(); - bindings::mongoc_database_get_collection( - self.inner, - collection_cstring.as_ptr() - ) - }; - Collection::new(collection::CreatedBy::Database(self), coll) + let coll = unsafe { self.collection_ptr(collection.into()) }; + Collection::new(collection::CreatedBy::BorrowedDatabase(self), coll) + } + + /// Take a collection, database is owned by the collection so the collection can easily + /// be passed around + pub fn take_collection>>(self, collection: S) -> Collection<'a> { + assert!(!self.inner.is_null()); + let coll = unsafe { self.collection_ptr(collection.into()) }; + Collection::new(collection::CreatedBy::OwnedDatabase(self), coll) + } + + unsafe fn collection_ptr(&self, collection: Vec) -> *mut bindings::mongoc_collection_t { + let collection_cstring = CString::new(collection).unwrap(); + bindings::mongoc_database_get_collection( + self.inner, + collection_cstring.as_ptr() + ) } pub fn get_name(&self) -> Cow { diff --git a/tests/client.rs b/tests/client.rs index 21a4f31..792f8d3 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -6,7 +6,7 @@ use mongo_driver::uri::Uri; use mongo_driver::client::{ClientPool,SslOptions}; #[test] -fn test_new_pool_and_pop_client() { +fn test_new_pool_pop_client_and_borrow_collection() { let uri = Uri::new("mongodb://localhost:27017/").unwrap(); let pool = ClientPool::new(uri.clone(), None); assert_eq!(pool.get_uri(), &uri); @@ -22,6 +22,32 @@ fn test_new_pool_and_pop_client() { assert_eq!("items", collection.get_name().to_mut()); } +#[test] +fn test_new_pool_pop_client_and_take_collection() { + let uri = Uri::new("mongodb://localhost:27017/").unwrap(); + let pool = ClientPool::new(uri.clone(), None); + assert_eq!(pool.get_uri(), &uri); + + // Pop a client and take collection + let client = pool.pop(); + let collection = client.take_collection("rust_test", "items"); + assert_eq!("items", collection.get_name().to_mut()); +} + +#[test] +fn test_new_pool_pop_client_and_take_database_and_collection() { + let uri = Uri::new("mongodb://localhost:27017/").unwrap(); + let pool = ClientPool::new(uri.clone(), None); + assert_eq!(pool.get_uri(), &uri); + + // Pop a client and take database and collection + let client = pool.pop(); + let database = client.take_database("rust_test"); + assert_eq!("rust_test", database.get_name().to_mut()); + let collection = database.take_collection("items"); + assert_eq!("items", collection.get_name().to_mut()); +} + #[test] fn test_new_pool_and_pop_client_in_threads() { let uri = Uri::new("mongodb://localhost:27017/").unwrap();