use std::collections::HashMap; use std::io::{Cursor, Write}; use std::net::SocketAddr; use std::sync::Arc; use abao::encode::Encoder; use portpicker::pick_unused_port; use tokio::sync::{Mutex, oneshot, RwLock}; use tonic::{Request, Response, Status}; use tonic::transport::Server; use uuid::Uuid; use bao::bao_server::Bao; use tokio::signal::unix::{signal, SignalKind}; use crate::bao::{bao_server, FinishRequest, FinishResponse, HashRequest, HashResponse, NewHasherRequest, NewHasherResponse}; #[path = "proto/bao.rs"] mod bao; struct GlobalState { hashers: HashMap>>>>>, } pub struct BaoService { state: Arc>, } #[tonic::async_trait] impl Bao for BaoService { async fn new_hasher(&self, _request: Request) -> Result, Status> { let encoder = Encoder::new(Cursor::new(Vec::new())); let id = Uuid::new_v4(); { let mut state = self.state.write().await; state.hashers.insert(id,Arc::new( Mutex::new(encoder))); } Ok(Response::new(NewHasherResponse { id: id.to_string(), })) } async fn hash(&self, request: Request) -> Result, Status> { let id = Uuid::parse_str(&request.get_ref().id).map_err(|_| Status::invalid_argument("invalid id"))?; { let state = self.state.read().await; let encoder = state.hashers.get(&id).ok_or_else(|| Status::not_found("hasher not found"))?.clone(); let mut encoder = encoder.lock().await; encoder.write(&request.get_ref().data).map_err(|_| Status::internal("write failed"))?; } Ok(Response::new(HashResponse { status: true, })) } async fn finish(&self, request: Request) -> Result, Status> { let id = Uuid::parse_str(&request.get_ref().id).map_err(|_| Status::invalid_argument("invalid id"))?; let (hash, proof) = { let mut state = self.state.write().await; let encoder = state.hashers.remove(&id).ok_or_else(|| Status::not_found("hasher not found"))?; let encoder = Arc::try_unwrap(encoder).unwrap(); // Unwrap the Arc let mut encoder = encoder.lock().await; let hash = encoder.finalize()?.as_bytes().to_vec(); let proof = encoder.inner_mut().get_ref().to_vec(); (hash, proof) }; Ok(Response::new(FinishResponse { hash, proof, })) } } impl BaoService { fn new(state: Arc>) -> Self { BaoService { state } } } #[tokio::main] async fn main() -> Result<(), Box> { let (tx, rx) = oneshot::channel::<()>(); let health_reporter = tonic_health::server::health_reporter(); let port = match pick_unused_port() { Some(p) => p, None => { return Err("Failed to pick an unused port".into()); } }; let addr: SocketAddr = format!("127.0.0.1:{}", port).parse()?; println!("1|1|tcp|127.0.0.1:{}|grpc", addr.port()); let global_state = Arc::new(RwLock::new(GlobalState { hashers: HashMap::new(), })); tokio::spawn(async move { let mut term_signal = signal(SignalKind::terminate()).expect("Could not create signal handler"); // Wait for the terminate signal term_signal.recv().await; println!("Termination signal received, shutting down server..."); // Sending a signal through the channel to initiate shutdown. // If the receiver is dropped, we don't care about the error. let _ = tx.send(()); }); Server::builder() .max_frame_size( (1 << 24) - 1) .add_service(bao_server::BaoServer::new(BaoService::new(global_state.clone()))) .add_service(health_reporter.1) .serve_with_shutdown(addr, async { // This future completes when the shutdown signal is received, // allowing the server to shut down gracefully. rx.await.ok(); }) .await?; Ok(()) }