From 0a3f41bb0c33afcc0305527502920d4467e89974 Mon Sep 17 00:00:00 2001 From: Michael Jansen Date: Mon, 16 Apr 2018 17:23:11 -0400 Subject: [PATCH] Add database.command_batch which automatically calls getMore for find and append. --- Cargo.toml | 2 + src/cursor.rs | 122 +++++++++++++++++++++++++++++++++++++++++++++++- src/database.rs | 43 ++++++++++++++++- src/lib.rs | 5 ++ tests/cursor.rs | 46 ++++++++++++++++++ 5 files changed, 215 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 99fb54b..18dca7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ name = "tests" libc = "^0.2" log = "^0.3" bson = "^0.11" +serde = "1.0" +serde_derive = "1.0" [dependencies.mongoc-sys] path = "mongoc-sys" diff --git a/src/cursor.rs b/src/cursor.rs index 6439398..30b7955 100644 --- a/src/cursor.rs +++ b/src/cursor.rs @@ -4,9 +4,10 @@ use std::iter::Iterator; use std::ptr; use std::thread; use std::time::Duration; +use std::collections::VecDeque; use mongoc::bindings; -use bson::{Bson,Document,oid}; +use bson::{self,Bson,Document,oid}; use super::BsoncError; use super::bsonc; @@ -15,6 +16,7 @@ use super::database::Database; use super::flags::QueryFlag; use super::collection::{Collection,TailOptions}; use super::CommandAndFindOptions; +use super::MongoError::ValueAccessError; use super::Result; @@ -251,3 +253,121 @@ impl<'a> Iterator for TailingCursor<'a> { } } } + +type DocArray = VecDeque; +type CursorId = i64; + +pub struct BatchCursor<'a> { + cursor: Cursor<'a>, + db: &'a Database<'a>, + coll_name: String, + cursor_id: Option, + documents: Option + +} + +impl<'a> BatchCursor<'a> { + pub fn new( + cursor: Cursor<'a>, + db: &'a Database<'a>, + coll_name: String + ) -> BatchCursor<'a> { + BatchCursor { + cursor, + db, + coll_name, + cursor_id: None, + documents: None + } + } + + fn get_cursor_next(&mut self) -> Option> { + let item_opt = self.cursor.next(); + if let Some(item_res) = item_opt { + if let Ok(item) = item_res { + let docs_ret = batch_to_array(item); + if let Ok(docs) = docs_ret { + self.documents = docs.0; + if docs.1.is_some() {self.cursor_id = docs.1} + let res = self.get_next_doc(); + if res.is_some() { return res; } + } else { + return Some(Err(docs_ret.err().unwrap())); + } + } + } + None + } + + fn get_next_doc(&mut self) -> Option> { + if let Some(ref mut docs) = self.documents { + if docs.len() > 0 { + let doc = docs.pop_front().unwrap(); + return Some(Ok(doc)); + } + } + None + } +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +struct CommandSimpleBatch { + id: CursorId, + first_batch: Option, + next_batch: Option +} +#[derive(Deserialize, Debug)] +struct CommandSimpleResult { + cursor: CommandSimpleBatch +} + +fn batch_to_array(doc: Document) -> Result<(Option,Option)> { + let doc_result: Result = + bson::from_bson(Bson::Document(doc.clone())) + .map_err(|err| + { + error!("cannot read batch from db: {}", err); + ValueAccessError(bson::ValueAccessError::NotPresent) + }); + + trace!("input: {}, result: {:?}", doc, doc_result); + + doc_result.map(|v| { + if v.cursor.first_batch.is_some() {return (v.cursor.first_batch, Some(v.cursor.id));} + if v.cursor.next_batch.is_some() {return (v.cursor.next_batch, Some(v.cursor.id));} + (None,None) + }) +} + +impl<'a> Iterator for BatchCursor<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + + // (1) try the local document buffer + let res = self.get_next_doc(); + if res.is_some() {return res;} + + // (2) try next() + let res = self.get_cursor_next(); + if res.is_some() {return res;} + + // (3) try getMore + if let Some(cid) = self.cursor_id { + let command = doc! { + "getMore": cid as i64, + "collection": self.coll_name.clone() + }; + let cur_result = self.db.command(command, None); + if let Ok(cur) = cur_result { + self.cursor = cur; + let res = self.get_cursor_next(); + if res.is_some() { return res; } + } + } + None + } + + +} \ No newline at end of file diff --git a/src/database.rs b/src/database.rs index c037c53..501f75b 100644 --- a/src/database.rs +++ b/src/database.rs @@ -16,6 +16,7 @@ use super::collection; use super::collection::Collection; use super::cursor; use super::cursor::Cursor; +use super::cursor::BatchCursor; use super::read_prefs::ReadPrefs; use flags::FlagsValue; @@ -25,6 +26,17 @@ pub enum CreatedBy<'a> { OwnedClient(Client<'a>) } +#[doc(hidden)] +fn get_coll_name_from_doc(doc: &Document) -> Result { + const VALID_COMMANDS: &'static [&'static str] = &["find", "aggregate"]; + for s in VALID_COMMANDS { + if let Ok(val) = doc.get_str(s) { + return Ok(val.to_owned()) + } + } + Err(InvalidParamsError.into()) +} + /// Provides access to a MongoDB database. /// /// A database instance can be created by calling `get_database` or `take_database` on a `Client` instance. @@ -58,8 +70,8 @@ impl<'a> Database<'a> { assert!(!self.inner.is_null()); let default_options = CommandAndFindOptions::default(); - let options = options.unwrap_or(&default_options); - let fields_bsonc = options.fields_bsonc(); + let options = options.unwrap_or(&default_options); + let fields_bsonc = options.fields_bsonc(); let cursor_ptr = unsafe { bindings::mongoc_database_command( @@ -91,6 +103,23 @@ impl<'a> Database<'a> { )) } + /// Execute a command on the database. + /// Automates the process of getting the next batch from getMore + /// and parses the batch so only the result documents are returned. + /// I am unsure of the best practices of when to use this or the CRUD function. + pub fn command_batch( + &'a self, + command: Document, + options: Option<&CommandAndFindOptions> + ) -> Result> { + let coll_name = get_coll_name_from_doc(&command)?; + Ok(BatchCursor::new( + self.command(command, options)?, + self, + coll_name + )) + } + /// Simplified version of `command` that returns the first document immediately. pub fn command_simple( &'a self, @@ -226,3 +255,13 @@ impl<'a> Drop for Database<'a> { } } } + +#[test] +fn test_get_coll_name_from_doc() { + let command = doc! {"find": "cursor_items"}; + assert_eq!("cursor_items", get_coll_name_from_doc(&command).unwrap()); + let command = doc! {"aggregate": "cursor_items"}; + assert_eq!("cursor_items", get_coll_name_from_doc(&command).unwrap()); + let command = doc! {"error": "cursor_items"}; + assert!(get_coll_name_from_doc(&command).is_err()); +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 4c1383a..3343ab3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,6 +31,11 @@ extern crate bson; #[macro_use] extern crate log; +#[macro_use] +extern crate serde_derive; +extern crate serde; + + use std::ffi::CStr; use std::ptr; use std::result; diff --git a/tests/cursor.rs b/tests/cursor.rs index f7b80d0..ea837d4 100644 --- a/tests/cursor.rs +++ b/tests/cursor.rs @@ -87,3 +87,49 @@ fn test_tailing_cursor() { // 15 results. assert_eq!(25, guard.join().expect("Thread failed")); } + +#[test] +fn test_batch_cursor() { + let uri = Uri::new("mongodb://localhost:27017/").unwrap(); + let pool = Arc::new(ClientPool::new(uri, None)); + let client = pool.pop(); + let database = client.get_database("rust_test"); + + const TEST_COLLECTION_NAME: &str = "test_batch_cursor"; + const NUM_TO_TEST: i32 = 10000; + + let mut collection = database.get_collection(TEST_COLLECTION_NAME); + if database.has_collection(TEST_COLLECTION_NAME).unwrap() { + collection.drop().unwrap(); // if prev test failed the old collection may still exist + } + + // add test rows. need many to exercise the batches + { + let bulk_operation = collection.create_bulk_operation(None); + + for i in 0..NUM_TO_TEST { + bulk_operation.insert(&doc!{"key": i}).unwrap(); + } + + let result = bulk_operation.execute(); + assert!(result.is_ok()); + + assert_eq!( + result.ok().unwrap().get("nInserted").unwrap(), // why is this an i32? + &bson::Bson::I32(NUM_TO_TEST) + ); + assert_eq!(NUM_TO_TEST as i64, collection.count(&doc!{}, None).unwrap()); + } + + { + let cur = database.command_batch(doc!{"find":TEST_COLLECTION_NAME},None); + let mut count = 0; + for doc in cur.unwrap() { + count += 1; + println!("doc: {:?}", doc ); + } + assert_eq!(count,NUM_TO_TEST); + } + + collection.drop().unwrap(); +} \ No newline at end of file