diff --git a/CHANGELOG.md b/CHANGELOG.md index 24dd878..d0d1ecd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ All changes in this project will be noted in this file. +### 0.8.7 (unreleased) + +> **BREAKING PATCH DUE TO MINIMUM VERSION UPGRADE** +> - **Minimum Supported Skytable Version**: 0.8.2 +> - **Field change warnings**: +> - The `Config` struct now has one additional field. This is not a breaking change because the functionality of the library remains unchanged +- Added support for pipelines + ### 0.8.6 Reduced allocations in `Query`. diff --git a/README.md b/README.md index 34cc013..160a8ed 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## Introduction -This library is the official client for the free and open-source NoSQL database [Skytable](https://github.com/skytable/skytable). First, go ahead and install Skytable by following the instructions [here](https://docs.skytable.io/getting-started). This library supports all Skytable versions that work with the [Skyhash 2 Protocol](https://docs.skytable.io/protocol/overview). This version of the library was tested with the latest Skytable release (release [0.8.0-beta](https://github.com/skytable/skytable/releases/v0.8.0-beta)). +This library is the official client for the free and open-source NoSQL database [Skytable](https://github.com/skytable/skytable). First, go ahead and install Skytable by following the instructions [here](https://docs.skytable.io/getting-started). This library supports all Skytable versions that work with the [Skyhash 2 Protocol](https://docs.skytable.io/protocol/overview). This version of the library was tested with the latest Skytable release (release [0.8.1](https://github.com/skytable/skytable/releases/v0.8.1)). [Read more about supported versions here](#version-support). ## Definitive example @@ -53,6 +53,12 @@ assert_eq!(user, our_user); > **Read [docs here to learn BlueQL](https://docs.skytable.io/)** + +## Version support + +- Minimum Supported Rust Version (MSRV): 1.51.0 +- Minimum Supported Skytable Version: 0.8.0 + ## Features - Sync API @@ -64,8 +70,8 @@ assert_eq!(user, our_user); ## Contributing -Open-source, and contributions ... — they're always welcome! For ideas and suggestions, [create an issue on GitHub](https://github.com/skytable/client-rust/issues/new) and for patches, fork and open those pull requests [here](https://github.com/skytable/client-rust)! +Contributions are always welcome. To submit patches please fork this repository and submit a pull request. If you find any bugs, [please open an issue here](https://github.com/skytable/client-rust/issues/new). ## License -This client library is distributed under the permissive [Apache-2.0 License](https://github.com/skytable/client-rust/blob/next/LICENSE). Now go build great apps! +This library is distributed under the [Apache-2.0 License](https://github.com/skytable/client-rust/blob/next/LICENSE). diff --git a/src/config.rs b/src/config.rs index 154bd9f..51f3078 100644 --- a/src/config.rs +++ b/src/config.rs @@ -15,21 +15,23 @@ */ //! # Configuration -//! +//! //! This module provides items to help with database connection setup and configuration. -//! +//! //! ## Example -//! +//! //! ```no_run //! use skytable::Config; -//! +//! //! // establish a sync connection to 127.0.0.1:2003 //! let mut db = Config::new_default("username", "password").connect().unwrap(); -//! +//! //! // establish a connection to a specific host `subnetx2_db1` and port `2008` //! let mut db = Config::new("subnetx2_db1", 2008, "username", "password").connect().unwrap(); //! ``` +use crate::protocol::handshake::ProtocolVersion; + /// The default host /// /// NOTE: If you are using a clustering setup, don't use this! @@ -46,21 +48,40 @@ pub struct Config { port: u16, username: Box, password: Box, + pub(crate) protocol: ProtocolVersion, } impl Config { + fn _new( + host: Box, + port: u16, + username: Box, + password: Box, + protocol: ProtocolVersion, + ) -> Self { + Self { + host, + port, + username, + password, + protocol, + } + } /// Create a new [`Config`] using the default connection settings and using the provided username and password pub fn new_default(username: &str, password: &str) -> Self { Self::new(DEFAULT_HOST, DEFAULT_TCP_PORT, username, password) } - /// Create a new [`Config`] using the given settings + /// Create a new [`Config`] using the given settings. + /// + /// **PROTOCOL VERSION**: Defaults to [`ProtocolVersion::V2_0`] pub fn new(host: &str, port: u16, username: &str, password: &str) -> Self { - Self { - host: host.into(), + Self::_new( + host.into(), port, - username: username.into(), - password: password.into(), - } + username.into(), + password.into(), + ProtocolVersion::V2_0, + ) } /// Returns the host setting for this this configuration pub fn host(&self) -> &str { diff --git a/src/error.rs b/src/error.rs index 47202cf..0761dd9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -109,6 +109,7 @@ impl fmt::Display for ProtocolError { Self::InvalidServerResponseUnknownDataType => { write!(f, "new or unknown data type received from server") } + Self::InvalidPacket => write!(f, "invalid packet received from server"), } } } diff --git a/src/aio.rs b/src/io/aio.rs similarity index 72% rename from src/aio.rs rename to src/io/aio.rs index e92b3b2..86e1193 100644 --- a/src/aio.rs +++ b/src/io/aio.rs @@ -24,7 +24,12 @@ use { crate::{ error::{ClientResult, ConnectionSetupError, Error}, - protocol::{ClientHandshake, DecodeState, Decoder, RState, ServerHandshake}, + protocol::{ + handshake::{ClientHandshake, ServerHandshake}, + state_init::{DecodeState, MRespState, PipelineResult, RState}, + Decoder, + }, + query::Pipeline, response::{FromResponse, Response}, Config, Query, }, @@ -80,17 +85,12 @@ impl DerefMut for ConnectionTlsAsync { impl Config { /// Establish an async connection to the database using the current configuration pub async fn connect_async(&self) -> ClientResult { - let mut tcpstream = TcpStream::connect((self.host(), self.port())).await?; - let handshake = ClientHandshake::new(self); - tcpstream.write_all(handshake.inner()).await?; - let mut resp = [0u8; 4]; - tcpstream.read_exact(&mut resp).await?; - match ServerHandshake::parse(resp)? { - ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), - ServerHandshake::Okay(_suggestion) => { - return Ok(ConnectionAsync(TcpConnection::new(tcpstream))) - } - } + TcpStream::connect((self.host(), self.port())) + .await + .map(TcpConnection::new)? + ._handshake(self) + .await + .map(ConnectionAsync) } /// Establish an async TLS connection to the database using the current configuration. /// Pass the certificate in PEM format. @@ -110,22 +110,15 @@ impl Config { let connector = builder.build().map_err(|e| { ConnectionSetupError::Other(format!("failed to set up TLS acceptor: {e}")) })?; - // init - let mut stream = TlsConnector::from(connector) + // init and handshake + TlsConnector::from(connector) .connect(self.host(), stream) .await - .map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))?; - // handshake - let handshake = ClientHandshake::new(self); - stream.write_all(handshake.inner()).await?; - let mut resp = [0u8; 4]; - stream.read_exact(&mut resp).await?; - match ServerHandshake::parse(resp)? { - ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), - ServerHandshake::Okay(_suggestion) => { - return Ok(ConnectionTlsAsync(TcpConnection::new(stream))) - } - } + .map(TcpConnection::new) + .map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))? + ._handshake(self) + .await + .map(ConnectionTlsAsync) } } @@ -143,6 +136,49 @@ impl TcpConnection { buf: Vec::with_capacity(crate::BUFSIZE), } } + async fn _handshake(mut self, cfg: &Config) -> ClientResult { + let handshake = ClientHandshake::new(cfg); + self.con.write_all(handshake.inner()).await?; + let mut resp = [0u8; 4]; + self.con.read_exact(&mut resp).await?; + match ServerHandshake::parse(resp)? { + ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), + ServerHandshake::Okay(_suggestion) => return Ok(self), + } + } + /// Execute a pipeline. The server returns the queries in the order they were sent (unless otherwise set). + pub async fn execute_pipeline(&mut self, pipeline: &Pipeline) -> ClientResult> { + self.buf.clear(); + self.buf.push(b'P'); + // packet size + self.buf + .extend(itoa::Buffer::new().format(pipeline.buf().len()).as_bytes()); + self.buf.push(b'\n'); + // write + self.con.write_all(&self.buf).await?; + self.con.write_all(pipeline.buf()).await?; + self.buf.clear(); + // read + let mut cursor = 0; + let mut state = MRespState::default(); + loop { + let mut buf = [0u8; crate::BUFSIZE]; + let n = self.con.read(&mut buf).await?; + if n == 0 { + return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into())); + } + self.buf.extend_from_slice(&buf[..n]); + let mut decoder = Decoder::new(&self.buf, cursor); + match decoder.validate_pipe(pipeline.query_count(), state) { + PipelineResult::Completed(r) => return Ok(r), + PipelineResult::Pending(_state) => { + cursor = decoder.position(); + state = _state; + } + PipelineResult::Error(e) => return Err(e.into()), + } + } + } /// Run a query and return a raw [`Response`] pub async fn query(&mut self, q: &Query) -> ClientResult { self.buf.clear(); diff --git a/src/io/mod.rs b/src/io/mod.rs new file mode 100644 index 0000000..306b1b6 --- /dev/null +++ b/src/io/mod.rs @@ -0,0 +1,18 @@ +/* + * Copyright 2023, Sayan Nandan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +pub mod aio; +pub mod sync; diff --git a/src/syncio.rs b/src/io/sync.rs similarity index 66% rename from src/syncio.rs rename to src/io/sync.rs index cb18ccf..226084e 100644 --- a/src/syncio.rs +++ b/src/io/sync.rs @@ -26,7 +26,12 @@ use { crate::{ config::Config, error::{ClientResult, ConnectionSetupError, Error}, - protocol::{ClientHandshake, DecodeState, Decoder, RState, ServerHandshake}, + protocol::{ + handshake::{ClientHandshake, ServerHandshake}, + state_init::{DecodeState, MRespState, PipelineResult, RState}, + Decoder, + }, + query::Pipeline, response::{FromResponse, Response}, Query, }, @@ -81,23 +86,17 @@ impl DerefMut for ConnectionTls { impl Config { /// Establish a connection to the database using the current configuration pub fn connect(&self) -> ClientResult { - let mut tcpstream = TcpStream::connect((self.host(), self.port()))?; - let handshake = ClientHandshake::new(self); - tcpstream.write_all(handshake.inner())?; - let mut resp = [0u8; 4]; - tcpstream.read_exact(&mut resp)?; - match ServerHandshake::parse(resp)? { - ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), - ServerHandshake::Okay(_suggestion) => { - return Ok(Connection(TcpConnection::new(tcpstream))) - } - } + TcpStream::connect((self.host(), self.port())) + .map(TcpConnection::new)? + ._handshake(self) + .map(Connection) } /// Establish a TLS connection to the database using the current configuration. /// Pass the certificate in PEM format. pub fn connect_tls(&self, cert: &str) -> ClientResult { let stream = TcpStream::connect((self.host(), self.port()))?; - let mut stream = TlsConnector::builder() + TlsConnector::builder() + // build TLS connector .add_root_certificate(Certificate::from_pem(cert.as_bytes()).map_err(|e| { ConnectionSetupError::Other(format!("failed to parse certificate: {e}")) })?) @@ -106,18 +105,13 @@ impl Config { .map_err(|e| { ConnectionSetupError::Other(format!("failed to set up TLS acceptor: {e}")) })? + // connect .connect(self.host(), stream) - .map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}")))?; - let handshake = ClientHandshake::new(self); - stream.write_all(handshake.inner())?; - let mut resp = [0u8; 4]; - stream.read_exact(&mut resp)?; - match ServerHandshake::parse(resp)? { - ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), - ServerHandshake::Okay(_suggestion) => { - return Ok(ConnectionTls(TcpConnection::new(stream))) - } - } + .map_err(|e| ConnectionSetupError::Other(format!("TLS handshake failed: {e}"))) + .map(TcpConnection::new)? + // handshake + ._handshake(self) + .map(ConnectionTls) } } @@ -127,22 +121,65 @@ impl Config { /// This can't be constructed directly! pub struct TcpConnection { con: C, - buffer: Vec, + buf: Vec, } impl TcpConnection { fn new(con: C) -> Self { Self { con, - buffer: Vec::with_capacity(crate::BUFSIZE), + buf: Vec::with_capacity(crate::BUFSIZE), + } + } + fn _handshake(mut self, cfg: &Config) -> ClientResult { + let handshake = ClientHandshake::new(cfg); + self.con.write_all(handshake.inner())?; + let mut resp = [0u8; 4]; + self.con.read_exact(&mut resp)?; + match ServerHandshake::parse(resp)? { + ServerHandshake::Error(e) => return Err(ConnectionSetupError::HandshakeError(e).into()), + ServerHandshake::Okay(_suggestion) => return Ok(self), + } + } + /// Execute a pipeline. The server returns the queries in the order they were sent (unless otherwise set). + pub fn execute_pipeline(&mut self, pipeline: &Pipeline) -> ClientResult> { + self.buf.clear(); + self.buf.push(b'P'); + // packet size + self.buf + .extend(itoa::Buffer::new().format(pipeline.buf().len()).as_bytes()); + self.buf.push(b'\n'); + // write + self.con.write_all(&self.buf)?; + self.con.write_all(pipeline.buf())?; + self.buf.clear(); + // read + let mut cursor = 0; + let mut state = MRespState::default(); + loop { + let mut buf = [0u8; crate::BUFSIZE]; + let n = self.con.read(&mut buf)?; + if n == 0 { + return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into())); + } + self.buf.extend_from_slice(&buf[..n]); + let mut decoder = Decoder::new(&self.buf, cursor); + match decoder.validate_pipe(pipeline.query_count(), state) { + PipelineResult::Completed(r) => return Ok(r), + PipelineResult::Pending(_state) => { + cursor = decoder.position(); + state = _state; + } + PipelineResult::Error(e) => return Err(e.into()), + } } } /// Run a query and return a raw [`Response`] pub fn query(&mut self, q: &Query) -> ClientResult { - self.buffer.clear(); - q.write_packet(&mut self.buffer).unwrap(); - self.con.write_all(&self.buffer)?; - self.buffer.clear(); + self.buf.clear(); + q.write_packet(&mut self.buf).unwrap(); + self.con.write_all(&self.buf)?; + self.buf.clear(); let mut state = RState::default(); let mut cursor = 0; loop { @@ -151,8 +188,8 @@ impl TcpConnection { if n == 0 { return Err(Error::IoError(std::io::ErrorKind::ConnectionReset.into())); } - self.buffer.extend_from_slice(&buf[..n]); - let mut decoder = Decoder::new(&self.buffer, cursor); + self.buf.extend_from_slice(&buf[..n]); + let mut decoder = Decoder::new(&self.buf, cursor); match decoder.validate_response(state) { DecodeState::ChangeState(new_state) => { state = new_state; @@ -171,6 +208,6 @@ impl TcpConnection { /// Call this if the internally allocated buffer is growing too large and impacting your performance. However, normally /// you will not need to call this pub fn reset_buffer(&mut self) { - self.buffer.shrink_to_fit() + self.buf.shrink_to_fit() } } diff --git a/src/lib.rs b/src/lib.rs index 0d2ee39..b559c0c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,25 +113,27 @@ mod macros; mod protocol; // public modules -pub mod aio; pub mod config; pub mod error; pub mod pool; pub mod query; pub mod response; -pub mod syncio; /// The `Query` derive macro enables you to directly pass complex types as parameters into queries pub use sky_derive::Query; /// The `Response` derive macro enables you to directly pass complex types as parameters into queries pub use sky_derive::Response; // re-exports pub use { - aio::{ConnectionAsync, ConnectionTlsAsync}, config::Config, error::ClientResult, - query::Query, - syncio::{Connection, ConnectionTls}, + io::{ + aio::{self, ConnectionAsync, ConnectionTlsAsync}, + sync::{self as syncio, Connection, ConnectionTls}, + }, + query::{Pipeline, Query}, }; +// private +mod io; /// we use a 8KB read buffer by default; allow this to be changed const BUFSIZE: usize = 8 * 1024; diff --git a/src/macros.rs b/src/macros.rs index 6e99130..0f12657 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -16,14 +16,14 @@ #[macro_export] /// This macro can be used to create a [`Query`](struct@crate::Query), almost like a variadic function -/// +/// /// ## Examples /// ``` /// use skytable::query; -/// +/// /// fn get_username() -> String { "sayan".to_owned() } /// fn get_counter() -> u64 { 100 } -/// +/// /// let query1 = query!("select * from myspace.mymodel WHERE username = ?", get_username()); /// assert_eq!(query1.param_cnt(), 1); /// let query2 = query!("update myspace.mymodel set counter += ? WHERE username = ?", get_counter(), get_username()); diff --git a/src/protocol/handshake.rs b/src/protocol/handshake.rs new file mode 100644 index 0000000..35a02e7 --- /dev/null +++ b/src/protocol/handshake.rs @@ -0,0 +1,74 @@ +/* + * Copyright 2024, Sayan Nandan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +use crate::{ + error::{ConnectionSetupError, Error}, + ClientResult, Config, +}; + +#[derive(Debug, PartialEq, Clone, Copy)] +#[repr(u8)] +/// The Skyhash protocol version +pub(crate) enum ProtocolVersion { + /// Skyhash 2.0 + V2_0, +} + +impl ProtocolVersion { + pub(crate) const fn hs_block(&self) -> [u8; 6] { + match self { + Self::V2_0 => [b'H', 0, 0, 0, 0, 0], + } + } +} + +pub struct ClientHandshake(Box<[u8]>); +impl ClientHandshake { + pub(crate) fn new(cfg: &Config) -> Self { + Self::_new(cfg.protocol.hs_block(), cfg) + } + fn _new(hs: [u8; 6], cfg: &Config) -> Self { + let mut v = Vec::with_capacity(6 + cfg.username().len() + cfg.password().len() + 5); + v.extend(hs); + pushlen!(v, cfg.username().len()); + pushlen!(v, cfg.password().len()); + v.extend(cfg.username().as_bytes()); + v.extend(cfg.password().as_bytes()); + Self(v.into_boxed_slice()) + } + pub(crate) fn inner(&self) -> &[u8] { + &self.0 + } +} + +#[derive(Debug)] +pub enum ServerHandshake { + Okay(u8), + Error(u8), +} +impl ServerHandshake { + pub fn parse(v: [u8; 4]) -> ClientResult { + Ok(match v { + [b'H', 0, 0, msg] => Self::Okay(msg), + [b'H', 0, 1, msg] => Self::Error(msg), + _ => { + return Err(Error::ConnectionSetupErr( + ConnectionSetupError::InvalidServerHandshake, + )) + } + }) + } +} diff --git a/src/protocol.rs b/src/protocol/mod.rs similarity index 69% rename from src/protocol.rs rename to src/protocol/mod.rs index 5140d75..8756c33 100644 --- a/src/protocol.rs +++ b/src/protocol/mod.rs @@ -14,12 +14,25 @@ * limitations under the License. */ -use crate::{ - config::Config, - error::{ClientResult, ConnectionSetupError, Error}, - response::{Response, Row, Value}, +pub mod handshake; +mod pipe; +mod state; + +use { + self::state::{ + DecodeState, MetaState, MultiRowState, PendingValue, RState, ResponseState, RowState, + ValueDecodeState, ValueDecodeStateAny, ValueDecodeStateRaw, ValueState, ValueStateMeta, + }, + crate::response::{Response, Row, Value}, }; +pub mod state_init { + pub(crate) use super::{ + pipe::{MRespState, PipelineResult}, + state::{DecodeState, RState}, + }; +} + pub(crate) type ProtocolResult = Result; /// Errors that can happen when handling protocol level encoding and decoding @@ -30,6 +43,7 @@ pub enum ProtocolError { /// The server possibly returned an unknown data type and we can't decode it. Note that this might happen when you use an older client version with /// a newer version of Skytable InvalidServerResponseUnknownDataType, + InvalidPacket, } impl Value { @@ -41,119 +55,6 @@ impl Value { } } -/* - Decode state management -*/ - -type ValueDecodeStateRaw = ValueDecodeStateAny; -type ValueDecodeState = ValueDecodeStateAny; - -#[derive(Debug, PartialEq)] -enum ValueDecodeStateAny { - Pending(P), - Decoded(V), -} - -#[derive(Debug, PartialEq)] -struct ValueState { - v: Value, - meta: ValueStateMeta, -} - -impl ValueState { - fn new(v: Value, meta: ValueStateMeta) -> Self { - Self { v, meta } - } -} - -#[derive(Debug, PartialEq)] -struct ValueStateMeta { - start: usize, - md1: u64, - md1_flag: bool, -} - -impl ValueStateMeta { - fn zero() -> Self { - Self { - start: 0, - md1: 0, - md1_flag: false, - } - } - fn new(start: usize, md1: u64, md1_flag: bool) -> Self { - Self { - start, - md1, - md1_flag, - } - } -} - -#[derive(Debug, PartialEq)] -struct RowState { - meta: ValueStateMeta, - row: Vec, - tmp: Option, -} - -impl RowState { - fn new(meta: ValueStateMeta, row: Vec, tmp: Option) -> Self { - Self { meta, row, tmp } - } -} - -#[derive(Debug, PartialEq)] -struct MultiRowState { - c_row: Option, - rows: Vec, - md_state: u8, - md1_target: u64, - md2_col_cnt: u64, -} - -impl Default for MultiRowState { - fn default() -> Self { - Self::new(None, vec![], 0, 0, 0) - } -} - -impl MultiRowState { - fn new(c_row: Option, rows: Vec, md_s: u8, md_cnt: u64, md_target: u64) -> Self { - Self { - c_row, - rows, - md_state: md_s, - md1_target: md_target, - md2_col_cnt: md_cnt, - } - } -} - -#[derive(Debug, PartialEq)] -enum ResponseState { - Initial, - PValue(PendingValue), - PError, - PRow(RowState), - PMultiRow(MultiRowState), -} - -#[derive(Debug, PartialEq)] -pub enum DecodeState { - ChangeState(RState), - Completed(Response), - Error(ProtocolError), -} - -#[derive(Debug, PartialEq)] -pub struct RState(ResponseState); -impl Default for RState { - fn default() -> Self { - RState(ResponseState::Initial) - } -} - /* Decoder */ @@ -215,23 +116,14 @@ impl<'a> Decoder<'a> { } } fn resume_row(&mut self, mut row_state: RowState) -> DecodeState { - if !row_state.meta.md1_flag { - match self.__resume_decode(row_state.meta.md1, ValueStateMeta::zero()) { - Ok(ValueDecodeStateAny::Pending(ValueState { v, .. })) => { - row_state.meta.md1 = v.u64(); - return DecodeState::ChangeState(RState(ResponseState::PRow(row_state))); - } - Ok(ValueDecodeStateAny::Decoded(v)) => { - row_state.meta.md1 = v.u64(); - row_state.meta.md1_flag = true; - } - Err(e) => return DecodeState::Error(e), - } + match row_state.meta.md.finished(self) { + Ok(true) => self._decode_row_core(row_state), + Ok(false) => DecodeState::ChangeState(RState(ResponseState::PRow(row_state))), + Err(e) => DecodeState::Error(e), } - self._decode_row_core(row_state) } fn _decode_row_core(&mut self, mut row_state: RowState) -> DecodeState { - while row_state.row.len() as u64 != row_state.meta.md1 { + while row_state.row.len() as u64 != row_state.meta.md.val() { let r = match row_state.tmp.take() { None => { if self._cursor_eof() { @@ -262,32 +154,19 @@ impl<'a> Decoder<'a> { DecodeState::Completed(Response::Row(Row::new(row_state.row))) } fn resume_rows(&mut self, mut multirow: MultiRowState) -> DecodeState { - if multirow.md_state == 0 { - match self.__resume_decode(multirow.md1_target, ValueStateMeta::zero()) { - Ok(ValueDecodeStateAny::Pending(ValueState { v, .. })) => { - multirow.md1_target = v.u64(); - return DecodeState::ChangeState(RState(ResponseState::PMultiRow(multirow))); - } - Ok(ValueDecodeStateAny::Decoded(v)) => { - multirow.md1_target = v.u64(); - multirow.md_state += 1; - } - Err(e) => return DecodeState::Error(e), - } - } - if multirow.md_state == 1 { - match self.__resume_decode(multirow.md2_col_cnt, ValueStateMeta::zero()) { - Ok(ValueDecodeStateAny::Pending(ValueState { v, .. })) => { - multirow.md2_col_cnt = v.u64(); - return DecodeState::ChangeState(RState(ResponseState::PMultiRow(multirow))); - } - Ok(ValueDecodeStateAny::Decoded(v)) => { - multirow.md2_col_cnt = v.u64(); - multirow.md_state += 1; + macro_rules! finish { + ($completed:expr, $target:expr) => { + match MetaState::try_finish(self, $completed, &mut $target) { + Ok(true) => multirow.md_state += 1, + Ok(false) => { + return DecodeState::ChangeState(RState(ResponseState::PMultiRow(multirow))) + } + Err(e) => return DecodeState::Error(e), } - Err(e) => return DecodeState::Error(e), - } + }; } + finish!(multirow.md_state == 1, &mut multirow.md1_target); + finish!(multirow.md_state == 2, &mut multirow.md2_col_cnt); while multirow.rows.len() as u64 != multirow.md1_target { let ret = match multirow.c_row.take() { Some(r) => self._decode_row_core(r), @@ -324,6 +203,7 @@ impl<'a> Decoder<'a> { } let lf = self._creq(b'\n'); self._cursor_incr_if(lf); + // FIXME(@ohsayan): the below is not exactly necessary and we can actually remove this if it complicates state management okay &= !(lf & (self._cursor() == meta.start)); if okay & lf { let start = meta.start; @@ -342,32 +222,23 @@ impl<'a> Decoder<'a> { &mut self, mut meta: ValueStateMeta, ) -> ProtocolResult { - if !meta.md1_flag { - match self.__resume_decode(meta.md1, ValueStateMeta::zero())? { - ValueDecodeStateAny::Decoded(s) => { - let s = s.u64(); - meta.md1_flag = true; - meta.md1 = s; - } - ValueDecodeStateAny::Pending(ValueState { v, .. }) => { - meta.md1 = v.u64(); - return Ok(ValueDecodeStateRaw::Pending(ValueState::new( - T::empty(), - meta, - ))); - } - } - } - meta.start = self._cursor(); - if self._remaining() as u64 >= meta.md1 { - let buf = &self.b[meta.start..self._cursor() + meta.md1 as usize]; - self._cursor_incr_by(meta.md1 as usize); - T::finish(buf).map(ValueDecodeStateAny::Decoded) - } else { - Ok(ValueDecodeStateAny::Pending(ValueState::new( + if !meta.md.finished(self)? { + Ok(ValueDecodeStateRaw::Pending(ValueState::new( T::empty(), meta, ))) + } else { + meta.start = self._cursor(); + if self._remaining() as u64 >= meta.md.val() { + let buf = &self.b[meta.start..self._cursor() + meta.md.val() as usize]; + self._cursor_incr_by(meta.md.val() as usize); + T::finish(buf).map(ValueDecodeStateAny::Decoded) + } else { + Ok(ValueDecodeStateAny::Pending(ValueState::new( + T::empty(), + meta, + ))) + } } } } @@ -402,6 +273,9 @@ impl<'a> Decoder<'a> { fn _creq(&self, b: u8) -> bool { (self.b[core::cmp::min(self.i, self.b.len() - 1)] == b) & !self._cursor_eof() } + fn _current(&self) -> &[u8] { + &self.b[self.i..] + } } trait DecodeDelimited { @@ -474,23 +348,6 @@ impl_fstr!( f64 as Float64 ); -#[derive(Debug, PartialEq)] -struct PendingValue { - state: ValueState, - tmp: Option, - stack: Vec<(Vec, ValueStateMeta)>, -} - -impl PendingValue { - fn new( - state: ValueState, - tmp: Option, - stack: Vec<(Vec, ValueStateMeta)>, - ) -> Self { - Self { state, tmp, stack } - } -} - impl<'a> Decoder<'a> { fn parse_list( &mut self, @@ -499,24 +356,14 @@ impl<'a> Decoder<'a> { ) -> ProtocolResult> { let (mut current_list, mut current_meta) = stack.pop().unwrap(); loop { - if !current_meta.md1_flag { - match self.__resume_decode(current_meta.md1, ValueStateMeta::zero())? { - ValueDecodeStateAny::Decoded(v) => { - current_meta.md1 = v.u64(); - current_meta.md1_flag = true; - } - ValueDecodeStateAny::Pending(ValueState { v, .. }) => { - current_meta.md1 = v.u64(); - stack.push((current_list, current_meta)); - return Ok(ValueDecodeStateAny::Pending(PendingValue::new( - ValueState::new(Value::List(vec![]), ValueStateMeta::zero()), - None, - stack, - ))); - } - } + if !current_meta.md.finished(self)? { + return Ok(ValueDecodeStateAny::Pending(PendingValue::new( + ValueState::new(Value::List(vec![]), ValueStateMeta::zero()), + None, + stack, + ))); } - if current_list.len() as u64 == current_meta.md1 { + if current_list.len() as u64 == current_meta.md.val() { match stack.pop() { None => { return Ok(ValueDecodeStateAny::Decoded(Value::List(current_list))); @@ -667,59 +514,24 @@ impl<'a> Decoder<'a> { } } -pub struct ClientHandshake(Box<[u8]>); -impl ClientHandshake { - pub(crate) fn new(cfg: &Config) -> Self { - let mut v = Vec::with_capacity(6 + cfg.username().len() + cfg.password().len() + 5); - v.extend(b"H\x00\x00\x00\x00\x00"); - pushlen!(v, cfg.username().len()); - pushlen!(v, cfg.password().len()); - v.extend(cfg.username().as_bytes()); - v.extend(cfg.password().as_bytes()); - Self(v.into_boxed_slice()) - } - pub(crate) fn inner(&self) -> &[u8] { - &self.0 - } -} - -#[derive(Debug)] -pub enum ServerHandshake { - Okay(u8), - Error(u8), -} -impl ServerHandshake { - pub fn parse(v: [u8; 4]) -> ClientResult { - Ok(match v { - [b'H', 0, 0, msg] => Self::Okay(msg), - [b'H', 0, 1, msg] => Self::Error(msg), - _ => { - return Err(Error::ConnectionSetupErr( - ConnectionSetupError::InvalidServerHandshake, - )) - } - }) - } -} - -#[test] -fn t_row() { - let mut decoder = Decoder::new(b"\x115\n\x00\x01\x01\x0D5\nsayan\x0220\n\x0E0\n", 0); - assert_eq!( - decoder.validate_response(RState::default()), - DecodeState::Completed(Response::Row(Row::new(vec![ - Value::Null, - Value::Bool(true), - Value::String("sayan".into()), - Value::UInt8(20), - Value::List(vec![]) - ]))) - ); -} - #[test] fn t_mrow() { - let mut decoder = Decoder::new(b"\x133\n5\n\x00\x01\x01\x0D5\nsayan\x0220\n\x0E0\n\x00\x01\x01\x0D5\nelana\x0221\n\x0E0\n\x00\x01\x01\x0D5\nemily\x0222\n\x0E0\n", 0); + const MROW_QUERY: &[u8] = b"\x133\n5\n\x00\x01\x01\x0D5\nsayan\x0220\n\x0E0\n\x00\x01\x01\x0D5\nelana\x0221\n\x0E0\n\x00\x01\x01\x0D5\nemily\x0222\n\x0E0\n"; + for i in 1..MROW_QUERY.len() { + let mut decoder = Decoder::new(&MROW_QUERY[..i], 0); + if i == 1 { + assert!(matches!( + decoder.validate_response(RState::default()), + DecodeState::ChangeState(RState(_)) + )); + } else { + assert!(matches!( + decoder.validate_response(RState::default()), + DecodeState::ChangeState(RState(ResponseState::PMultiRow(_))) + )); + } + } + let mut decoder = Decoder::new(MROW_QUERY, 0); assert_eq!( decoder.validate_response(RState::default()), DecodeState::Completed(Response::Rows(vec![ @@ -747,3 +559,31 @@ fn t_mrow() { ])) ); } +#[test] +fn t_num() { + const NUM: &[u8] = b"1234\n"; + fn decoder(i: usize) -> Decoder<'static> { + Decoder::new(&NUM[..i], 0) + } + for (i, expected) in [1, 12, 123, 1234u64] + .iter() + .enumerate() + .map(|(a, b)| (a + 1, *b)) + { + assert_eq!( + decoder(i) + .__resume_decode(0u64, ValueStateMeta::zero()) + .unwrap(), + ValueDecodeStateAny::Pending(ValueState::new( + Value::UInt64(expected), + ValueStateMeta::zero() + )) + ); + } + assert_eq!( + decoder(NUM.len()) + .__resume_decode(0u64, ValueStateMeta::zero()) + .unwrap(), + ValueDecodeStateAny::Decoded(Value::UInt64(1234)) + ); +} diff --git a/src/protocol/pipe.rs b/src/protocol/pipe.rs new file mode 100644 index 0000000..fd01fe6 --- /dev/null +++ b/src/protocol/pipe.rs @@ -0,0 +1,131 @@ +/* + * Copyright 2024, Sayan Nandan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +use { + super::{ + state::{DecodeState, RState, ResponseState}, + Decoder, ProtocolError, + }, + crate::response::Response, +}; + +const ILLEGAL_PACKET_ESCAPE: u8 = 0xFF; + +#[derive(Debug, PartialEq, Default)] +pub(crate) struct MRespState { + processed: Vec, + pending: Option, +} + +#[derive(Debug, PartialEq)] +pub(crate) enum PipelineResult { + Completed(Vec), + Pending(MRespState), + Error(ProtocolError), +} + +impl MRespState { + #[cold] + fn except() -> PipelineResult { + PipelineResult::Error(ProtocolError::InvalidPacket) + } + fn step(mut self, decoder: &mut Decoder, expected: usize) -> PipelineResult { + loop { + if decoder._cursor_eof() { + return PipelineResult::Pending(self); + } + if decoder._cursor_value() == ILLEGAL_PACKET_ESCAPE { + return Self::except(); + } + match decoder.validate_response(RState( + self.pending.take().unwrap_or(ResponseState::Initial), + )) { + DecodeState::ChangeState(RState(s)) => { + self.pending = Some(s); + return PipelineResult::Pending(self); + } + DecodeState::Completed(c) => { + self.processed.push(c); + if self.processed.len() == expected { + return PipelineResult::Completed(self.processed); + } + } + DecodeState::Error(e) => return PipelineResult::Error(e), + } + } + } +} + +impl<'a> Decoder<'a> { + pub fn validate_pipe(&mut self, expected: usize, state: MRespState) -> PipelineResult { + state.step(self, expected) + } +} + +#[cfg(test)] +const QUERY: &[u8] = b"\x12\x10\xFF\xFF\x115\n\x00\x01\x01\x0D5\nsayan\x0220\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nelana\x0221\n\x0E0\n\x115\n\x00\x01\x01\x0D5\nemily\x0222\n\x0E0\n"; + +#[test] +fn t_pipe() { + use crate::response::{Response, Row, Value}; + let mut decoder = Decoder::new(QUERY, 0); + assert_eq!( + decoder.validate_pipe(5, MRespState::default()), + PipelineResult::Completed(vec![ + Response::Empty, + Response::Error(u16::MAX), + Response::Row(Row::new(vec![ + Value::Null, + Value::Bool(true), + Value::String("sayan".into()), + Value::UInt8(20), + Value::List(vec![]) + ])), + Response::Row(Row::new(vec![ + Value::Null, + Value::Bool(true), + Value::String("elana".into()), + Value::UInt8(21), + Value::List(vec![]) + ])), + Response::Row(Row::new(vec![ + Value::Null, + Value::Bool(true), + Value::String("emily".into()), + Value::UInt8(22), + Value::List(vec![]) + ])) + ]) + ); +} + +#[test] +fn t_pipe_staged() { + for i in Decoder::MIN_READBACK..QUERY.len() { + let mut dec = Decoder::new(&QUERY[..i], 0); + if i < 3 { + assert!(matches!( + dec.validate_pipe(5, MRespState::default()), + PipelineResult::Pending(_) + )); + } else { + assert!(matches!( + dec.validate_pipe(5, MRespState::default()), + PipelineResult::Pending(_) + )); + } + } +} diff --git a/src/protocol/state.rs b/src/protocol/state.rs new file mode 100644 index 0000000..3e5753d --- /dev/null +++ b/src/protocol/state.rs @@ -0,0 +1,254 @@ +/* + * Copyright 2024, Sayan Nandan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +use { + super::{Decoder, ProtocolError, ProtocolResult}, + crate::response::{Response, Row, Value}, +}; + +pub type ValueDecodeStateRaw = ValueDecodeStateAny; +pub type ValueDecodeState = ValueDecodeStateAny; + +/* + pending value + --- + a stack is useful for recursive types +*/ + +#[derive(Debug, PartialEq)] +pub struct PendingValue { + pub(super) state: ValueState, + pub(super) tmp: Option, + pub(super) stack: Vec<(Vec, ValueStateMeta)>, +} + +impl PendingValue { + pub fn new( + state: ValueState, + tmp: Option, + stack: Vec<(Vec, ValueStateMeta)>, + ) -> Self { + Self { state, tmp, stack } + } +} + +/* + value state +*/ + +#[derive(Debug, PartialEq)] +pub enum ValueDecodeStateAny { + Pending(P), + Decoded(V), +} + +#[derive(Debug, PartialEq)] +pub struct ValueState { + pub(super) v: Value, + pub(super) meta: ValueStateMeta, +} + +impl ValueState { + pub fn new(v: Value, meta: ValueStateMeta) -> Self { + Self { v, meta } + } +} + +#[derive(Debug, PartialEq)] +pub struct ValueStateMeta { + pub(super) start: usize, + pub(super) md: MetaState, +} + +impl ValueStateMeta { + pub fn zero() -> Self { + Self { + start: 0, + md: MetaState::default(), + } + } + pub fn new(start: usize, md1: u64, md1_flag: bool) -> Self { + Self { + start, + md: MetaState::new(md1_flag, md1), + } + } +} + +/* + metadata init state +*/ + +#[derive(Debug, Default, PartialEq)] +pub struct MetaState { + completed: bool, + val: u64, +} + +impl MetaState { + pub fn new(completed: bool, val: u64) -> Self { + Self { completed, val } + } + #[inline(always)] + pub fn finished(&mut self, decoder: &mut Decoder) -> ProtocolResult { + self.finish_or_continue(decoder, || Ok(true), || Ok(false), |e| Err(e)) + } + #[inline(always)] + pub fn finish_or_continue( + &mut self, + decoder: &mut Decoder, + if_completed: impl FnOnce() -> T, + if_pending: impl FnOnce() -> T, + if_err: impl FnOnce(ProtocolError) -> T, + ) -> T { + Self::try_finish_or_continue( + self.completed, + &mut self.val, + decoder, + if_completed, + if_pending, + if_err, + ) + } + #[inline(always)] + pub fn try_finish( + decoder: &mut Decoder, + completed: bool, + val: &mut u64, + ) -> ProtocolResult { + Self::try_finish_or_continue( + completed, + val, + decoder, + || Ok(true), + || Ok(false), + |e| Err(e), + ) + } + #[inline(always)] + pub fn try_finish_or_continue( + completed: bool, + val: &mut u64, + decoder: &mut Decoder, + if_completed: impl FnOnce() -> T, + if_pending: impl FnOnce() -> T, + if_err: impl FnOnce(ProtocolError) -> T, + ) -> T { + if completed { + if_completed() + } else { + match decoder.__resume_decode(*val, ValueStateMeta::zero()) { + Ok(vs) => match vs { + ValueDecodeStateAny::Pending(ValueState { v, .. }) => { + *val = v.u64(); + if_pending() + } + ValueDecodeStateAny::Decoded(v) => { + *val = v.u64(); + if_completed() + } + }, + Err(e) => if_err(e), + } + } + } + #[inline(always)] + pub fn val(&self) -> u64 { + self.val + } +} + +/* + row state +*/ + +#[derive(Debug, PartialEq)] +pub struct RowState { + pub(super) meta: ValueStateMeta, + pub(super) row: Vec, + pub(super) tmp: Option, +} + +impl RowState { + pub fn new(meta: ValueStateMeta, row: Vec, tmp: Option) -> Self { + Self { meta, row, tmp } + } +} + +/* + multi row state +*/ + +#[derive(Debug, PartialEq)] +pub struct MultiRowState { + pub(super) c_row: Option, + pub(super) rows: Vec, + pub(super) md_state: u8, + pub(super) md1_target: u64, + pub(super) md2_col_cnt: u64, +} + +impl Default for MultiRowState { + fn default() -> Self { + Self::new(None, vec![], 0, 0, 0) + } +} + +impl MultiRowState { + pub fn new( + c_row: Option, + rows: Vec, + md_s: u8, + md_cnt: u64, + md_target: u64, + ) -> Self { + Self { + c_row, + rows, + md_state: md_s, + md1_target: md_target, + md2_col_cnt: md_cnt, + } + } +} + +/* + response state +*/ + +#[derive(Debug, PartialEq)] +pub enum ResponseState { + Initial, + PValue(PendingValue), + PError, + PRow(RowState), + PMultiRow(MultiRowState), +} + +#[derive(Debug, PartialEq)] +pub enum DecodeState { + ChangeState(RState), + Completed(Response), + Error(ProtocolError), +} + +#[derive(Debug, PartialEq)] +pub struct RState(pub(super) ResponseState); +impl Default for RState { + fn default() -> Self { + RState(ResponseState::Initial) + } +} diff --git a/src/query.rs b/src/query.rs index 351e521..0d4b67a 100644 --- a/src/query.rs +++ b/src/query.rs @@ -32,6 +32,7 @@ use std::{ io::{self, Write}, + iter::FromIterator, num::{ NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroIsize, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize, @@ -128,6 +129,99 @@ impl Query { } } +/// # Pipeline +/// +/// A pipeline can be used to send multiple queries at once to the server. Queries in a pipeline are executed independently +/// of one another, but they are executed serially unless otherwise configured +pub struct Pipeline { + cnt: usize, + buf: Vec, +} + +impl Pipeline { + /// Create a new pipeline + pub const fn new() -> Self { + Self { + cnt: 0, + buf: Vec::new(), + } + } + pub(crate) fn buf(&self) -> &[u8] { + &self.buf + } + /// Returns the number of queries that were appended to this pipeline + pub fn query_count(&self) -> usize { + self.cnt + } + /// Add a query to this pipeline + /// + /// Note: It's not possible to get the query back from the pipeline since it's not indexed (and doing so would be an unnecessary + /// waste of space and time). That's why we take a reference which allows the caller to continue owning the [`Query`] item + pub fn push(&mut self, q: &Query) { + // qlen + self.buf + .extend(itoa::Buffer::new().format(q.q_window).as_bytes()); + self.buf.push(b'\n'); + // plen + self.buf.extend( + itoa::Buffer::new() + .format(q.buf.len() - q.q_window) + .as_bytes(), + ); + self.buf.push(b'\n'); + // body + self.buf.extend(&q.buf); + self.cnt += 1; + } + /// Add a query to this pipeline (builder pattern) + /// + /// This is intended to be used with the + /// ["builder pattern"](https://rust-unofficial.github.io/patterns/patterns/creational/builder.html). For example: + /// ``` + /// use skytable::{query, Pipeline}; + /// + /// let pipeline = Pipeline::new() + /// .add(&query!("create space myspace")) + /// .add(&query!("drop space myspace")); + /// assert_eq!(pipeline.query_count(), 2); + /// ``` + pub fn add(mut self, q: &Query) -> Self { + self.push(q); + self + } +} + +impl, I> From for Pipeline +where + I: Iterator, +{ + fn from(iter: I) -> Self { + let mut pipeline = Pipeline::new(); + iter.into_iter().for_each(|q| pipeline.push(q.as_ref())); + pipeline + } +} + +impl> Extend for Pipeline { + fn extend>(&mut self, iter: T) { + iter.into_iter().for_each(|q| self.push(q.as_ref())) + } +} + +impl> FromIterator for Pipeline { + fn from_iter>(iter: T) -> Self { + let mut pipe = Pipeline::new(); + iter.into_iter().for_each(|q| pipe.push(q.as_ref())); + pipe + } +} + +impl AsRef for Query { + fn as_ref(&self) -> &Query { + self + } +} + /* Query parameters */ diff --git a/tests/pipe.rs b/tests/pipe.rs new file mode 100644 index 0000000..59e30cd --- /dev/null +++ b/tests/pipe.rs @@ -0,0 +1,12 @@ +use skytable::{query, query::Pipeline}; + +#[test] +fn compile_add_queries() { + let mut pipeline: Pipeline = (0..123) + .map(|num| query!("select * from mymodel where username = ?", num as u64)) + .collect(); + assert_eq!(pipeline.query_count(), 123); + let query = query!("systemctl report status"); + pipeline.extend(vec![&query]); + assert_eq!(pipeline.query_count(), 124); +}