rewrote it

This commit is contained in:
Mathias Buus 2021-12-30 21:13:47 +01:00
parent c70c9f989f
commit 45e10f4bec
4 changed files with 494 additions and 311 deletions

113
README.md
View File

@ -10,8 +10,121 @@ npm install protomux
``` js
const Protomux = require('protomux')
const c = require('compact-encoding')
// By framed stream, it has be a stream that preserves the messages, ie something that length prefixes
// like @hyperswarm/secret-stream
const mux = new Protomux(aStreamThatFrames)
// Now add some protocols
const cool = mux.addProtocol({
name: 'cool-protocol',
version: {
major: 1,
minor: 0
},
// an array of compact encoders, each encoding/decoding the messages sent
messages: [
c.string,
c.bool
],
onremoteopen () {
console.log('the other side opened this protocol!')
},
onemoteclose () {
console.log('the other side closed this protocol!')
},
onmessage (type, message) {
console.log('the other side sent a message', type, message)
}
})
// And send some messages
cool.send(0, 'a string')
cool.send(1, true)
```
## API
#### `mux = new Protomux(stream, [options])`
Make a new instance. `stream` should be a framed stream, preserving the messages written.
Options include:
``` js
{
// Called when the muxer wants to allocate a message that is written, defaults to Buffer.allocUnsafe.
alloc (size) {},
// Hook that is called when an unknown protocol is received. Should return true/false.
async onacceptprotocol ({ name, version }) {}
// How many protocols can be remote open, that we haven't opened yet?
// Only used if you don't provide an accept hook.
backlog: 128
}
```
#### `const p = mux.addProtocol(opts)`
Add a new protocol.
Options include:
``` js
{
// Used to match the protocol
name: 'name of the protocol',
// You can have multiple versions of the same protocol on the same stream.
// Protocols are matched using the major version only.
version: {
major: 0,
minor: 0
},
// Array of the message types you want to send/receive. Should be compact-encoders
messages: [
...
],
// Called when the remote side adds this protocol.
// Errors here are caught and forwared to stream.destroy
async onremoteopen () {},
// Called when the remote side closes or rejects this protocol.
// Errors here are caught and forwared to stream.destroy
async onremoteclose () {},
// Called when the remote sends a message
// Errors here are caught and forwared to stream.destroy
async onmessage (type, message) {}
}
```
Each of the functions can also be set directly on the instance with the same name.
#### `p.close()`
Closes the protocol
#### `p.send(type, message)`
Send a message, type is the offset into the messages array.
#### `p.cork()`
Corking the protocol, makes it buffer messages and send them all in a batch when it uncorks.
#### `p.uncork()`
Uncork and send the batch.
#### `mux.cork()`
Same as `p.cork` but on the muxer instance.
#### `mux.uncork()`
Same as `p.uncork` but on the muxer instance.
## License
MIT

532
index.js
View File

