From 45e10f4bec177f44ff37c38e8c27374b350fed75 Mon Sep 17 00:00:00 2001 From: Mathias Buus Date: Thu, 30 Dec 2021 21:13:47 +0100 Subject: [PATCH] rewrote it --- README.md | 113 +++++++++++ index.js | 532 ++++++++++++++++++++++++++++++---------------------- messages.js | 30 +-- test.js | 130 +++++++------ 4 files changed, 494 insertions(+), 311 deletions(-) diff --git a/README.md b/README.md index 11c898b..5fb5378 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,121 @@ npm install protomux ``` js const Protomux = require('protomux') +const c = require('compact-encoding') + +// By framed stream, it has be a stream that preserves the messages, ie something that length prefixes +// like @hyperswarm/secret-stream + +const mux = new Protomux(aStreamThatFrames) + +// Now add some protocols + +const cool = mux.addProtocol({ + name: 'cool-protocol', + version: { + major: 1, + minor: 0 + }, + // an array of compact encoders, each encoding/decoding the messages sent + messages: [ + c.string, + c.bool + ], + onremoteopen () { + console.log('the other side opened this protocol!') + }, + onemoteclose () { + console.log('the other side closed this protocol!') + }, + onmessage (type, message) { + console.log('the other side sent a message', type, message) + } +}) + +// And send some messages + +cool.send(0, 'a string') +cool.send(1, true) ``` +## API + +#### `mux = new Protomux(stream, [options])` + +Make a new instance. `stream` should be a framed stream, preserving the messages written. + +Options include: + +``` js +{ + // Called when the muxer wants to allocate a message that is written, defaults to Buffer.allocUnsafe. + alloc (size) {}, + // Hook that is called when an unknown protocol is received. Should return true/false. + async onacceptprotocol ({ name, version }) {} + // How many protocols can be remote open, that we haven't opened yet? + // Only used if you don't provide an accept hook. + backlog: 128 +} +``` + +#### `const p = mux.addProtocol(opts)` + +Add a new protocol. + +Options include: + +``` js +{ + // Used to match the protocol + name: 'name of the protocol', + // You can have multiple versions of the same protocol on the same stream. + // Protocols are matched using the major version only. + version: { + major: 0, + minor: 0 + }, + // Array of the message types you want to send/receive. Should be compact-encoders + messages: [ + ... + ], + // Called when the remote side adds this protocol. + // Errors here are caught and forwared to stream.destroy + async onremoteopen () {}, + // Called when the remote side closes or rejects this protocol. + // Errors here are caught and forwared to stream.destroy + async onremoteclose () {}, + // Called when the remote sends a message + // Errors here are caught and forwared to stream.destroy + async onmessage (type, message) {} +} +``` + +Each of the functions can also be set directly on the instance with the same name. + +#### `p.close()` + +Closes the protocol + +#### `p.send(type, message)` + +Send a message, type is the offset into the messages array. + +#### `p.cork()` + +Corking the protocol, makes it buffer messages and send them all in a batch when it uncorks. + +#### `p.uncork()` + +Uncork and send the batch. + +#### `mux.cork()` + +Same as `p.cork` but on the muxer instance. + +#### `mux.uncork()` + +Same as `p.uncork` but on the muxer instance. + ## License MIT diff --git a/index.js b/index.js index 6b4cc00..abc739c 100644 --- a/index.js +++ b/index.js @@ -1,252 +1,124 @@ const c = require('compact-encoding') -const m = require('./messages') +const b4a = require('b4a') +const safetyCatch = require('safety-catch') +const { addProtocol } = require('./messages') + +const EMPTY = [] class Protocol { - constructor (muxer, offset, protocol) { - this.muxer = muxer - this.stream = muxer.stream - this.start = offset - this.end = offset + protocol.messages.length - this.name = protocol.name - this.version = protocol.version || { major: 0, minor: 0 } - this.messages = protocol.messages.length + constructor (mux) { + this.mux = mux + this.name = null + this.version = null + this.messages = EMPTY + this.offset = 0 + this.length = 0 + this.opened = false + + this.remoteVersion = null + this.remoteOffset = 0 + this.remoteEnd = 0 this.remoteOpened = false - this.removed = false - this.encodings = protocol.messages + this.remoteClosed = false + this.onmessage = noop - this.onopen = noop - this.onclose = noop + this.onremoteopen = noop + this.onremoteclose = noop } - get corked () { - return this.muxer.corked + _attach ({ name, version = { major: 0, minor: 0 }, messages, onmessage = noop, onremoteopen = noop, onremoteclose = noop }) { + const opened = this.opened + + this.name = name + this.version = version + this.messages = messages + this.offset = this.mux.offset + this.length = messages.length + this.opened = true + this.corked = false + + this.onmessage = onmessage + this.onremoteopen = onremoteopen + this.onremoteclose = onremoteclose + + return !opened } cork () { - this.muxer.cork() + if (this.corked) return + this.corked = true + this.mux.cork() } uncork () { - this.muxer.uncork() + if (!this.corked) return + this.corked = false + this.mux.uncork() } send (type, message) { - const t = this.start + type - const enc = this.encodings[type] + if (!this.opened) return false - if (this.muxer._handshakeSent === false) { - this.muxer._sendHandshake() - } + const t = this.offset + type + const m = this.messages[type] - if (this.muxer.corked > 0) { - this.muxer._batch.push({ type: t, encoding: enc, message }) - return false - } - - const state = { start: 0, end: 0, buffer: null } - - c.uint.preencode(state, t) - enc.preencode(state, message) - - state.buffer = this.muxer._alloc(state.end) - - c.uint.encode(state, t) - enc.encode(state, message) - - return this.stream.write(state.buffer) + return this.mux._push(t, m, message) } - recv (type, state) { - this.onmessage(type, this.encodings[type].decode(state)) + close () { + if (this.opened === false) return + this.opened = false + this.mux._unopened++ + + const offset = this.offset + + this.version = null + this.messages = EMPTY + this.offset = 0 + this.length = 0 + this.onmessage = this.onremoteopen = this.onremoteclose = noop + this.mux._push(2, c.uint, offset) + this._gc() + + if (this.corked) this.uncork() + } + + _gc () { + if (this.opened || this.remoteOpened) return + this.mux._removeProtocol(this) + } + + _recv (type, state) { + if (type >= this.messages.length) return + + const m = this.messages[type] + const message = m.decode(state) + + this.mux._catch(this.onmessage(type, message)) } } module.exports = class Protomux { - constructor (stream, protocols, opts = {}) { + constructor (stream, { backlog = 128, alloc, onacceptprotocol } = {}) { this.stream = stream - this.protocols = [] - this.offset = 2 - this.remoteProtocols = [] - this.remoteOffset = 2 - - this.remoteHandshake = null - this.onhandshake = noop - + this.offset = 4 // 4 messages reserved this.corked = 0 + this.backlog = backlog + this.onacceptprotocol = onacceptprotocol || (() => this._unopened < this.backlog) + this._unopened = 0 this._batch = null - this._unmatchedProtocols = [] - this._handshakeSent = false - this._alloc = opts.alloc || (typeof stream.alloc === 'function' ? stream.alloc.bind(stream) : Buffer.allocUnsafe) - - for (const p of protocols) this.addProtocol(p) + this._alloc = alloc || (typeof stream.alloc === 'function' ? stream.alloc.bind(stream) : b4a.allocUnsafe) + this._safeDestroyBound = this._safeDestroy.bind(this) this.stream.on('data', this._ondata.bind(this)) - queueMicrotask(this._sendHandshake.bind(this)) + this.stream.on('close', this._shutdown.bind(this)) } - remoteOpened (name, version) { - for (const { remote } of this.remoteProtocols) { - if (remote.name === name || (version === undefined || version.major === remote.version.major)) return true - } - for (const { remote } of this._unmatchedProtocols) { - if (remote.name === name || (version === undefined || version.major === remote.version.major)) return true - } - return false - } - - addProtocol (p) { - const local = new Protocol(this, this.offset, p) - - this.protocols.push(local) - this.offset += p.messages.length - - for (let i = 0; i < this._unmatchedProtocols.length; i++) { - const { start, remote } = this._unmatchedProtocols[i] - if (remote.name !== local.name || remote.version.major !== local.version.major) continue - local.remoteOpened = true - this._unmatchedProtocols.splice(i, 1) - const end = start + Math.min(remote.messages, local.messages) - this.remoteProtocols.push({ local, remote, start, end }) - break - } - - return local - } - - removeProtocol (p) { - const { name, version = { major: 0, minor: 0 } } = typeof p === 'string' ? { name: p, version: undefined } : p - - for (let i = 0; i < this.protocols.length; i++) { - const local = this.protocols[i] - if (local.name !== name || local.version.major !== version.major) continue - p.removed = true - this.protocols.splice(i, 1) - } - - for (let i = 0; i < this.remoteProtocols.length; i++) { - const { local, remote, start } = this.remoteProtocols[i] - if (local.name !== name || local.version.major !== version.major) continue - this.remoteProtocols.splice(i, 1) - this._unmatchedProtocols.push({ start, remote }) - } - } - - addRemoteProtocol (p) { - if (!p.version) p = { name: p.name, version: { major: 0, minor: 0 }, messages: p.messages } - - const local = this.get(p.name) - const start = this.remoteOffset - - this.remoteOffset += p.messages - - if (!local || local.version.major !== p.version.major) { - this._unmatchedProtocols.push({ start, remote: p }) - return - } - - if (local.remoteOpened) { - this.destroy(new Error('Remote sent duplicate protocols')) - return - } - - const end = start + Math.min(p.messages, local.messages) - - this.remoteProtocols.push({ local, remote: p, start, end }) - - local.remoteOpened = true - local.onopen() - } - - removeRemoteProtocol ({ name, version = { major: 0, minor: 0 } }) { - for (let i = 0; i < this.remoteProtocols.length; i++) { - const { local } = this.remoteProtocols[i] - if (local.name !== name || local.version.major !== version.major) continue - this.remoteProtocols.splice(i, 1) - local.remoteOpened = false - local.onclose() - break - } - - for (let i = 0; i < this._unmatchedProtocols.length; i++) { - const { remote } = this._unmatchedProtocols[i] - if (remote.name !== name || remote.version.major !== version.major) continue - this._unmatchedProtocols.splice(i, 1) - break - } - } - - _ondata (buffer) { - if (buffer.byteLength === 0) return // keep alive - - const state = { start: 0, end: buffer.byteLength, buffer } - - try { - this._recv(state) - } catch (err) { - this.destroy(err) - } - } - - _recv (state) { - const t = c.uint.decode(state) - - if (t < 2) { - if (t === 0) { - this._recvBatch(state) - } else { - this._recvHandshake(state) - } - return - } - - for (let i = 0; i < this.remoteProtocols.length; i++) { - const p = this.remoteProtocols[i] - - if (p.start <= t && t < p.end) { - p.local.recv(t - p.start, state) - break - } - } - - state.start = state.end - } - - _recvBatch (state) { - const end = state.end - - while (state.start < state.end) { - const len = c.uint.decode(state) - state.end = state.start + len - this._recv(state) - state.end = end - } - } - - _recvHandshake (state) { - if (this.remoteHandshake !== null) { - this.destroy(new Error('Double handshake')) - return - } - - this.remoteHandshake = m.handshake.decode(state) - for (const p of this.remoteHandshake.protocols) this.addRemoteProtocol(p) - - this.onhandshake(this.remoteHandshake) - } - - destroy (err) { - this._handshakeSent = true // just to avoid sending it again - this.stream.destroy(err) - } - - get (name) { - for (const p of this.protocols) { - if (p.name === name) return p - } - return null + sendKeepAlive () { + this.stream.write(this._alloc(0)) } cork () { @@ -285,35 +157,237 @@ module.exports = class Protomux { this.stream.write(state.buffer) } - sendKeepAlive () { - this.stream.write(this._alloc(0)) + hasProtocol (opts) { + return !!this.getProtocol(opts) } - _sendHandshake () { - if (this._handshakeSent) return - this._handshakeSent = true + getProtocol ({ name, version }) { + return this._getProtocol(name, version, false) + } - const hs = { - protocols: this.protocols + addProtocol (opts) { + const p = this._getProtocol(opts.name, (opts.version && opts.version.major) || 0, true) + + if (opts.cork) p.cork() + if (!p._attach(opts)) return p + + this._unopened-- + this.offset += p.length + this._push(1, addProtocol, { + name: p.name, + version: p.version, + offset: p.offset, + length: p.length + }) + + return p + } + + destroy (err) { + this.stream.destroy(err) + } + + _shutdown () { + while (this.protocols.length) { + const p = this.protocols.pop() + if (!p.remoteOpened) continue + if (p.remoteClosed) continue + p.remoteOpened = false + p.remoteClosed = true + this._catch(p.onremoteclose()) + } + } + + _safeDestroy (err) { + safetyCatch(err) + this.destroy(err) + } + + _catch (p) { + if (isPromise(p)) p.catch(this._safeDestroyBound) + } + + async _acceptMaybe (added) { + let accept = false + + try { + accept = await this.onacceptprotocol(added) + } catch (err) { + this._safeDestroy(err) + return } - if (this.corked > 0) { - this._batch.push({ type: 1, encoding: m.handshake, message: hs }) + if (!accept) this._rejectProtocol(added) + } + + _rejectProtocol (added) { + for (let i = 0; i < this.protocols.length; i++) { + const p = this.protocols[i] + if (p.opened || p.name !== added.name || !p.remoteOpened) continue + if (p.remoteVersion.major !== added.version.major) continue + + this._unopened-- + this.protocols.splice(i, 1) + this._push(3, c.uint, added.offset) return } + } + + _ondata (buffer) { + if (buffer.byteLength === 0) return // keep alive + + const end = buffer.byteLength + const state = { start: 0, end, buffer } + + try { + const type = c.uint.decode(state) + if (type === 0) this._recvBatch(end, state) + else this._recv(type, state) + } catch (err) { + this._safeDestroy(err) + } + } + + _getProtocol (name, major, upsert) { + for (let i = 0; i < this.protocols.length; i++) { + const p = this.protocols[i] + const v = p.remoteVersion === null ? p.version : p.remoteVersion + if (p.name === name && (v !== null && v.major === major)) return p + } + + if (!upsert) return null + + const p = new Protocol(this) + this.protocols.push(p) + this._unopened++ + return p + } + + _removeProtocol (p) { + const i = this.protocols.indexOf(this) + if (i > -1) this.protocols.splice(i, 1) + if (!p.opened) this._unopened-- + } + + _recvAddProtocol (state) { + const add = addProtocol.decode(state) + + const p = this._getProtocol(add.name, add.version.major, true) + if (p.remoteOpened) throw new Error('Duplicate protocol received') + + p.name = add.name + p.remoteVersion = add.version + p.remoteOffset = add.offset + p.remoteEnd = add.offset + add.length + p.remoteOpened = true + p.remoteClosed = false + + if (p.opened) { + this._catch(p.onremoteopen()) + } else { + this._acceptMaybe(add) + } + } + + _recvRemoveProtocol (state) { + const offset = c.uint.decode(state) + + for (let i = 0; i < this.protocols.length; i++) { + const p = this.protocols[i] + + if (p.remoteOffset === offset && p.remoteOpened) { + p.remoteVersion = null + p.remoteOpened = false + p.remoteClosed = true + this._catch(p.onremoteclose()) + p._gc() + return + } + } + } + + _recvRejectedProtocol (state) { + const offset = c.uint.decode(state) + + for (let i = 0; i < this.protocols.length; i++) { + const p = this.protocols[i] + + if (p.offset === offset && !p.remoteClosed) { + p.remoteClosed = true + this._catch(p.onremoteclose()) + p._gc() + } + } + } + + _recvBatch (end, state) { + while (state.start < state.end) { + const len = c.uint.decode(state) + const type = c.uint.decode(state) + state.end = state.start + len + this._recv(type, state) + state.end = end + } + } + + _recv (type, state) { + if (type < 4) { + if (type === 0) { + throw new Error('Invalid nested batch') + } + + if (type === 1) { + this._recvAddProtocol(state) + return + } + + if (type === 2) { + this._recvRemoveProtocol(state) + return + } + + if (type === 3) { + this._recvRejectedProtocol(state) + return + } + + return + } + + // TODO: Consider make this array sorted by remoteOffset and use a bisect here. + // For now we use very few protocols in practice, so it might be overkill. + for (let i = 0; i < this.protocols.length; i++) { + const p = this.protocols[i] + + if (p.remoteOffset <= type && type < p.remoteEnd) { + p._recv(type - p.remoteOffset, state) + break + } + } + } + + _push (type, enc, message) { + if (this.corked > 0) { + this._batch.push({ type, encoding: enc, message }) + return false + } const state = { start: 0, end: 0, buffer: null } - c.uint.preencode(state, 1) - m.handshake.preencode(state, hs) + c.uint.preencode(state, type) + enc.preencode(state, message) state.buffer = this._alloc(state.end) - c.uint.encode(state, 1) - m.handshake.encode(state, hs) + c.uint.encode(state, type) + enc.encode(state, message) - this.stream.write(state.buffer) + return this.stream.write(state.buffer) } } function noop () {} + +function isPromise (p) { + return typeof p === 'object' && p !== null && !!p.catch +} diff --git a/messages.js b/messages.js index b9c6efa..4ea764c 100644 --- a/messages.js +++ b/messages.js @@ -17,41 +17,25 @@ const version = { } } -const protocol = { +exports.addProtocol = { preencode (state, p) { c.string.preencode(state, p.name) version.preencode(state, p.version) - c.uint.preencode(state, p.messages) + c.uint.preencode(state, p.offset) + c.uint.preencode(state, p.length) }, encode (state, p) { c.string.encode(state, p.name) version.encode(state, p.version) - c.uint.encode(state, p.messages) + c.uint.encode(state, p.offset) + c.uint.encode(state, p.length) }, decode (state, p) { return { name: c.string.decode(state), version: version.decode(state), - messages: c.uint.decode(state) - } - } -} - -const protocolArray = c.array(protocol) - -exports.handshake = { - preencode (state, h) { - state.end++ // reversed flags - protocolArray.preencode(state, h.protocols) - }, - encode (state, h) { - state.buffer[state.start++] = 0 // reserved flags - protocolArray.encode(state, h.protocols) - }, - decode (state) { - c.uint.decode(state) // not using any flags for now - return { - protocols: protocolArray.decode(state) + offset: c.uint.decode(state), + length: c.uint.decode(state) } } } diff --git a/test.js b/test.js index 59f2587..180bdd5 100644 --- a/test.js +++ b/test.js @@ -4,40 +4,35 @@ const test = require('brittle') const c = require('compact-encoding') test('basic', function (t) { - const a = new Protomux(new SecretStream(true), [{ - name: 'foo', - messages: [c.string] - }]) - - const b = new Protomux(new SecretStream(false), [{ - name: 'foo', - messages: [c.string] - }]) + const a = new Protomux(new SecretStream(true)) + const b = new Protomux(new SecretStream(false)) replicate(a, b) - const ap = a.get('foo') - const bp = b.get('foo') + a.addProtocol({ + name: 'foo', + messages: [c.string], + onremoteopen () { + t.pass('a remote opened') + }, + onmessage (type, message) { + t.is(type, 0) + t.is(message, 'hello world') + } + }) + + const bp = b.addProtocol({ + name: 'foo', + messages: [c.string] + }) t.plan(3) - ap.onopen = function () { - t.pass('a opened') - } - - ap.onmessage = function (type, message) { - t.is(type, 0) - t.is(message, 'hello world') - } - bp.send(0, 'hello world') }) test('echo message', function (t) { - const a = new Protomux(new SecretStream(true), [{ - name: 'foo', - messages: [c.string] - }]) + const a = new Protomux(new SecretStream(true)) const b = new Protomux(new SecretStream(false), [{ name: 'other', @@ -49,48 +44,60 @@ test('echo message', function (t) { replicate(a, b) - const ap = a.get('foo') - const bp = b.get('foo') + const ap = a.addProtocol({ + name: 'foo', + messages: [c.string], + onmessage (type, message) { + ap.send(type, 'echo: ' + message) + } + }) + + b.addProtocol({ + name: 'other', + messages: [c.bool, c.bool] + }) + + const bp = b.addProtocol({ + name: 'foo', + messages: [c.string], + onremoteopen () { + t.pass('b remote opened') + }, + onmessage (type, message) { + t.is(type, 0) + t.is(message, 'echo: hello world') + } + }) t.plan(3) - ap.onmessage = function (type, message) { - ap.send(type, 'echo: ' + message) - } - bp.send(0, 'hello world') - - bp.onopen = function () { - t.pass('b opened') - } - - bp.onmessage = function (type, message) { - t.is(type, 0) - t.is(message, 'echo: hello world') - } }) test('multi message', function (t) { - const a = new Protomux(new SecretStream(true), [{ + const a = new Protomux(new SecretStream(true)) + + a.addProtocol({ name: 'other', messages: [c.bool, c.bool] - }, { + }) + + const ap = a.addProtocol({ name: 'multi', messages: [c.int, c.string, c.string] - }]) + }) - const b = new Protomux(new SecretStream(false), [{ + const b = new Protomux(new SecretStream(false)) + + const bp = b.addProtocol({ name: 'multi', messages: [c.int, c.string] - }]) + }) replicate(a, b) t.plan(4) - const ap = a.get('multi') - const bp = b.get('multi') - ap.send(0, 42) ap.send(1, 'a string with 42') ap.send(2, 'should be ignored') @@ -108,26 +115,31 @@ test('multi message', function (t) { }) test('corks', function (t) { - const a = new Protomux(new SecretStream(true), [{ + const a = new Protomux(new SecretStream(true)) + + a.cork() + + a.addProtocol({ name: 'other', messages: [c.bool, c.bool] - }, { - name: 'multi', - messages: [c.int, c.string] - }]) + }) - const b = new Protomux(new SecretStream(false), [{ + const ap = a.addProtocol({ name: 'multi', messages: [c.int, c.string] - }]) + }) + + const b = new Protomux(new SecretStream(false)) + + const bp = b.addProtocol({ + name: 'multi', + messages: [c.int, c.string] + }) replicate(a, b) t.plan(8 + 1) - const ap = a.get('multi') - const bp = b.get('multi') - const expected = [ [0, 1], [0, 2], @@ -135,12 +147,12 @@ test('corks', function (t) { [1, 'a string'] ] - ap.cork() ap.send(0, 1) ap.send(0, 2) ap.send(0, 3) ap.send(1, 'a string') - ap.uncork() + + a.uncork() b.stream.once('data', function (data) { t.ok(expected.length === 0, 'received all messages in one data packet')