diff --git a/package.json b/package.json index 31f1d69..039ed98 100644 --- a/package.json +++ b/package.json @@ -3,15 +3,18 @@ "version": "0.1.0", "main": "dist/index.js", "devDependencies": { + "@types/b4a": "^1.6.0", "@types/node": "^18.11.18", "@types/streamx": "^2.9.1", "prettier": "^2.8.2", "typescript": "^4.9.4" }, "dependencies": { + "b4a": "^1.6.3", "buffer": "^6.0.3", "compact-encoding": "^2.11.0", "protomux": "^3.4.0", + "serialize-error": "^11.0.0", "streamx": "^2.13.0" } } diff --git a/src/index.ts b/src/index.ts index d10c697..39a2238 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,6 +10,10 @@ import Peer, { OnClose, } from "./peer.js"; import Server from "./server.js"; +import DummySocket from "./proxies/multiSocket/dummySocket.js"; +import TcpSocket from "./proxies/multiSocket/tcpSocket.js"; +import BasicProxy from "./proxies/basic.js"; +import MultiSocketProxy from "./proxies/multiSocket.js"; export { Proxy, @@ -23,6 +27,10 @@ export { OnSend, OnReceive, OnClose, + DummySocket, + TcpSocket, + BasicProxy, + MultiSocketProxy, }; export function createSocket(port: number, host: string): Socket { diff --git a/src/proxies/basic.ts b/src/proxies/basic.ts new file mode 100644 index 0000000..53bcef6 --- /dev/null +++ b/src/proxies/basic.ts @@ -0,0 +1,3 @@ +import Proxy from "../proxy.js"; + +export default class BasicProxy extends Proxy {} diff --git a/src/proxies/multiSocket.ts b/src/proxies/multiSocket.ts new file mode 100644 index 0000000..0f7d373 --- /dev/null +++ b/src/proxies/multiSocket.ts @@ -0,0 +1,276 @@ +import Proxy, { ProxyOptions } from "../proxy.js"; +import TcpSocket from "./multiSocket/tcpSocket.js"; +import { json, raw, uint } from "compact-encoding"; +import { deserializeError } from "serialize-error"; +import b4a from "b4a"; +import type { TcpSocketConnectOpts } from "net"; +import Peer from "../peer.js"; +import { roundRobinFactory, idFactory } from "../util.js"; +import { + CloseSocketRequest, + ErrorSocketRequest, + PeerEntity, + SocketRequest, + WriteSocketRequest, +} from "./multiSocket/types.js"; +import DummySocket from "./multiSocket/dummySocket.js"; + +export interface MultiSocketProxyOptions extends ProxyOptions { + socketClass?: any; + server: boolean; + allowedPorts?: number[]; +} + +const socketEncoding = { + preencode(state: any, m: SocketRequest) { + uint.preencode(state, m.id); + uint.preencode(state, m.remoteId); + }, + encode(state: any, m: SocketRequest) { + uint.encode(state, m.id); + uint.encode(state, m.remoteId); + }, + decode(state: any, m: any): SocketRequest { + return { + remoteId: uint.decode(state, m), + id: uint.decode(state, m), + }; + }, +}; + +const writeSocketEncoding = { + preencode(state: any, m: WriteSocketRequest) { + socketEncoding.preencode(state, m); + raw.preencode(state, m.data); + }, + encode(state: any, m: WriteSocketRequest) { + socketEncoding.encode(state, m); + raw.encode(state, m.data); + }, + decode(state: any, m: any): WriteSocketRequest { + const socket = socketEncoding.decode(state, m); + return { + ...socket, + data: raw.decode(state, m), + }; + }, +}; + +const errorSocketEncoding = { + decode(state: any, m: any): ErrorSocketRequest { + const socket = socketEncoding.decode(state, m); + return { + ...socket, + err: deserializeError(json.decode(state, m)), + }; + }, +}; + +const nextSocketId = idFactory(1); + +export default class MultiSocketProxy extends Proxy { + private socketClass: any; + private _peers: Map = new Map(); + private _nextPeer = roundRobinFactory(this._peers); + private _server = false; + private _allowedPorts = []; + + constructor(options: MultiSocketProxyOptions) { + super({ + createDefaultMessage: false, + ...options, + }); + if (options.socketClass) { + this.socketClass = options.socketClass; + } else { + if (options.server) { + this.socketClass = TcpSocket; + } else { + this.socketClass = DummySocket; + } + } + if (options.server) { + this._server = true; + } + } + + private _socketMap = new Map(); + + get socketMap(): Map { + return this._socketMap; + } + + private _sockets = new Map(); + + get sockets(): Map { + return this._sockets; + } + + handleNewPeerChannel(peer: Peer, channel: any) { + this.update(peer.socket.remotePublicKey, { peer }); + + this._registerOpenSocketMessage(peer, channel); + this._registerWriteSocketMessage(peer, channel); + this._registerCloseSocketMessage(peer, channel); + this._registerTimeoutSocketMessage(peer, channel); + this._registerErrorSocketMessage(peer, channel); + } + + async handleClosePeer(peer: Peer) { + for (const item of this._sockets) { + if (item[1].peer.peer === peer) { + item[1].end(); + } + } + + const pubkey = this._toString(peer.socket.remotePublicKey); + + if (this._peers.has(pubkey)) { + this._peers.delete(pubkey); + } + } + + public get(pubkey: Uint8Array): PeerEntity | undefined { + if (this._peers.has(this._toString(pubkey))) { + return this._peers.get(this._toString(pubkey)) as PeerEntity; + } + + return undefined; + } + + public update(pubkey: Uint8Array, data: Partial): void { + const peer = this.get(pubkey) ?? ({} as PeerEntity); + + this._peers.set(this._toString(pubkey), { + ...peer, + ...data, + ...{ + messages: { + ...peer?.messages, + ...data?.messages, + }, + }, + } as PeerEntity); + } + + public async createSocket( + options: TcpSocketConnectOpts + ): Promise { + if (!this._peers.size) { + throw new Error("no peers found"); + } + + const peer = this._nextPeer(); + const socketId = nextSocketId(); + const socket = new this.socketClass(socketId, this, peer, options); + this._sockets.set(socketId, socket); + + return socket; + } + + private _registerOpenSocketMessage(peer: Peer, channel: any) { + const self = this; + const message = channel.addMessage({ + encoding: { + preencode: json.preencode, + encode: json.encode, + decode: this._server ? json.encode : socketEncoding.decode, + }, + async onmessage(m: SocketRequest | TcpSocketConnectOpts) { + if ( + self._allowedPorts.length && + !self._allowedPorts.includes((m as TcpSocketConnectOpts).port) + ) { + self.get(peer.socket.remotePublicKey).messages.errorSocket.send({ + id: (m as SocketRequest).id, + err: new Error( + `port ${(m as TcpSocketConnectOpts).port} not allowed` + ), + }); + return; + } + + m = m as SocketRequest; + + if (self._server) { + new self.socketClass( + nextSocketId(), + m, + self, + self.get(peer.socket.remotePublicKey) as PeerEntity, + m + ).connect(); + return; + } + + const socket = self._sockets.get(m.id); + if (socket) { + socket.remoteId = m.remoteId; + // @ts-ignore + socket.emit("connect"); + } + }, + }); + this.update(peer.socket.remotePublicKey, { + messages: { openSocket: message }, + }); + } + + private _registerWriteSocketMessage(peer: Peer, channel: any) { + const self = this; + const message = channel.addMessage({ + encoding: writeSocketEncoding, + onmessage(m: WriteSocketRequest) { + self._sockets.get(m.id)?.push(m.data); + }, + }); + this.update(peer.socket.remotePublicKey, { + messages: { writeSocket: message }, + }); + } + + private _registerCloseSocketMessage(peer: Peer, channel: any) { + const self = this; + const message = channel.addMessage({ + encoding: socketEncoding, + onmessage(m: CloseSocketRequest) { + self._sockets.get(m.id)?.end(); + }, + }); + this.update(peer.socket.remotePublicKey, { + messages: { closeSocket: message }, + }); + } + + private _registerTimeoutSocketMessage(peer: Peer, channel: any) { + const self = this; + const message = channel.addMessage({ + encoding: socketEncoding, + onmessage(m: SocketRequest) { + // @ts-ignore + self._sockets.get(m.id)?.emit("timeout"); + }, + }); + this.update(peer.socket.remotePublicKey, { + messages: { timeoutSocket: message }, + }); + } + + private _registerErrorSocketMessage(peer: Peer, channel: any) { + const self = this; + const message = channel.addMessage({ + encoding: errorSocketEncoding, + onmessage(m: ErrorSocketRequest) { + // @ts-ignore + self._sockets.get(m.id)?.emit("error", m.err); + }, + }); + this.update(peer.socket.remotePublicKey, { + messages: { errorSocket: message }, + }); + } + + private _toString(pubkey: Uint8Array) { + return b4a.from(pubkey).toString("hex"); + } +} diff --git a/src/proxies/multiSocket/dummySocket.ts b/src/proxies/multiSocket/dummySocket.ts new file mode 100644 index 0000000..bd84e7e --- /dev/null +++ b/src/proxies/multiSocket/dummySocket.ts @@ -0,0 +1,82 @@ +import { Callback, Duplex } from "streamx"; +import { TcpSocketConnectOpts } from "net"; +import { clearTimeout } from "timers"; +import MultiSocketProxy from "../multiSocket.js"; +import { PeerEntity, SocketRequest, WriteSocketRequest } from "./types.js"; +import { maybeGetAsyncProperty } from "../../util.js"; + +export default class DummySocket extends Duplex { + private _options: TcpSocketConnectOpts; + private _id: number; + private _proxy: MultiSocketProxy; + + private _connectTimeout?: number; + + constructor( + id: number, + manager: MultiSocketProxy, + peer: PeerEntity, + options: TcpSocketConnectOpts + ) { + super(); + this._id = id; + this._proxy = manager; + this._peer = peer; + this._options = options; + + // @ts-ignore + this.on("timeout", () => { + if (this._connectTimeout) { + clearTimeout(this._connectTimeout); + } + }); + } + + private _remoteId = 0; + + set remoteId(value: number) { + this._remoteId = value; + this._proxy.socketMap.set(this._id, value); + } + + private _peer; + + get peer() { + return this._peer; + } + + public async _write(data: any, cb: any): Promise { + (await maybeGetAsyncProperty(this._peer.messages.writeSocket))?.send({ + id: this._id, + remoteId: this._remoteId, + data, + } as WriteSocketRequest); + cb(); + } + + public async _destroy(cb: Callback) { + (await maybeGetAsyncProperty(this._peer.messages.closeSocket))?.send({ + id: this._id, + remoteId: this._remoteId, + } as SocketRequest); + this._proxy.socketMap.delete(this._id); + this._proxy.sockets.delete(this._id); + } + + public async connect() { + (await maybeGetAsyncProperty(this._peer.messages.openSocket))?.send({ + ...this._options, + id: this._id, + }); + } + + public setTimeout(ms: number, cb: Function) { + if (this._connectTimeout) { + clearTimeout(this._connectTimeout); + } + + this._connectTimeout = setTimeout(() => { + cb && cb(); + }, ms) as any; + } +} diff --git a/src/proxies/multiSocket/tcpSocket.ts b/src/proxies/multiSocket/tcpSocket.ts new file mode 100644 index 0000000..b5a99b7 --- /dev/null +++ b/src/proxies/multiSocket/tcpSocket.ts @@ -0,0 +1,89 @@ +import { Callback, Duplex } from "streamx"; +import { Socket, TcpSocketConnectOpts } from "net"; +import MultiSocketProxy from "../multiSocket.js"; +import { PeerEntity, SocketRequest, WriteSocketRequest } from "./types.js"; +import * as net from "net"; + +export default class TcpSocket extends Duplex { + private _options; + private _id: number; + private _remoteId: number; + private _proxy: MultiSocketProxy; + + private _socket?: Socket; + + constructor( + id: number, + remoteId: number, + manager: MultiSocketProxy, + peer: PeerEntity, + options: TcpSocketConnectOpts + ) { + super(); + this._remoteId = remoteId; + this._proxy = manager; + this._id = id; + this._peer = peer; + this._options = options; + + this._proxy.sockets.set(this._id, this); + this._proxy.socketMap.set(this._id, this._remoteId); + } + + private _peer; + + get peer() { + return this._peer; + } + + public _write(data: any, cb: any): void { + this._peer.messages.writeSocket?.send({ + ...this._getSocketRequest(), + data, + } as WriteSocketRequest); + cb(); + } + + public _destroy(cb: Callback) { + this._proxy.sockets.delete(this._id); + this._proxy.socketMap.delete(this._id); + this._peer.messages.closeSocket?.send(this._getSocketRequest()); + } + + public connect() { + this.on("error", (err: Error) => { + this._peer.messages.errorSocket?.send({ + ...this._getSocketRequest(), + err, + }); + }); + + // @ts-ignore + this.on("timeout", () => { + this._peer.messages.timeoutSocket?.send(this._getSocketRequest()); + }); + // @ts-ignore + this.on("connect", () => { + this._peer.messages.openSocket?.send(this._getSocketRequest()); + }); + + this._socket = net.connect(this._options); + ["timeout", "error", "connect", "end", "destroy", "close"].forEach( + (event) => { + this._socket?.on(event, (...args: any) => + this.emit(event as any, ...args) + ); + } + ); + + this._socket.pipe(this as any); + this.pipe(this._socket); + } + + private _getSocketRequest(): SocketRequest { + return { + id: this._id, + remoteId: this._remoteId, + }; + } +} diff --git a/src/proxies/multiSocket/types.ts b/src/proxies/multiSocket/types.ts new file mode 100644 index 0000000..7f34523 --- /dev/null +++ b/src/proxies/multiSocket/types.ts @@ -0,0 +1,35 @@ +import { ProxyOptions } from "../../proxy.js"; +import Peer from "../../peer.js"; + +export interface SocketRequest { + remoteId: number; + id: number; +} + +export type CloseSocketRequest = SocketRequest; + +export interface WriteSocketRequest extends SocketRequest { + data: Uint8Array; +} + +export interface ErrorSocketRequest extends SocketRequest { + err: Error; +} + +type Message = { + send: (pubkey: Uint8Array | any) => void; +}; + +export interface PeerEntityMessages { + keyExchange: Message; + openSocket: Message; + writeSocket: Message; + closeSocket: Message; + timeoutSocket: Message; + errorSocket: Message; +} + +export interface PeerEntity { + messages: PeerEntityMessages | Partial; + peer: Peer; +} diff --git a/src/proxy.ts b/src/proxy.ts index 6926858..3900b51 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -8,7 +8,7 @@ export interface ProxyOptions extends DataSocketOptions { autostart?: boolean; } -export default class Proxy { +export default abstract class Proxy { private _listen: any; private _socketOptions: DataSocketOptions; private _autostart: boolean; diff --git a/src/util.ts b/src/util.ts new file mode 100644 index 0000000..c6b834b --- /dev/null +++ b/src/util.ts @@ -0,0 +1,40 @@ +export function idFactory(start: number, step = 1, limit = 2 ** 32) { + let id = start; + + return function nextId() { + const nextId = id; + id += step; + if (id >= limit) id = start; + return nextId; + }; +} +export function roundRobinFactory(list: Map) { + let index = 0; + + return (): T => { + const keys = [...list.keys()].sort(); + if (index >= keys.length) { + index = 0; + } + + return list.get(keys[index++]); + }; +} +export async function maybeGetAsyncProperty(object: any) { + if (typeof object === "function") { + object = object(); + } + + if (isPromise(object)) { + object = await object; + } + + return object; +} +export function isPromise(obj: Promise) { + return ( + !!obj && + (typeof obj === "object" || typeof obj === "function") && + typeof obj.then === "function" + ); +}