refactor: move to a go-plugin based GRPC approach for bao

This commit is contained in:
Derrick Hammer 2023-05-15 12:34:55 -04:00
parent 435445dda5
commit a8d2ad3393
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
14 changed files with 1660 additions and 116 deletions

1132
bao/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -5,10 +5,24 @@ edition = "2021"
[dependencies]
abao = { version = "0.2.0", features = ["group_size_256k", "tokio_io"], default-features = false }
wasmedge-bindgen = "0.4.1"
wasmedge-bindgen-macro = "0.4.1"
anyhow = "1.0.71"
async-stream = "0.3.5"
async-trait = "0.1.68"
atomic-counter = "1.0.1"
futures = "0.3.28"
gag = "1.0.0"
hyper = "0.14.26"
log = "0.4.17"
parking_lot = "0.12.1"
portpicker = "0.1.1"
prost = "0.11.9"
serde = { version = "1.0.163", features = ["derive"] }
thiserror = "1.0.40"
tokio = { version = "1.28.1", features = ["rt", "rt-multi-thread"] }
tokio-stream = "0.1.14"
tonic = "0.9.2"
tonic-health = "0.9.2"
tower = "0.4.13"
[lib]
name = "bao"
path = "src/lib.rs"
crate-type = ["cdylib"]
[build-dependencies]
tonic-build = "0.9.2"

View File

