use actix_web::body::EitherBody; use actix_web::dev::{Decompress, Payload}; use actix_web::error::PayloadError; use actix_web::http::StatusCode; use actix_web::web::{Bytes, BytesMut}; use actix_web::{Error, FromRequest, HttpRequest, HttpResponse, Responder, ResponseError}; use derive_more::Display; use futures::{ready, Stream}; use rmp_serde::{decode::Error as RmpDecodeError, encode::Error as RmpEncodeError}; use serde::de::DeserializeOwned; use serde::Serialize; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; pub struct MsgPack(pub T); impl FromRequest for MsgPack { type Error = Error; type Future = MsgPackExtractFuture; fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { MsgPackExtractFuture::new(req.clone(), payload) } } #[derive(Debug, Display)] pub enum MsgPackError { #[display(fmt = "Can not deserialize")] Deserialize(RmpDecodeError), #[display(fmt = "Reading payload error {}", _0)] ReadPayload(PayloadError), #[display(fmt = "Can not serialize msgpack")] Serialize(RmpEncodeError), } impl ResponseError for MsgPackError { fn status_code(&self) -> StatusCode { match self { MsgPackError::Deserialize(_) => StatusCode::UNPROCESSABLE_ENTITY, MsgPackError::ReadPayload(err) => err.status_code(), MsgPackError::Serialize(_) => StatusCode::INTERNAL_SERVER_ERROR, } } } pub struct MsgPackExtractFuture { _req: HttpRequest, fut: MsgPackBody, _res: PhantomData, } impl MsgPackExtractFuture { fn new(req: HttpRequest, payload: &mut Payload) -> Self { MsgPackExtractFuture { _req: req.clone(), fut: MsgPackBody::new(req.clone(), payload), _res: PhantomData {}, } } } impl Unpin for MsgPackExtractFuture {} impl Future for MsgPackExtractFuture { type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); let res = ready!(Pin::new(&mut this.fut).poll(cx)); match res { Ok(data) => Ready(Ok(MsgPack(data))), Err(err) => Ready(Err(err.into())), } } } struct MsgPackBody { buf: BytesMut, payload: Decompress, _res: PhantomData, } impl Unpin for MsgPackBody {} impl MsgPackBody { fn new(req: HttpRequest, payload: &mut Payload) -> Self { MsgPackBody { payload: Decompress::from_headers(payload.take(), req.headers()), buf: BytesMut::new(), _res: PhantomData {}, } } } impl Future for MsgPackBody { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); let next = ready!(Pin::new(&mut this.payload).poll_next(cx)); match next { None => { let res = match rmp_serde::from_slice(&this.buf) { Ok(data) => data, Err(err) => { return Ready(Err(MsgPackError::Deserialize(err))); } }; Ready(Ok(res)) } Some(res) => { let data: Bytes = match res { Ok(data) => data, Err(err) => { return Ready(Err(MsgPackError::ReadPayload(err))); } }; this.buf.extend_from_slice(&data); Pending } } } } impl Responder for MsgPack { type Body = EitherBody; fn respond_to(self, _: &HttpRequest) -> HttpResponse { match rmp_serde::to_vec(&self.0) { Ok(data) => { match HttpResponse::Ok() .content_type(mime::APPLICATION_MSGPACK) .message_body(Bytes::from(data)) { Ok(response) => response.map_into_left_body(), Err(err) => HttpResponse::from_error(err).map_into_right_body(), } } Err(err) => { HttpResponse::from_error(MsgPackError::Serialize(err)).map_into_right_body() } } } }