diff --git a/bao/bao.go b/bao/bao.go index 4de5869..899bc31 100644 --- a/bao/bao.go +++ b/bao/bao.go @@ -23,6 +23,7 @@ type Bao interface { Write(id uint32, data []byte) error Finalize(id uint32) ([]byte, error) Destroy(id uint32) error + ComputeFile(path string) ([]byte, error) } func init() { @@ -88,8 +89,7 @@ func init() { } -func ComputeBaoTree(reader io.Reader) ([]byte, error) { - +func ComputeTreeStreaming(reader io.Reader) ([]byte, error) { instance, err := baoInstance.Init() if err != nil { return nil, err @@ -119,6 +119,15 @@ func ComputeBaoTree(reader io.Reader) ([]byte, error) { } } +func ComputeTreeFile(file *os.File) ([]byte, error) { + tree, err := baoInstance.ComputeFile(file.Name()) + if err != nil { + return nil, err + } + + return tree, nil +} + func write(instance uint32, bytes *[]byte) error { err := baoInstance.Write(instance, *bytes) if err != nil { diff --git a/bao/client.go b/bao/client.go index 784f046..b490f26 100644 --- a/bao/client.go +++ b/bao/client.go @@ -45,3 +45,11 @@ func (g *GRPCClient) Destroy(id uint32) error { return nil } +func (g *GRPCClient) ComputeFile(path string) ([]byte, error) { + tree, err := g.client.ComputeFile(context.Background(), &wrappers.StringValue{Value: path}) + if err != nil { + return nil, err + } + + return tree.Value, nil +} diff --git a/bao/proto/bao.proto b/bao/proto/bao.proto index 861a4ca..7c09ec3 100644 --- a/bao/proto/bao.proto +++ b/bao/proto/bao.proto @@ -17,4 +17,5 @@ service bao { rpc Write(WriteRequest) returns (google.protobuf.Empty); rpc Finalize (google.protobuf.UInt32Value) returns (google.protobuf.BytesValue); rpc Destroy (google.protobuf.UInt32Value) returns (google.protobuf.Empty); + rpc ComputeFile (google.protobuf.StringValue) returns (google.protobuf.BytesValue); } diff --git a/bao/src/main.rs b/bao/src/main.rs index 05c65db..ac162de 100644 --- a/bao/src/main.rs +++ b/bao/src/main.rs @@ -1,8 +1,11 @@ #![feature(async_fn_in_trait)] #![allow(incomplete_features)] +use io::Read; use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::fs::{File}; +use std::io; use std::io::{Cursor, Write}; use std::sync::{Arc}; @@ -16,7 +19,7 @@ use tonic_health::server::HealthReporter; use crate::proto::bao::bao_server::{Bao, BaoServer}; use crate::proto::bao::WriteRequest; -use crate::proto::google::protobuf::{BytesValue, Empty, UInt32Value}; +use crate::proto::google::protobuf::{BytesValue, Empty, StringValue, UInt32Value}; use crate::unique_port::UniquePort; mod proto; @@ -64,7 +67,7 @@ impl Bao for BaoService { let next_id = self.counter.inc() as u32; let tree = Vec::new(); let cursor = Cursor::new(tree); - let encoder = Encoder::new(cursor); + let encoder = Encoder::new_outboard(cursor); let mut req = self.requests.lock(); req.insert(next_id, encoder); @@ -113,4 +116,36 @@ impl Bao for BaoService { Ok(Response::new(Empty::default())) } + + async fn compute_file(&self, request: Request) -> Result, Status> { + let r = request.into_inner(); + let tree = Vec::new(); + let cursor = Cursor::new(tree); + let mut encoder = Encoder::new_outboard(cursor); + let mut input = File::open(r.value)?; + + copy_reader_to_writer(&mut input, &mut encoder)?; + + let ret = encoder.finalize().unwrap(); + let bytes = ret.as_bytes().to_vec(); + Ok(Response::new(BytesValue { value: bytes })) + } +} +fn copy_reader_to_writer( + reader: &mut impl Read, + writer: &mut impl Write, +) -> io::Result { + // At least 16 KiB is necessary to use AVX-512 with BLAKE3. + let mut buf = [0; 65536]; + let mut written = 0; + loop { + let len = match reader.read(&mut buf) { + Ok(0) => return Ok(written), + Ok(len) => len, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(e) => return Err(e), + }; + writer.write_all(&buf[..len])?; + written += len as u64; + } }