212 lines
6.3 KiB
Rust
212 lines
6.3 KiB
Rust
use std::collections::HashMap;
|
|
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
|
|
use abao::encode::Encoder;
|
|
use portpicker::pick_unused_port;
|
|
use tokio::sync::{Mutex, oneshot, RwLock};
|
|
use tonic::{Request, Response, Status};
|
|
use tonic::transport::Server;
|
|
use uuid::Uuid;
|
|
use bao::bao_server::Bao;
|
|
use tokio::signal::unix::{signal, SignalKind};
|
|
|
|
use crate::bao::{bao_server, FinishRequest, FinishResponse, HashRequest, HashResponse, NewHasherRequest, NewHasherResponse, VerifyRequest, VerifyResponse};
|
|
|
|
#[path = "proto/bao.rs"]
|
|
mod bao;
|
|
|
|
struct GlobalState {
|
|
hashers: HashMap<Uuid, Arc<Mutex<Encoder<Cursor<Vec<u8>>>>>>,
|
|
}
|
|
|
|
|
|
pub struct BaoService {
|
|
state: Arc<RwLock<GlobalState>>,
|
|
}
|
|
|
|
#[tonic::async_trait]
|
|
impl Bao for BaoService {
|
|
async fn new_hasher(&self, _request: Request<NewHasherRequest>) -> Result<Response<NewHasherResponse>, Status> {
|
|
let encoder = Encoder::new_outboard(Cursor::new(Vec::new()));
|
|
let id = Uuid::new_v4();
|
|
{
|
|
let mut state = self.state.write().await;
|
|
state.hashers.insert(id, Arc::new(Mutex::new(encoder)));
|
|
}
|
|
Ok(Response::new(NewHasherResponse {
|
|
id: id.to_string(),
|
|
}))
|
|
}
|
|
|
|
async fn hash(&self, request: Request<HashRequest>) -> Result<Response<HashResponse>, Status> {
|
|
let id = Uuid::parse_str(&request.get_ref().id).map_err(|_| Status::invalid_argument("invalid id"))?;
|
|
{
|
|
let state = self.state.read().await;
|
|
let encoder = state.hashers.get(&id).ok_or_else(|| Status::not_found("hasher not found"))?.clone();
|
|
let mut encoder = encoder.lock().await;
|
|
encoder.write_all(&request.get_ref().data).map_err(|_| Status::internal("write failed"))?;
|
|
}
|
|
Ok(Response::new(HashResponse {
|
|
status: true,
|
|
}))
|
|
}
|
|
|
|
async fn finish(&self, request: Request<FinishRequest>) -> Result<Response<FinishResponse>, Status> {
|
|
let id = Uuid::parse_str(&request.get_ref().id).map_err(|_| Status::invalid_argument("invalid id"))?;
|
|
let (hash, proof) = {
|
|
let mut state = self.state.write().await;
|
|
let encoder = state.hashers.remove(&id).ok_or_else(|| Status::not_found("hasher not found"))?;
|
|
let encoder = Arc::try_unwrap(encoder).unwrap(); // Unwrap the Arc
|
|
let mut encoder = encoder.lock().await;
|
|
let hash = encoder.finalize()?.as_bytes().to_vec();
|
|
let proof = encoder.inner_mut().get_ref().to_vec();
|
|
(hash, proof)
|
|
};
|
|
Ok(Response::new(FinishResponse {
|
|
hash,
|
|
proof,
|
|
}))
|
|
}
|
|
|
|
async fn verify(&self, request: Request<VerifyRequest>) -> Result<Response<VerifyResponse>, Status> {
|
|
let req = request.get_ref();
|
|
let res = verify_internal(
|
|
req.data.clone(),
|
|
req.offset,
|
|
req.proof.clone(),
|
|
from_vec_to_array(req.hash.clone()),
|
|
);
|
|
|
|
if res.is_err() {
|
|
|
|
Ok(Response::new(VerifyResponse {
|
|
status: false,
|
|
error: res.unwrap_err().to_string(),
|
|
|
|
}))
|
|
} else {
|
|
Ok(Response::new(VerifyResponse {
|
|
status: true,
|
|
error: String::from(""),
|
|
}))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn verify_internal(
|
|
chunk_bytes: Vec<u8>,
|
|
offset: u64,
|
|
bao_outboard_bytes: Vec<u8>,
|
|
blake3_hash: [u8; 32],
|
|
) -> anyhow::Result<u8> {
|
|
let mut slice_stream = abao::encode::SliceExtractor::new_outboard(
|
|
FakeSeeker::new(&chunk_bytes[..]),
|
|
Cursor::new(&bao_outboard_bytes),
|
|
offset,
|
|
262144,
|
|
);
|
|
|
|
let mut decode_stream = abao::decode::SliceDecoder::new(
|
|
&mut slice_stream,
|
|
&abao::Hash::from(blake3_hash),
|
|
offset,
|
|
262144,
|
|
);
|
|
let mut decoded = Vec::new();
|
|
decode_stream.read_to_end(&mut decoded)?;
|
|
|
|
Ok(1)
|
|
}
|
|
|
|
fn from_vec_to_array<T, const N: usize>(v: Vec<T>) -> [T; N] {
|
|
core::convert::TryInto::try_into(v)
|
|
.unwrap_or_else(|v: Vec<T>| panic!("Expected a Vec of length {} but it was {}", N, v.len()))
|
|
}
|
|
|
|
struct FakeSeeker<R: Read> {
|
|
reader: R,
|
|
bytes_read: u64,
|
|
}
|
|
|
|
impl<R: Read> FakeSeeker<R> {
|
|
fn new(reader: R) -> Self {
|
|
Self {
|
|
reader,
|
|
bytes_read: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R: Read> Read for FakeSeeker<R> {
|
|
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
|
let n = self.reader.read(buf)?;
|
|
self.bytes_read += n as u64;
|
|
Ok(n)
|
|
}
|
|
}
|
|
|
|
impl<R: Read> Seek for FakeSeeker<R> {
|
|
fn seek(&mut self, _: SeekFrom) -> std::io::Result<u64> {
|
|
// Do nothing and return the current position.
|
|
Ok(self.bytes_read)
|
|
}
|
|
}
|
|
|
|
impl BaoService {
|
|
fn new(state: Arc<RwLock<GlobalState>>) -> Self {
|
|
BaoService { state }
|
|
}
|
|
}
|
|
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
let (tx, rx) = oneshot::channel::<()>();
|
|
let health_reporter = tonic_health::server::health_reporter();
|
|
let port = match pick_unused_port() {
|
|
Some(p) => p,
|
|
None => {
|
|
return Err("Failed to pick an unused port".into());
|
|
}
|
|
};
|
|
|
|
let addr: SocketAddr = format!("127.0.0.1:{}", port).parse()?;
|
|
|
|
println!("1|1|tcp|127.0.0.1:{}|grpc", addr.port());
|
|
|
|
let global_state = Arc::new(RwLock::new(GlobalState {
|
|
hashers: HashMap::new(),
|
|
}));
|
|
|
|
tokio::spawn(async move {
|
|
let mut term_signal = signal(SignalKind::terminate()).expect("Could not create signal handler");
|
|
|
|
// Wait for the terminate signal
|
|
term_signal.recv().await;
|
|
println!("Termination signal received, shutting down server...");
|
|
|
|
// Sending a signal through the channel to initiate shutdown.
|
|
// If the receiver is dropped, we don't care about the error.
|
|
let _ = tx.send(());
|
|
});
|
|
|
|
let server = bao_server::BaoServer::new(BaoService::new(global_state.clone()))
|
|
.max_decoding_message_size(usize::MAX)
|
|
.max_encoding_message_size(usize::MAX);
|
|
|
|
Server::builder()
|
|
.max_frame_size((1 << 24) - 1)
|
|
.add_service(server)
|
|
.add_service(health_reporter.1)
|
|
.serve_with_shutdown(addr, async {
|
|
// This future completes when the shutdown signal is received,
|
|
// allowing the server to shut down gracefully.
|
|
rx.await.ok();
|
|
})
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|