This repository has been archived on 2023-04-09. You can view files and clone it, but cannot push or open issues or pull requests.
kernel-protomux/index.js

667 lines
16 KiB
JavaScript

const b4a = require('b4a')
const c = require('compact-encoding')
const queueTick = require('queue-tick')
const safetyCatch = require('safety-catch')
const MAX_BUFFERED = 32768
const MAX_BACKLOG = Infinity // TODO: impl "open" backpressure
const MAX_BATCH = 8 * 1024 * 1024
class Channel {
constructor (mux, info, userData, protocol, aliases, id, handshake, messages, onopen, onclose, ondestroy) {
this.userData = userData
this.protocol = protocol
this.aliases = aliases
this.id = id
this.handshake = null
this.messages = []
this.opened = false
this.closed = false
this.destroyed = false
this.onopen = onopen
this.onclose = onclose
this.ondestroy = ondestroy
this._handshake = handshake
this._mux = mux
this._info = info
this._localId = 0
this._remoteId = 0
this._active = 0
this._extensions = null
this._decBound = this._dec.bind(this)
this._decAndDestroyBound = this._decAndDestroy.bind(this)
for (const m of messages) this.addMessage(m)
}
open (handshake) {
const id = this._mux._free.length > 0
? this._mux._free.pop()
: this._mux._local.push(null) - 1
this._info.opened++
this._localId = id + 1
this._mux._local[id] = this
if (this._remoteId === 0) {
this._info.outgoing.push(this._localId)
}
const state = { buffer: null, start: 2, end: 2 }
c.uint.preencode(state, this._localId)
c.string.preencode(state, this.protocol)
c.buffer.preencode(state, this.id)
if (this._handshake) this._handshake.preencode(state, handshake)
state.buffer = this._mux._alloc(state.end)
state.buffer[0] = 0
state.buffer[1] = 1
c.uint.encode(state, this._localId)
c.string.encode(state, this.protocol)
c.buffer.encode(state, this.id)
if (this._handshake) this._handshake.encode(state, handshake)
this._mux._write0(state.buffer)
}
_dec () {
if (--this._active === 0 && this.closed === true) this._destroy()
}
_decAndDestroy (err) {
this._dec()
this._mux._safeDestroy(err)
}
_fullyOpenSoon () {
this._mux._remote[this._remoteId - 1].session = this
queueTick(this._fullyOpen.bind(this))
}
_fullyOpen () {
if (this.opened === true || this.closed === true) return
const remote = this._mux._remote[this._remoteId - 1]
this.opened = true
this.handshake = this._handshake ? this._handshake.decode(remote.state) : null
this._track(this.onopen(this.handshake, this))
remote.session = this
remote.state = null
if (remote.pending !== null) this._drain(remote)
}
_drain (remote) {
for (let i = 0; i < remote.pending.length; i++) {
const p = remote.pending[i]
this._mux._buffered -= byteSize(p.state)
this._recv(p.type, p.state)
}
remote.pending = null
this._mux._resumeMaybe()
}
_track (p) {
if (isPromise(p) === true) {
this._active++
p.then(this._decBound, this._decAndDestroyBound)
}
}
_close (isRemote) {
if (this.closed === true) return
this.closed = true
this._info.opened--
if (this._remoteId > 0) {
this._mux._remote[this._remoteId - 1] = null
this._remoteId = 0
// If remote has acked, we can reuse the local id now
// otherwise, we need to wait for the "ack" to arrive
this._mux._free.push(this._localId - 1)
}
this._mux._local[this._localId - 1] = null
this._localId = 0
this._mux._gc(this._info)
this._track(this.onclose(isRemote, this))
if (this._active === 0) this._destroy()
}
_destroy () {
if (this.destroyed === true) return
this.destroyed = true
this._track(this.ondestroy(this))
}
_recv (type, state) {
if (type < this.messages.length) {
this.messages[type].recv(state, this)
}
}
cork () {
this._mux.cork()
}
uncork () {
this._mux.uncork()
}
close () {
if (this.closed === true) return
const state = { buffer: null, start: 2, end: 2 }
c.uint.preencode(state, this._localId)
state.buffer = this._mux._alloc(state.end)
state.buffer[0] = 0
state.buffer[1] = 3
c.uint.encode(state, this._localId)
this._close(false)
this._mux._write0(state.buffer)
}
addMessage (opts) {
if (!opts) return this._skipMessage()
const type = this.messages.length
const encoding = opts.encoding || c.raw
const onmessage = opts.onmessage || noop
const s = this
const typeLen = encodingLength(c.uint, type)
const m = {
type,
encoding,
onmessage,
recv (state, session) {
session._track(m.onmessage(encoding.decode(state), session))
},
send (m, session = s) {
if (session.closed === true) return false
const mux = session._mux
const state = { buffer: null, start: 0, end: typeLen }
if (mux._batch !== null) {
encoding.preencode(state, m)
state.buffer = mux._alloc(state.end)
c.uint.encode(state, type)
encoding.encode(state, m)
mux._pushBatch(session._localId, state.buffer)
return true
}
c.uint.preencode(state, session._localId)
encoding.preencode(state, m)
state.buffer = mux._alloc(state.end)
c.uint.encode(state, session._localId)
c.uint.encode(state, type)
encoding.encode(state, m)
return mux.stream.write(state.buffer)
}
}
this.messages.push(m)
return m
}
_skipMessage () {
const type = this.messages.length
const m = {
type,
encoding: c.raw,
onmessage: noop,
recv (state, session) {},
send (m, session) {}
}
this.messages.push(m)
return m
}
}
module.exports = class Protomux {
constructor (stream, { alloc } = {}) {
this.isProtomux = true
this.stream = stream
this.corked = 0
this._alloc = alloc || (typeof stream.alloc === 'function' ? stream.alloc.bind(stream) : b4a.allocUnsafe)
this._safeDestroyBound = this._safeDestroy.bind(this)
this._remoteBacklog = 0
this._buffered = 0
this._paused = false
this._remote = []
this._local = []
this._free = []
this._batch = null
this._batchState = null
this._infos = new Map()
this._notify = new Map()
this.stream.on('data', this._ondata.bind(this))
this.stream.on('end', this._onend.bind(this))
this.stream.on('error', noop) // we handle this in "close"
this.stream.on('close', this._shutdown.bind(this))
}
static from (stream, opts) {
if (stream.isProtomux) return stream
return new this(stream, opts)
}
static isProtomux (mux) {
return typeof mux === 'object' && mux.isProtomux === true
}
* [Symbol.iterator] () {
for (const session of this._local) {
if (session !== null) yield session
}
}
cork () {
if (++this.corked === 1) {
this._batch = []
this._batchState = { buffer: null, start: 0, end: 1 }
}
}
uncork () {
if (--this.corked === 0) {
this._sendBatch(this._batch, this._batchState)
this._batch = null
this._batchState = null
}
}
pair ({ protocol, id = null }, notify) {
this._notify.set(toKey(protocol, id), notify)
}
unpair ({ protocol, id = null }) {
this._notify.delete(toKey(protocol, id))
}
opened ({ protocol, id = null }) {
const key = toKey(protocol, id)
const info = this._infos.get(key)
return info ? info.opened > 0 : false
}
createChannel ({ userData = null, protocol, aliases = [], id = null, unique = true, handshake = null, messages = [], onopen = noop, onclose = noop, ondestroy = noop }) {
if (this.stream.destroyed) return null
const info = this._get(protocol, id, aliases)
if (unique && info.opened > 0) return null
if (info.incoming.length === 0) {
return new Channel(this, info, userData, protocol, aliases, id, handshake, messages, onopen, onclose, ondestroy)
}
this._remoteBacklog--
const remoteId = info.incoming.shift()
const r = this._remote[remoteId - 1]
if (r === null) return null
const session = new Channel(this, info, userData, protocol, aliases, id, handshake, messages, onopen, onclose, ondestroy)
session._remoteId = remoteId
session._fullyOpenSoon()
return session
}
_pushBatch (localId, buffer) {
if (this._batchState.end >= MAX_BATCH) {
this._sendBatch(this._batch, this._batchState)
this._batch = []
this._batchState = { buffer: null, start: 0, end: 1 }
}
if (this._batch.length === 0 || this._batch[this._batch.length - 1].localId !== localId) {
this._batchState.end++
c.uint.preencode(this._batchState, localId)
}
c.buffer.preencode(this._batchState, buffer)
this._batch.push({ localId, buffer })
}
_sendBatch (batch, state) {
if (batch.length === 0) return
let prev = batch[0].localId
state.buffer = this._alloc(state.end)
state.buffer[state.start++] = 0
state.buffer[state.start++] = 0
c.uint.encode(state, prev)
for (let i = 0; i < batch.length; i++) {
const b = batch[i]
if (prev !== b.localId) {
state.buffer[state.start++] = 0
c.uint.encode(state, (prev = b.localId))
}
c.buffer.encode(state, b.buffer)
}
this.stream.write(state.buffer)
}
_get (protocol, id, aliases = []) {
const key = toKey(protocol, id)
let info = this._infos.get(key)
if (info) return info
info = { key, protocol, aliases: [], id, pairing: 0, opened: 0, incoming: [], outgoing: [] }
this._infos.set(key, info)
for (const alias of aliases) {
const key = toKey(alias, id)
info.aliases.push(key)
this._infos.set(key, info)
}
return info
}
_gc (info) {
if (info.opened === 0 && info.outgoing.length === 0 && info.incoming.length === 0) {
this._infos.delete(info.key)
for (const alias of info.aliases) this._infos.delete(alias)
}
}
_ondata (buffer) {
try {
const state = { buffer, start: 0, end: buffer.byteLength }
this._decode(c.uint.decode(state), state)
} catch (err) {
this._safeDestroy(err)
}
}
_onend () { // TODO: support half open mode for the users who wants that here
this.stream.end()
}
_decode (remoteId, state) {
const type = c.uint.decode(state)
if (remoteId === 0) {
this._oncontrolsession(type, state)
return
}
const r = remoteId <= this._remote.length ? this._remote[remoteId - 1] : null
// if the channel is closed ignore - could just be a pipeline message...
if (r === null) return
if (r.pending !== null) {
this._bufferMessage(r, type, state)
return
}
r.session._recv(type, state)
}
_oncontrolsession (type, state) {
switch (type) {
case 0:
this._onbatch(state)
break
case 1:
this._onopensession(state)
break
case 2:
this._onrejectsession(state)
break
case 3:
this._onclosesession(state)
break
}
}
_bufferMessage (r, type, { buffer, start, end }) {
const state = { buffer, start, end } // copy
r.pending.push({ type, state })
this._buffered += byteSize(state)
this._pauseMaybe()
}
_pauseMaybe () {
if (this._paused === true || this._buffered <= MAX_BUFFERED) return
this._paused = true
this.stream.pause()
}
_resumeMaybe () {
if (this._paused === false || this._buffered > MAX_BUFFERED) return
this._paused = false
this.stream.resume()
}
_onbatch (state) {
const end = state.end
let remoteId = c.uint.decode(state)
while (state.end > state.start) {
const len = c.uint.decode(state)
if (len === 0) {
remoteId = c.uint.decode(state)
continue
}
state.end = state.start + end
this._decode(remoteId, state)
state.end = end
}
}
_onopensession (state) {
const remoteId = c.uint.decode(state)
const protocol = c.string.decode(state)
const id = c.buffer.decode(state)
// remote tried to open the control session - auto reject for now
// as we can use as an explicit control protocol declaration if we need to
if (remoteId === 0) {
this._rejectSession(0)
return
}
const rid = remoteId - 1
const info = this._get(protocol, id)
// allow the remote to grow the ids by one
if (this._remote.length === rid) {
this._remote.push(null)
}
if (rid >= this._remote.length || this._remote[rid] !== null) {
throw new Error('Invalid open message')
}
if (info.outgoing.length > 0) {
const localId = info.outgoing.shift()
const session = this._local[localId - 1]
if (session === null) { // we already closed the channel - ignore
this._free.push(localId - 1)
return
}
this._remote[rid] = { state, pending: null, session: null }
session._remoteId = remoteId
session._fullyOpen()
return
}
this._remote[rid] = { state, pending: [], session: null }
if (++this._remoteBacklog > MAX_BACKLOG) {
throw new Error('Remote exceeded backlog')
}
info.pairing++
info.incoming.push(remoteId)
this._requestSession(protocol, id, info).catch(this._safeDestroyBound)
}
_onrejectsession (state) {
const localId = c.uint.decode(state)
// TODO: can be done smarter...
for (const info of this._infos.values()) {
const i = info.outgoing.indexOf(localId)
if (i === -1) continue
info.outgoing.splice(i, 1)
const session = this._local[localId - 1]
this._free.push(localId - 1)
if (session !== null) session._close(true)
this._gc(info)
return
}
throw new Error('Invalid reject message')
}
_onclosesession (state) {
const remoteId = c.uint.decode(state)
if (remoteId === 0) return // ignore
const rid = remoteId - 1
const r = rid < this._remote.length ? this._remote[rid] : null
if (r === null) return
if (r.session !== null) r.session._close(true)
}
async _requestSession (protocol, id, info) {
const notify = this._notify.get(toKey(protocol, id)) || this._notify.get(toKey(protocol, null))
if (notify) await notify(id)
if (--info.pairing > 0) return
while (info.incoming.length > 0) {
this._rejectSession(info, info.incoming.shift())
}
this._gc(info)
}
_rejectSession (info, remoteId) {
if (remoteId > 0) {
const r = this._remote[remoteId - 1]
if (r.pending !== null) {
for (let i = 0; i < r.pending.length; i++) {
this._buffered -= byteSize(r.pending[i].state)
}
}
this._remote[remoteId - 1] = null
this._resumeMaybe()
}
const state = { buffer: null, start: 2, end: 2 }
c.uint.preencode(state, remoteId)
state.buffer = this._alloc(state.end)
state.buffer[0] = 0
state.buffer[1] = 2
c.uint.encode(state, remoteId)
this._write0(state.buffer)
}
_write0 (buffer) {
if (this._batch !== null) {
this._pushBatch(0, buffer.subarray(1))
return
}
this.stream.write(buffer)
}
destroy (err) {
this.stream.destroy(err)
}
_safeDestroy (err) {
safetyCatch(err)
this.stream.destroy(err)
}
_shutdown () {
for (const s of this._local) {
if (s !== null) s._close(true)
}
}
}
function noop () {}
function toKey (protocol, id) {
return protocol + '##' + (id ? b4a.toString(id, 'hex') : '')
}
function byteSize (state) {
return 512 + (state.end - state.start)
}
function isPromise (p) {
return !!(p && typeof p.then === 'function')
}
function encodingLength (enc, val) {
const state = { buffer: null, start: 0, end: 0 }
enc.preencode(state, val)
return state.end
}