@ -2,46 +2,96 @@ package bao
import (
_ "embed"
"errors"
"github.com/second-state/WasmEdge-go/wasmedge"
bindgen "github.com/second-state/wasmedge-bindgen/host/go"
"github.com/hashicorp/go-plugin"
"io"
"io/fs"
"log"
"os"
"os/exec"
"os/signal"
"syscall"
)
//go:embed target/wasm32-wasi/release/bao.wasm
var wasm []byte
//go:generate protoc --proto_path=proto/ bao.proto --go_out=proto --go_opt=paths=source_relative --go-grpc_out=proto --go-grpc_opt=paths=source_relative
var conf *wasmedge.Configure
//go:embed target/release/bao
var baoPlugin []byte
var baoInstance Bao
type Bao interface {
Init() (uint32, error)
Write(id uint32, data []byte) error
Finalize(id uint32) ([]byte, error)
Destroy(id uint32) error
}
func init() {
wasmedge.SetLogErrorLevel()
conf = wasmedge.NewConfigure(wasmedge.WASI)
baoExec, err := os.CreateTemp("", "lumeportal")
_, err = baoExec.Write(baoPlugin)
if err != nil {
log.Fatalf("Error:", err.Error())
}
err = baoExec.Sync()
if err != nil {
log.Fatalf("Error:", err.Error())
}
err = baoExec.Chmod(fs.ModePerm)
if err != nil {
log.Fatalf("Error:", err.Error())
}
err = baoExec.Close()
if err != nil {
log.Fatalf("Error:", err.Error())
}
pluginMap := map[string]plugin.Plugin{
"bao": &BAOPlugin{},
}
client := plugin.NewClient(&plugin.ClientConfig{
HandshakeConfig: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "foo",
MagicCookieValue: "bar",
},
Plugins: pluginMap,
Cmd: exec.Command("sh", "-c", baoExec.Name()),
AllowedProtocols: []plugin.Protocol{plugin.ProtocolGRPC},
})
// Connect via RPC
rpcClient, err := client.Client()
if err != nil {
log.Fatalf("Error:", err.Error())
}
// Request the plugin
raw, err := rpcClient.Dispense("bao")
if err != nil {
log.Fatalf("Error:", err.Error())
}
baoInstance = raw.(Bao)
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-signalCh
err := os.Remove(baoExec.Name())
if err != nil {
log.Fatalf("Error:", err.Error())
}
}()
}
func ComputeBaoTree(reader io.Reader) ([]byte, error) {
var vm = wasmedge.NewVMWithConfig(conf)
var wasi = vm.GetImportModule(wasmedge.WASI)
wasi.InitWasi(
os.Args[1:], // The args
os.Environ(), // The envs
[]string{".:."}, // The mapping preopens
)
err := vm.LoadWasmBuffer(wasm)
if err != nil {
return nil, err
}
err = vm.Validate()
if err != nil {
return nil, err
}
bg := bindgen.New(vm)
bg.Instantiate()
_, _, err = bg.Execute("init")
instance, err := baoInstance.Init()
if err != nil {
bg.Release()
return nil, err
}
@ -50,7 +100,7 @@ func ComputeBaoTree(reader io.Reader) ([]byte, error) {
n, err := reader.Read(b)
if n > 0 {
err := write(*bg, &b)
err := write(instance, &b)
if err != nil {
return nil, err
}
@ -59,7 +109,7 @@ func ComputeBaoTree(reader io.Reader) ([]byte, error) {
if err != nil {
var result []byte
if err == io.EOF {
result, err = finalize(*bg)
result, err = finalize(instance)
if err == nil {
return result, nil
}
@ -69,37 +119,38 @@ func ComputeBaoTree(reader io.Reader) ([]byte, error) {
}
}
func write(bg bindgen.Bindgen, bytes *[]byte) error {
_, _, err := bg.Execute("write", *bytes)
func write(instance uint32, bytes *[]byte) error {
err := baoInstance.Write(instance, *bytes)
if err != nil {
bg.Release()
derr := destroy(instance)
if derr != nil {
return derr
}
return err
}
if err != nil {
derr := destroy(instance)
if derr != nil {
return derr
}
return err
}
return nil
}
func finalize(bg bindgen.Bindgen) ([]byte, error) {
var byteResult []byte
result, _, err := bg.Execute("finalize")
func finalize(instance uint32) ([]byte, error) {
result, err := baoInstance.Finalize(instance)
if err != nil {
bg.Release()
derr := destroy(instance)
if derr != nil {
return nil, derr
}
return nil, err
}
// Iterate over each element in the result slice
for _, elem := range result {
// Type assert the element to []byte
byteSlice, ok := elem.([]byte)
if !ok {
// If the element is not a byte slice, return an error
return nil, errors.New("result element is not a byte slice")
}
// Concatenate the byte slice to the byteResult slice
byteResult = append(byteResult, byteSlice...)
}
return byteResult, nil
return result, nil
}
func destroy(instance uint32) error {
return baoInstance.Destroy(instance)
}

15
bao/build.rs Normal file
View File

@ -0,0 +1,15 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::configure()
.build_server(true)
.out_dir("src/proto")
.compile_well_known_types(true)
.include_file("mod.rs")
.type_attribute(".", "#[derive(serde::Deserialize)]")
.type_attribute(".", "#[derive(serde::Serialize)]")
.compile(&[
"proto/grpc_stdio.proto",
"proto/bao.proto"
], &["bao"])
.unwrap();
Ok(())
}

47
bao/client.go Normal file
View File

@ -0,0 +1,47 @@
package bao
import (
"context"
"git.lumeweb.com/LumeWeb/portal/bao/proto"
"github.com/golang/protobuf/ptypes/empty"
"github.com/golang/protobuf/ptypes/wrappers"
)
// GRPCClient is an implementation of KV that talks over RPC.
type GRPCClient struct{ client proto.BaoClient }
func (g *GRPCClient) Init() (uint32, error) {
init, err := g.client.Init(context.Background(), &empty.Empty{})
if err != nil {
return 0, err
}
return init.Value, nil
}
func (g *GRPCClient) Write(id uint32, data []byte) error {
_, err := g.client.Write(context.Background(), &proto.WriteRequest{Id: id, Data: data})
if err != nil {
return err
}
return nil
}
func (g *GRPCClient) Finalize(id uint32) ([]byte, error) {
tree, err := g.client.Finalize(context.Background(), &wrappers.UInt32Value{Value: id})
if err != nil {
return nil, err
}
return tree.Value, nil
}
func (g *GRPCClient) Destroy(id uint32) error {
_, err := g.client.Destroy(context.Background(), &wrappers.UInt32Value{Value: id})
if err != nil {
return err
}
return nil
}

21
bao/plugin.go Normal file
View File

@ -0,0 +1,21 @@
package bao
import (
"context"
"git.lumeweb.com/LumeWeb/portal/bao/proto"
"github.com/hashicorp/go-plugin"
"google.golang.org/grpc"
)
type BAOPlugin struct {
plugin.Plugin
Impl Bao
}
func (p *BAOPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error {
return nil
}
func (b *BAOPlugin) GRPCClient(_ context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &GRPCClient{client: proto.NewBaoClient(c)}, nil
}

20
bao/proto/bao.proto Normal file
View File

@ -0,0 +1,20 @@
syntax = "proto3";
import "google/protobuf/empty.proto";
import "google/protobuf/wrappers.proto";
option go_package = "git.lumeweb.com/LumeWeb/portal/bao/proto";
package bao;
message WriteRequest {
uint32 id = 1;
bytes data = 2;
}
service bao {
rpc Init (google.protobuf.Empty) returns (google.protobuf.UInt32Value);
rpc Write(WriteRequest) returns (google.protobuf.Empty);
rpc Finalize (google.protobuf.UInt32Value) returns (google.protobuf.BytesValue);
rpc Destroy (google.protobuf.UInt32Value) returns (google.protobuf.Empty);
}

View File

@ -0,0 +1,33 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
syntax = "proto3";
package plugin;
option go_package = "plugin";
import "google/protobuf/empty.proto";
// GRPCStdio is a service that is automatically run by the plugin process
// to stream any stdout/err data so that it can be mirrored on the plugin
// host side.
service GRPCStdio {
// StreamStdio returns a stream that contains all the stdout/stderr.
// This RPC endpoint must only be called ONCE. Once stdio data is consumed
// it is not sent again.
//
// Callers should connect early to prevent blocking on the plugin process.
rpc StreamStdio(google.protobuf.Empty) returns (stream StdioData);
}
// StdioData is a single chunk of stdout or stderr data that is streamed
// from GRPCStdio.
message StdioData {
enum Channel {
INVALID = 0;
STDOUT = 1;
STDERR = 2;
}
Channel channel = 1;
bytes data = 2;
}

34
bao/src/grpc/error.rs Normal file
View File

@ -0,0 +1,34 @@
use hyper::http::uri::InvalidUri;
use thiserror::Error as ThisError;
use tokio::sync::mpsc::error::SendError;
use tonic::transport::Error as TonicError;
use std::fmt::{Debug};
pub fn into_status(err: Error) -> tonic::Status {
tonic::Status::unknown(format!("{}", err))
}
#[derive(Debug, ThisError)]
pub enum Error {
#[error("Error with IO: {0}")]
Io(#[from] std::io::Error),
#[error("Error with tonic (gRPC) transport: {0}")]
TonicTransport(#[from] TonicError),
#[error("Error parsing string into a network address: {0}")]
AddrParser(#[from] std::net::AddrParseError),
#[error("Error sending on a mpsc channel: {0}")]
Send(String),
#[error("Invalid Uri: {0}")]
InvalidUri(#[from] InvalidUri),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl<T> From<SendError<T>> for Error {
fn from(_err: SendError<T>) -> Self {
Self::Send(format!(
"unable to send {} on a mpsc channel",
std::any::type_name::<T>()
))
}
}

View File

@ -0,0 +1,98 @@
// Copied from: https://github.com/hashicorp/go-plugin/blob/master/grpc_controller.go
use anyhow::{Context, Result};
use async_stream::stream;
use futures::stream::Stream;
use gag::BufferRedirect;
use std::io::Read;
use std::pin::Pin;
use tokio::time::{sleep, Duration};
use tokio_stream::StreamExt;
use tonic::{async_trait, Request, Response, Status};
use crate::proto::google::protobuf::Empty;
use crate::proto::plugin::grpc_stdio_server::{GrpcStdio, GrpcStdioServer};
use crate::proto::plugin::stdio_data::Channel;
use crate::proto::plugin::{StdioData};
use crate::grpc::error::into_status;
const CONSOLE_POLL_SLEEP_MILLIS: u64 = 500;
pub fn new_server() -> GrpcStdioServer<GrpcStdioImpl> {
GrpcStdioServer::new(GrpcStdioImpl {})
}
#[derive(Clone)]
pub struct GrpcStdioImpl {}
impl GrpcStdioImpl {
fn new_combined_stream() -> Result<<Self as GrpcStdio>::StreamStdioStream, Status> {
log::trace!("new_inner_stream called. Asked for a stream of stdout and stderr");
log::info!("Gagging stdout and stderr to a buffer for redirection to plugin's host.",);
let stdoutbuf = BufferRedirect::stdout()
.context("Failed to create a BufferRedirec from stdout")
.map_err(|e| e.into())
.map_err(into_status)?;
let stdout_stream = GrpcStdioImpl::new_stream("stdout", Channel::Stdout as i32, stdoutbuf);
let stderrbuf = BufferRedirect::stderr()
.context("Failed to create a BufferRedirec from stderr")
.map_err(|e| e.into())
.map_err(into_status)?;
let stderr_stream = GrpcStdioImpl::new_stream("stderr", Channel::Stderr as i32, stderrbuf);
let merged_stream = stdout_stream.merge(stderr_stream);
Ok(Box::pin(merged_stream))
}
fn new_stream(
stream_name: &'static str,
channel: i32,
mut redirected_buf: BufferRedirect,
) -> impl Stream<Item = Result<StdioData, Status>> {
stream! {
loop {
log::trace!("beginning next iteration of {} reading and streaming...", stream_name);
let mut readbuf = String::new();
match redirected_buf.read_to_string(&mut readbuf) {
Ok(len) => match len{
0 => {
log::trace!("{} had zero bytes. Sleeping to avoid polling...", stream_name);
sleep(Duration::from_millis(CONSOLE_POLL_SLEEP_MILLIS)).await;
},
_ => {
log::trace!("Sending {} {} bytes of data: {}", stream_name, len, readbuf);
yield Ok(StdioData{
channel,
data: readbuf.into_bytes(),
});
},
},
Err(e) => {
log::error!("Error reading {} data: {:?}", stream_name, e);
yield Err(Status::unknown(format!("Error reading from Stderr of plugin's process: {:?}", e)));
},
}
}
}
}
}
#[async_trait]
impl GrpcStdio for GrpcStdioImpl {
type StreamStdioStream =
Pin<Box<dyn Stream<Item = Result<StdioData, Status>> + Send + 'static>>;
async fn stream_stdio(
&self,
_req: Request<Empty>,
) -> Result<Response<Self::StreamStdioStream>, Status> {
log::trace!("stream_stdio called.");
let s = GrpcStdioImpl::new_combined_stream()?;
log::trace!("stream_stdio responding with a stream of StdioData.",);
Ok(Response::new(s))
}
}

2
bao/src/grpc/mod.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod grpc_stdio;
pub mod error;

View File

@ -1,32 +0,0 @@
use abao::encode::Encoder;
use std::io::{Cursor, Write};
#[allow(unused_imports)]
use wasmedge_bindgen::*;
use wasmedge_bindgen_macro::*;
static mut TREE: Option<Vec<u8>> = None;
static mut CURSOR: Option<Cursor<Vec<u8>>> = None;
static mut ENCODER: Option<Encoder<Cursor<Vec<u8>>>> = None;
#[wasmedge_bindgen]
pub unsafe fn init() {
TREE = Option::Some(Vec::new());
CURSOR = Option::Some(Cursor::new(TREE.take().unwrap()));
ENCODER = Option::Some(Encoder::new_outboard(CURSOR.take().unwrap()));
}
#[wasmedge_bindgen]
pub unsafe fn write(v: Vec<u8>) -> Result<u64, String> {
let encoder = ENCODER.take().unwrap();
let bytes_written = encoder.to_owned().write(&v).map_err(|e| e.to_string())?;
ENCODER = Some(encoder); // Restore the value
Ok(bytes_written as u64)
}
#[wasmedge_bindgen]
pub unsafe fn finalize() -> Vec<u8> {
let mut encoder = ENCODER.take().unwrap();
let bytes = encoder.finalize().unwrap().as_bytes().to_vec();
ENCODER = Some(encoder); // Restore the value
bytes
}

116
bao/src/main.rs Normal file
View File

@ -0,0 +1,116 @@
#![feature(async_fn_in_trait)]
#![allow(incomplete_features)]
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::io::{Cursor, Write};
use std::sync::{Arc};
use abao::encode::Encoder;
use async_trait::async_trait;
use atomic_counter::{AtomicCounter, ConsistentCounter};
use parking_lot::Mutex;
use tonic::{Request, Response, Status};
use tonic::transport::Server;
use tonic_health::server::HealthReporter;
use crate::proto::bao::bao_server::{Bao, BaoServer};
use crate::proto::bao::WriteRequest;
use crate::proto::google::protobuf::{BytesValue, Empty, UInt32Value};
use crate::unique_port::UniquePort;
mod proto;
mod unique_port;
mod grpc;
async fn driver_service_status(mut reporter: HealthReporter) {
reporter.set_serving::<BaoServer<BaoService>>().await;
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut uport = UniquePort::default();
let port = uport.get_unused_port().expect("No ports free");
println!("{}", format!("1|1|tcp|127.0.0.1:{}|grpc", port));
let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
health_reporter.set_serving::<BaoServer<BaoService>>().await;
tokio::spawn(driver_service_status(health_reporter.clone()));
let addr = format!("127.0.0.1:{}", port).parse().unwrap();
let bao_service = BaoService::default();
let server = BaoServer::new(bao_service);
Server::builder()
.add_service(health_service)
.add_service(server)
.add_service(grpc::grpc_stdio::new_server())
.serve(addr)
.await?;
Ok(())
}
#[derive(Debug, Default)]
pub struct BaoService {
requests: Arc<Mutex<HashMap<u32, Encoder<Cursor<Vec<u8>>>>>>,
counter: ConsistentCounter,
}
#[async_trait]
impl Bao for BaoService {
async fn init(&self, _request: Request<Empty>) -> Result<Response<UInt32Value>, Status> {
let next_id = self.counter.inc() as u32;
let tree = Vec::new();
let cursor = Cursor::new(tree);
let encoder = Encoder::new(cursor);
let mut req = self.requests.lock();
req.insert(next_id, encoder);
Ok(Response::new(UInt32Value { value: next_id }))
}
async fn write(&self, request: Request<WriteRequest>) -> Result<Response<Empty>, Status> {
let r = request.into_inner();
let mut req = self.requests.lock();
if let Some(encoder) = req.get_mut(&r.id) {
encoder.write(&r.data)?;
} else {
return Err(Status::invalid_argument("invalid id"));
}
Ok(Response::new(Empty::default()))
}
async fn finalize(
&self,
request: Request<UInt32Value>,
) -> Result<Response<BytesValue>, Status> {
let r = request.into_inner();
let mut req = self.requests.lock();
match req.entry(r.value) {
Entry::Occupied(mut entry) => {
let encoder = entry.get_mut();
let ret = encoder.finalize().unwrap();
let bytes = ret.as_bytes().to_vec();
Ok(Response::new(BytesValue { value: bytes }))
}
Entry::Vacant(_) => {
Err(Status::invalid_argument("invalid id"))
}
}
}
async fn destroy(&self, request: Request<UInt32Value>) -> Result<Response<Empty>, Status> {
let r = request.into_inner();
let mut req = self.requests.lock();
if req.remove(&r.value).is_none() {
return Err(Status::invalid_argument("invalid id"));
}
Ok(Response::new(Empty::default()))
}
}

45
bao/src/unique_port.rs Normal file
View File

@ -0,0 +1,45 @@
use portpicker::Port;
pub struct UniquePort {
vended_ports: Vec<Port>,
}
impl UniquePort {
pub fn new() -> Self {
Self {
vended_ports: vec![],
}
}
pub fn get_unused_port(&mut self) -> Option<Port> {
let mut counter = 0;
loop {
counter += 1;
if counter > 1000 {
// no luck in 1000 tries? Give up!
return None;
}
match portpicker::pick_unused_port() {
None => return None,
Some(p) => {
if self.vended_ports.contains(&p) {
log::trace!("Skipped port: {} because it is in the list of previously vended ports: {:?}", p, self.vended_ports);
continue;
} else {
log::trace!("Vending port: {}", p);
self.vended_ports.push(p);
return Some(p);
}
}
}
}
}
}
impl Default for UniquePort {
fn default() -> Self {
Self::new()
}
}