150 lines
4.5 KiB
Rust
150 lines
4.5 KiB
Rust
use actix_web::body::EitherBody;
|
|
use actix_web::dev::{Decompress, Payload};
|
|
use actix_web::error::PayloadError;
|
|
use actix_web::http::StatusCode;
|
|
use actix_web::web::{Bytes, BytesMut};
|
|
use actix_web::{Error, FromRequest, HttpRequest, HttpResponse, Responder, ResponseError};
|
|
use derive_more::Display;
|
|
use futures::{ready, Stream};
|
|
use rmp_serde::{decode::Error as RmpDecodeError, encode::Error as RmpEncodeError};
|
|
use serde::de::DeserializeOwned;
|
|
use serde::Serialize;
|
|
use std::future::Future;
|
|
use std::marker::PhantomData;
|
|
use std::pin::Pin;
|
|
use std::task::Poll::{Pending, Ready};
|
|
use std::task::{Context, Poll};
|
|
|
|
pub struct MsgPack<T>(pub T);
|
|
|
|
impl<T: DeserializeOwned> FromRequest for MsgPack<T> {
|
|
type Error = Error;
|
|
type Future = MsgPackExtractFuture<T>;
|
|
|
|
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
|
|
MsgPackExtractFuture::new(req.clone(), payload)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Display)]
|
|
pub enum MsgPackError {
|
|
#[display(fmt = "Can not deserialize")]
|
|
Deserialize(RmpDecodeError),
|
|
#[display(fmt = "Reading payload error {}", _0)]
|
|
ReadPayload(PayloadError),
|
|
#[display(fmt = "Can not serialize msgpack")]
|
|
Serialize(RmpEncodeError),
|
|
}
|
|
|
|
impl ResponseError for MsgPackError {
|
|
fn status_code(&self) -> StatusCode {
|
|
match self {
|
|
MsgPackError::Deserialize(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
|
MsgPackError::ReadPayload(err) => err.status_code(),
|
|
MsgPackError::Serialize(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct MsgPackExtractFuture<T> {
|
|
_req: HttpRequest,
|
|
fut: MsgPackBody<T>,
|
|
_res: PhantomData<T>,
|
|
}
|
|
|
|
impl<T: DeserializeOwned> MsgPackExtractFuture<T> {
|
|
fn new(req: HttpRequest, payload: &mut Payload) -> Self {
|
|
MsgPackExtractFuture {
|
|
_req: req.clone(),
|
|
fut: MsgPackBody::new(req.clone(), payload),
|
|
_res: PhantomData {},
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> Unpin for MsgPackExtractFuture<T> {}
|
|
|
|
impl<T: DeserializeOwned> Future for MsgPackExtractFuture<T> {
|
|
type Output = Result<MsgPack<T>, Error>;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let this = self.get_mut();
|
|
|
|
let res = ready!(Pin::new(&mut this.fut).poll(cx));
|
|
match res {
|
|
Ok(data) => Ready(Ok(MsgPack(data))),
|
|
Err(err) => Ready(Err(err.into())),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct MsgPackBody<T> {
|
|
buf: BytesMut,
|
|
payload: Decompress<Payload>,
|
|
_res: PhantomData<T>,
|
|
}
|
|
|
|
impl<T> Unpin for MsgPackBody<T> {}
|
|
|
|
impl<T: DeserializeOwned> MsgPackBody<T> {
|
|
fn new(req: HttpRequest, payload: &mut Payload) -> Self {
|
|
MsgPackBody {
|
|
payload: Decompress::from_headers(payload.take(), req.headers()),
|
|
buf: BytesMut::new(),
|
|
_res: PhantomData {},
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: DeserializeOwned> Future for MsgPackBody<T> {
|
|
type Output = Result<T, MsgPackError>;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
let this = self.get_mut();
|
|
|
|
let next = ready!(Pin::new(&mut this.payload).poll_next(cx));
|
|
match next {
|
|
None => {
|
|
let res = match rmp_serde::from_slice(&this.buf) {
|
|
Ok(data) => data,
|
|
Err(err) => {
|
|
return Ready(Err(MsgPackError::Deserialize(err)));
|
|
}
|
|
};
|
|
Ready(Ok(res))
|
|
}
|
|
Some(res) => {
|
|
let data: Bytes = match res {
|
|
Ok(data) => data,
|
|
Err(err) => {
|
|
return Ready(Err(MsgPackError::ReadPayload(err)));
|
|
}
|
|
};
|
|
this.buf.extend_from_slice(&data);
|
|
Pending
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: Serialize> Responder for MsgPack<T> {
|
|
type Body = EitherBody<Bytes>;
|
|
|
|
fn respond_to(self, _: &HttpRequest) -> HttpResponse<Self::Body> {
|
|
match rmp_serde::to_vec(&self.0) {
|
|
Ok(data) => {
|
|
match HttpResponse::Ok()
|
|
.content_type(mime::APPLICATION_MSGPACK)
|
|
.message_body(Bytes::from(data))
|
|
{
|
|
Ok(response) => response.map_into_left_body(),
|
|
Err(err) => HttpResponse::from_error(err).map_into_right_body(),
|
|
}
|
|
}
|
|
Err(err) => {
|
|
HttpResponse::from_error(MsgPackError::Serialize(err)).map_into_right_body()
|
|
}
|
|
}
|
|
}
|
|
}
|