@ -1,252 +1,124 @@
const c = require('compact-encoding')
const m = require('./messages')
const b4a = require('b4a')
const safetyCatch = require('safety-catch')
const { addProtocol } = require('./messages')
const EMPTY = []
class Protocol {
constructor (muxer, offset, protocol) {
this.muxer = muxer
this.stream = muxer.stream
this.start = offset
this.end = offset + protocol.messages.length
this.name = protocol.name
this.version = protocol.version || { major: 0, minor: 0 }
this.messages = protocol.messages.length
constructor (mux) {
this.mux = mux
this.name = null
this.version = null
this.messages = EMPTY
this.offset = 0
this.length = 0
this.opened = false
this.remoteVersion = null
this.remoteOffset = 0
this.remoteEnd = 0
this.remoteOpened = false
this.removed = false
this.encodings = protocol.messages
this.remoteClosed = false
this.onmessage = noop
this.onopen = noop
this.onclose = noop
this.onremoteopen = noop
this.onremoteclose = noop
}
get corked () {
return this.muxer.corked
_attach ({ name, version = { major: 0, minor: 0 }, messages, onmessage = noop, onremoteopen = noop, onremoteclose = noop }) {
const opened = this.opened
this.name = name
this.version = version
this.messages = messages
this.offset = this.mux.offset
this.length = messages.length
this.opened = true
this.corked = false
this.onmessage = onmessage
this.onremoteopen = onremoteopen
this.onremoteclose = onremoteclose
return !opened
}
cork () {
this.muxer.cork()
if (this.corked) return
this.corked = true
this.mux.cork()
}
uncork () {
this.muxer.uncork()
if (!this.corked) return
this.corked = false
this.mux.uncork()
}
send (type, message) {
const t = this.start + type
const enc = this.encodings[type]
if (!this.opened) return false
if (this.muxer._handshakeSent === false) {
this.muxer._sendHandshake()
}
const t = this.offset + type
const m = this.messages[type]
if (this.muxer.corked > 0) {
this.muxer._batch.push({ type: t, encoding: enc, message })
return false
}
const state = { start: 0, end: 0, buffer: null }
c.uint.preencode(state, t)
enc.preencode(state, message)
state.buffer = this.muxer._alloc(state.end)
c.uint.encode(state, t)
enc.encode(state, message)
return this.stream.write(state.buffer)
return this.mux._push(t, m, message)
}
recv (type, state) {
this.onmessage(type, this.encodings[type].decode(state))
close () {
if (this.opened === false) return
this.opened = false
this.mux._unopened++
const offset = this.offset
this.version = null
this.messages = EMPTY
this.offset = 0
this.length = 0
this.onmessage = this.onremoteopen = this.onremoteclose = noop
this.mux._push(2, c.uint, offset)
this._gc()
if (this.corked) this.uncork()
}
_gc () {
if (this.opened || this.remoteOpened) return
this.mux._removeProtocol(this)
}
_recv (type, state) {
if (type >= this.messages.length) return
const m = this.messages[type]
const message = m.decode(state)
this.mux._catch(this.onmessage(type, message))
}
}
module.exports = class Protomux {
constructor (stream, protocols, opts = {}) {
constructor (stream, { backlog = 128, alloc, onacceptprotocol } = {}) {
this.stream = stream
this.protocols = []
this.offset = 2
this.remoteProtocols = []
this.remoteOffset = 2
this.remoteHandshake = null
this.onhandshake = noop
this.offset = 4 // 4 messages reserved
this.corked = 0
this.backlog = backlog
this.onacceptprotocol = onacceptprotocol || (() => this._unopened < this.backlog)
this._unopened = 0
this._batch = null
this._unmatchedProtocols = []
this._handshakeSent = false
this._alloc = opts.alloc || (typeof stream.alloc === 'function' ? stream.alloc.bind(stream) : Buffer.allocUnsafe)
for (const p of protocols) this.addProtocol(p)
this._alloc = alloc || (typeof stream.alloc === 'function' ? stream.alloc.bind(stream) : b4a.allocUnsafe)
this._safeDestroyBound = this._safeDestroy.bind(this)
this.stream.on('data', this._ondata.bind(this))
queueMicrotask(this._sendHandshake.bind(this))
this.stream.on('close', this._shutdown.bind(this))
}
remoteOpened (name, version) {
for (const { remote } of this.remoteProtocols) {
if (remote.name === name || (version === undefined || version.major === remote.version.major)) return true
}
for (const { remote } of this._unmatchedProtocols) {
if (remote.name === name || (version === undefined || version.major === remote.version.major)) return true
}
return false
}
addProtocol (p) {
const local = new Protocol(this, this.offset, p)
this.protocols.push(local)
this.offset += p.messages.length
for (let i = 0; i < this._unmatchedProtocols.length; i++) {
const { start, remote } = this._unmatchedProtocols[i]
if (remote.name !== local.name || remote.version.major !== local.version.major) continue
local.remoteOpened = true
this._unmatchedProtocols.splice(i, 1)
const end = start + Math.min(remote.messages, local.messages)
this.remoteProtocols.push({ local, remote, start, end })
break
}
return local
}
removeProtocol (p) {
const { name, version = { major: 0, minor: 0 } } = typeof p === 'string' ? { name: p, version: undefined } : p
for (let i = 0; i < this.protocols.length; i++) {
const local = this.protocols[i]
if (local.name !== name || local.version.major !== version.major) continue
p.removed = true
this.protocols.splice(i, 1)
}
for (let i = 0; i < this.remoteProtocols.length; i++) {
const { local, remote, start } = this.remoteProtocols[i]
if (local.name !== name || local.version.major !== version.major) continue
this.remoteProtocols.splice(i, 1)
this._unmatchedProtocols.push({ start, remote })
}
}
addRemoteProtocol (p) {
if (!p.version) p = { name: p.name, version: { major: 0, minor: 0 }, messages: p.messages }
const local = this.get(p.name)
const start = this.remoteOffset
this.remoteOffset += p.messages
if (!local || local.version.major !== p.version.major) {
this._unmatchedProtocols.push({ start, remote: p })
return
}
if (local.remoteOpened) {
this.destroy(new Error('Remote sent duplicate protocols'))
return
}
const end = start + Math.min(p.messages, local.messages)
this.remoteProtocols.push({ local, remote: p, start, end })
local.remoteOpened = true
local.onopen()
}
removeRemoteProtocol ({ name, version = { major: 0, minor: 0 } }) {
for (let i = 0; i < this.remoteProtocols.length; i++) {
const { local } = this.remoteProtocols[i]
if (local.name !== name || local.version.major !== version.major) continue
this.remoteProtocols.splice(i, 1)
local.remoteOpened = false
local.onclose()
break
}
for (let i = 0; i < this._unmatchedProtocols.length; i++) {
const { remote } = this._unmatchedProtocols[i]
if (remote.name !== name || remote.version.major !== version.major) continue
this._unmatchedProtocols.splice(i, 1)
break
}
}
_ondata (buffer) {
if (buffer.byteLength === 0) return // keep alive
const state = { start: 0, end: buffer.byteLength, buffer }
try {
this._recv(state)
} catch (err) {
this.destroy(err)
}
}
_recv (state) {
const t = c.uint.decode(state)
if (t < 2) {
if (t === 0) {
this._recvBatch(state)
} else {
this._recvHandshake(state)
}
return
}
for (let i = 0; i < this.remoteProtocols.length; i++) {
const p = this.remoteProtocols[i]
if (p.start <= t && t < p.end) {
p.local.recv(t - p.start, state)
break
}
}
state.start = state.end
}
_recvBatch (state) {
const end = state.end
while (state.start < state.end) {
const len = c.uint.decode(state)
state.end = state.start + len
this._recv(state)
state.end = end
}
}
_recvHandshake (state) {
if (this.remoteHandshake !== null) {
this.destroy(new Error('Double handshake'))
return
}
this.remoteHandshake = m.handshake.decode(state)
for (const p of this.remoteHandshake.protocols) this.addRemoteProtocol(p)
this.onhandshake(this.remoteHandshake)
}
destroy (err) {
this._handshakeSent = true // just to avoid sending it again
this.stream.destroy(err)
}
get (name) {
for (const p of this.protocols) {
if (p.name === name) return p
}
return null
sendKeepAlive () {
this.stream.write(this._alloc(0))
}
cork () {
@ -285,35 +157,237 @@ module.exports = class Protomux {
this.stream.write(state.buffer)
}
sendKeepAlive () {
this.stream.write(this._alloc(0))
hasProtocol (opts) {
return !!this.getProtocol(opts)
}
_sendHandshake () {
if (this._handshakeSent) return
this._handshakeSent = true
getProtocol ({ name, version }) {
return this._getProtocol(name, version, false)
}
const hs = {
protocols: this.protocols
addProtocol (opts) {
const p = this._getProtocol(opts.name, (opts.version && opts.version.major) || 0, true)
if (opts.cork) p.cork()
if (!p._attach(opts)) return p
this._unopened--
this.offset += p.length
this._push(1, addProtocol, {
name: p.name,
version: p.version,
offset: p.offset,
length: p.length
})
return p
}
destroy (err) {
this.stream.destroy(err)
}
_shutdown () {
while (this.protocols.length) {
const p = this.protocols.pop()
if (!p.remoteOpened) continue
if (p.remoteClosed) continue
p.remoteOpened = false
p.remoteClosed = true
this._catch(p.onremoteclose())
}
}
_safeDestroy (err) {
safetyCatch(err)
this.destroy(err)
}
_catch (p) {
if (isPromise(p)) p.catch(this._safeDestroyBound)
}
async _acceptMaybe (added) {
let accept = false
try {
accept = await this.onacceptprotocol(added)
} catch (err) {
this._safeDestroy(err)
return
}
if (this.corked > 0) {
this._batch.push({ type: 1, encoding: m.handshake, message: hs })
if (!accept) this._rejectProtocol(added)
}
_rejectProtocol (added) {
for (let i = 0; i < this.protocols.length; i++) {
const p = this.protocols[i]
if (p.opened || p.name !== added.name || !p.remoteOpened) continue
if (p.remoteVersion.major !== added.version.major) continue
this._unopened--
this.protocols.splice(i, 1)
this._push(3, c.uint, added.offset)
return
}
}
_ondata (buffer) {
if (buffer.byteLength === 0) return // keep alive
const end = buffer.byteLength
const state = { start: 0, end, buffer }
try {
const type = c.uint.decode(state)
if (type === 0) this._recvBatch(end, state)
else this._recv(type, state)
} catch (err) {
this._safeDestroy(err)
}
}
_getProtocol (name, major, upsert) {
for (let i = 0; i < this.protocols.length; i++) {
const p = this.protocols[i]
const v = p.remoteVersion === null ? p.version : p.remoteVersion
if (p.name === name && (v !== null && v.major === major)) return p
}
if (!upsert) return null
const p = new Protocol(this)
this.protocols.push(p)
this._unopened++
return p
}
_removeProtocol (p) {
const i = this.protocols.indexOf(this)
if (i > -1) this.protocols.splice(i, 1)
if (!p.opened) this._unopened--
}
_recvAddProtocol (state) {
const add = addProtocol.decode(state)
const p = this._getProtocol(add.name, add.version.major, true)
if (p.remoteOpened) throw new Error('Duplicate protocol received')
p.name = add.name
p.remoteVersion = add.version
p.remoteOffset = add.offset
p.remoteEnd = add.offset + add.length
p.remoteOpened = true
p.remoteClosed = false
if (p.opened) {
this._catch(p.onremoteopen())
} else {
this._acceptMaybe(add)
}
}
_recvRemoveProtocol (state) {
const offset = c.uint.decode(state)
for (let i = 0; i < this.protocols.length; i++) {
const p = this.protocols[i]
if (p.remoteOffset === offset && p.remoteOpened) {
p.remoteVersion = null
p.remoteOpened = false
p.remoteClosed = true
this._catch(p.onremoteclose())
p._gc()
return
}
}
}
_recvRejectedProtocol (state) {
const offset = c.uint.decode(state)
for (let i = 0; i < this.protocols.length; i++) {
const p = this.protocols[i]
if (p.offset === offset && !p.remoteClosed) {
p.remoteClosed = true
this._catch(p.onremoteclose())
p._gc()
}
}
}
_recvBatch (end, state) {
while (state.start < state.end) {
const len = c.uint.decode(state)
const type = c.uint.decode(state)
state.end = state.start + len
this._recv(type, state)
state.end = end
}
}
_recv (type, state) {
if (type < 4) {
if (type === 0) {
throw new Error('Invalid nested batch')
}
if (type === 1) {
this._recvAddProtocol(state)
return
}
if (type === 2) {
this._recvRemoveProtocol(state)
return
}
if (type === 3) {
this._recvRejectedProtocol(state)
return
}
return
}
// TODO: Consider make this array sorted by remoteOffset and use a bisect here.
// For now we use very few protocols in practice, so it might be overkill.
for (let i = 0; i < this.protocols.length; i++) {
const p = this.protocols[i]
if (p.remoteOffset <= type && type < p.remoteEnd) {
p._recv(type - p.remoteOffset, state)
break
}
}
}
_push (type, enc, message) {
if (this.corked > 0) {
this._batch.push({ type, encoding: enc, message })
return false
}
const state = { start: 0, end: 0, buffer: null }
c.uint.preencode(state, 1)
m.handshake.preencode(state, hs)
c.uint.preencode(state, type)
enc.preencode(state, message)
state.buffer = this._alloc(state.end)
c.uint.encode(state, 1)
m.handshake.encode(state, hs)
c.uint.encode(state, type)
enc.encode(state, message)
this.stream.write(state.buffer)
return this.stream.write(state.buffer)
}
}
function noop () {}
function isPromise (p) {
return typeof p === 'object' && p !== null && !!p.catch
}

View File

@ -17,41 +17,25 @@ const version = {
}
}
const protocol = {
exports.addProtocol = {
preencode (state, p) {
c.string.preencode(state, p.name)
version.preencode(state, p.version)
c.uint.preencode(state, p.messages)
c.uint.preencode(state, p.offset)
c.uint.preencode(state, p.length)
},
encode (state, p) {
c.string.encode(state, p.name)
version.encode(state, p.version)
c.uint.encode(state, p.messages)
c.uint.encode(state, p.offset)
c.uint.encode(state, p.length)
},
decode (state, p) {
return {
name: c.string.decode(state),
version: version.decode(state),
messages: c.uint.decode(state)
}
}
}
const protocolArray = c.array(protocol)
exports.handshake = {
preencode (state, h) {
state.end++ // reversed flags
protocolArray.preencode(state, h.protocols)
},
encode (state, h) {
state.buffer[state.start++] = 0 // reserved flags
protocolArray.encode(state, h.protocols)
},
decode (state) {
c.uint.decode(state) // not using any flags for now
return {
protocols: protocolArray.decode(state)
offset: c.uint.decode(state),
length: c.uint.decode(state)
}
}
}

130
test.js
View File

@ -4,40 +4,35 @@ const test = require('brittle')
const c = require('compact-encoding')
test('basic', function (t) {
const a = new Protomux(new SecretStream(true), [{
name: 'foo',
messages: [c.string]
}])
const b = new Protomux(new SecretStream(false), [{
name: 'foo',
messages: [c.string]
}])
const a = new Protomux(new SecretStream(true))
const b = new Protomux(new SecretStream(false))
replicate(a, b)
const ap = a.get('foo')
const bp = b.get('foo')
a.addProtocol({
name: 'foo',
messages: [c.string],
onremoteopen () {
t.pass('a remote opened')
},
onmessage (type, message) {
t.is(type, 0)
t.is(message, 'hello world')
}
})
const bp = b.addProtocol({
name: 'foo',
messages: [c.string]
})
t.plan(3)
ap.onopen = function () {
t.pass('a opened')
}
ap.onmessage = function (type, message) {
t.is(type, 0)
t.is(message, 'hello world')
}
bp.send(0, 'hello world')
})
test('echo message', function (t) {
const a = new Protomux(new SecretStream(true), [{
name: 'foo',
messages: [c.string]
}])
const a = new Protomux(new SecretStream(true))
const b = new Protomux(new SecretStream(false), [{
name: 'other',
@ -49,48 +44,60 @@ test('echo message', function (t) {
replicate(a, b)
const ap = a.get('foo')
const bp = b.get('foo')
const ap = a.addProtocol({
name: 'foo',
messages: [c.string],
onmessage (type, message) {
ap.send(type, 'echo: ' + message)
}
})
b.addProtocol({
name: 'other',
messages: [c.bool, c.bool]
})
const bp = b.addProtocol({
name: 'foo',
messages: [c.string],
onremoteopen () {
t.pass('b remote opened')
},
onmessage (type, message) {
t.is(type, 0)
t.is(message, 'echo: hello world')
}
})
t.plan(3)
ap.onmessage = function (type, message) {
ap.send(type, 'echo: ' + message)
}
bp.send(0, 'hello world')
bp.onopen = function () {
t.pass('b opened')
}
bp.onmessage = function (type, message) {
t.is(type, 0)
t.is(message, 'echo: hello world')
}
})
test('multi message', function (t) {
const a = new Protomux(new SecretStream(true), [{
const a = new Protomux(new SecretStream(true))
a.addProtocol({
name: 'other',
messages: [c.bool, c.bool]
}, {
})
const ap = a.addProtocol({
name: 'multi',
messages: [c.int, c.string, c.string]
}])
})
const b = new Protomux(new SecretStream(false), [{
const b = new Protomux(new SecretStream(false))
const bp = b.addProtocol({
name: 'multi',
messages: [c.int, c.string]
}])
})
replicate(a, b)
t.plan(4)
const ap = a.get('multi')
const bp = b.get('multi')
ap.send(0, 42)
ap.send(1, 'a string with 42')
ap.send(2, 'should be ignored')
@ -108,26 +115,31 @@ test('multi message', function (t) {
})
test('corks', function (t) {
const a = new Protomux(new SecretStream(true), [{
const a = new Protomux(new SecretStream(true))
a.cork()
a.addProtocol({
name: 'other',
messages: [c.bool, c.bool]
}, {
name: 'multi',
messages: [c.int, c.string]
}])
})
const b = new Protomux(new SecretStream(false), [{
const ap = a.addProtocol({
name: 'multi',
messages: [c.int, c.string]
}])
})
const b = new Protomux(new SecretStream(false))
const bp = b.addProtocol({
name: 'multi',
messages: [c.int, c.string]
})
replicate(a, b)
t.plan(8 + 1)
const ap = a.get('multi')
const bp = b.get('multi')
const expected = [
[0, 1],
[0, 2],
@ -135,12 +147,12 @@ test('corks', function (t) {
[1, 'a string']
]
ap.cork()
ap.send(0, 1)
ap.send(0, 2)
ap.send(0, 3)
ap.send(1, 'a string')
ap.uncork()
a.uncork()
b.stream.once('data', function (data) {
t.ok(expected.length === 0, 'received all messages in one data packet')