libhyperproxy/src/proxies/multiSocket.ts

310 lines
8.1 KiB
TypeScript

import Proxy, { ProxyOptions } from "../proxy.js";
import TcpSocket from "./multiSocket/tcpSocket.js";
import { json, raw, uint } from "compact-encoding";
import { deserializeError, serializeError } from "serialize-error";
import b4a from "b4a";
import type { TcpSocketConnectOpts } from "net";
import { DataSocketOptions, PeerOptions } from "../peer.js";
import {
roundRobinFactory,
idFactory,
maybeGetAsyncProperty,
} from "../util.js";
import {
CloseSocketRequest,
ErrorSocketRequest,
PeerEntity,
SocketRequest,
WriteSocketRequest,
} from "./multiSocket/types.js";
import DummySocket from "./multiSocket/dummySocket.js";
import Peer from "./multiSocket/peer.js";
export interface MultiSocketProxyOptions extends ProxyOptions {
socketClass?: any;
server: boolean;
allowedPorts?: number[];
}
const socketEncoding = {
preencode(state: any, m: SocketRequest) {
uint.preencode(state, m.id);
uint.preencode(state, m.remoteId);
},
encode(state: any, m: SocketRequest) {
uint.encode(state, m.id);
uint.encode(state, m.remoteId);
},
decode(state: any, m: any): SocketRequest {
return {
remoteId: uint.decode(state, m),
id: uint.decode(state, m),
};
},
};
const writeSocketEncoding = {
preencode(state: any, m: WriteSocketRequest) {
socketEncoding.preencode(state, m);
raw.preencode(state, m.data);
},
encode(state: any, m: WriteSocketRequest) {
socketEncoding.encode(state, m);
raw.encode(state, m.data);
},
decode(state: any, m: any): WriteSocketRequest {
const socket = socketEncoding.decode(state, m);
return {
...socket,
data: raw.decode(state, m),
};
},
};
const errorSocketEncoding = {
preencode(state: any, m: ErrorSocketRequest) {
socketEncoding.preencode(state, m);
json.preencode(state, serializeError(m.err));
},
encode(state: any, m: ErrorSocketRequest) {
socketEncoding.encode(state, m);
json.encode(state, serializeError(m.err));
},
decode(state: any, m: any): ErrorSocketRequest {
const socket = socketEncoding.decode(state, m);
return {
...socket,
err: deserializeError(json.decode(state, m)),
};
},
};
const nextSocketId = idFactory(1);
export default class MultiSocketProxy extends Proxy {
async handlePeer({
peer,
muxer,
...options
}: DataSocketOptions & PeerOptions) {
const conn = new Peer({
...this.socketOptions,
proxy: this,
peer,
muxer,
...options,
});
await conn.init();
this.emit("peer", conn);
}
private socketClass: any;
private _peers: Map<string, PeerEntity> = new Map<string, PeerEntity>();
private _nextPeer;
private _server = false;
private _allowedPorts = [];
constructor(options: MultiSocketProxyOptions) {
super(options);
if (options.socketClass) {
this.socketClass = options.socketClass;
} else {
if (options.server) {
this.socketClass = TcpSocket;
} else {
this.socketClass = DummySocket;
}
}
if (options.server) {
this._server = true;
}
this._nextPeer = roundRobinFactory(this._peers);
}
private _socketMap = new Map<number, number>();
get socketMap(): Map<number, number> {
return this._socketMap;
}
private _sockets = new Map<number, typeof this.socketClass>();
get sockets(): Map<number, any> {
return this._sockets;
}
async handleNewPeerChannel(peer: Peer) {
this.update(await this._getPublicKey(peer), {
peer,
});
await this._registerOpenSocketMessage(peer);
await this._registerWriteSocketMessage(peer);
await this._registerCloseSocketMessage(peer);
await this._registerTimeoutSocketMessage(peer);
await this._registerErrorSocketMessage(peer);
}
async handleClosePeer(peer: Peer) {
for (const item of this._sockets) {
if (item[1].peer.peer === peer) {
item[1].end();
}
}
const pubkey = this._toString(await this._getPublicKey(peer));
if (this._peers.has(pubkey)) {
this._peers.delete(pubkey);
}
}
public get(pubkey: Uint8Array): PeerEntity | undefined {
if (this._peers.has(this._toString(pubkey))) {
return this._peers.get(this._toString(pubkey)) as PeerEntity;
}
return undefined;
}
public update(pubkey: Uint8Array, data: Partial<PeerEntity>): void {
const peer = this.get(pubkey) ?? ({} as PeerEntity);
this._peers.set(this._toString(pubkey), {
...peer,
...data,
...{
messages: {
...peer?.messages,
...data?.messages,
},
},
} as PeerEntity);
}
public createSocket(options: TcpSocketConnectOpts): typeof this.socketClass {
if (!this._peers.size) {
throw new Error("no peers found");
}
const peer = this._nextPeer();
const socketId = nextSocketId();
const socket = new this.socketClass(socketId, this, peer, options);
this._sockets.set(socketId, socket);
return socket;
}
private async _registerOpenSocketMessage(peer: Peer) {
const self = this;
const message = await peer.channel.addMessage({
encoding: {
preencode: this._server ? socketEncoding.preencode : json.preencode,
encode: this._server ? socketEncoding.encode : json.encode,
decode: this._server ? json.decode : socketEncoding.decode,
},
async onmessage(m: SocketRequest | TcpSocketConnectOpts) {
if (self._server) {
if (
self._allowedPorts.length &&
!self._allowedPorts.includes((m as TcpSocketConnectOpts).port)
) {
self.get(await self._getPublicKey(peer)).messages.errorSocket.send({
id: (m as SocketRequest).id,
err: new Error(
`port ${(m as TcpSocketConnectOpts).port} not allowed`
),
});
return;
}
}
m = m as SocketRequest;
if (self._server) {
new self.socketClass(
nextSocketId(),
m.id,
self,
self.get(await self._getPublicKey(peer)) as PeerEntity,
m
).connect();
return;
}
const socket = self._sockets.get(m.id);
if (socket) {
socket.remoteId = m.remoteId;
// @ts-ignore
socket.emit("connect");
}
},
});
this.update(await this._getPublicKey(peer), {
messages: { openSocket: message },
});
}
private async _registerWriteSocketMessage(peer: Peer) {
const self = this;
const message = await peer.channel.addMessage({
encoding: writeSocketEncoding,
onmessage(m: WriteSocketRequest) {
self._sockets.get(m.id)?.push(b4a.from(m.data));
},
});
this.update(await this._getPublicKey(peer), {
messages: { writeSocket: message },
});
}
private async _registerCloseSocketMessage(peer: Peer) {
const self = this;
const message = await peer.channel.addMessage({
encoding: socketEncoding,
onmessage(m: CloseSocketRequest) {
self._sockets.get(m.id)?.end();
},
});
this.update(await this._getPublicKey(peer), {
messages: { closeSocket: message },
});
}
private async _registerTimeoutSocketMessage(peer: Peer) {
const self = this;
const message = await peer.channel.addMessage({
encoding: socketEncoding,
onmessage(m: SocketRequest) {
// @ts-ignore
self._sockets.get(m.id)?.emit("timeout");
},
});
this.update(await this._getPublicKey(peer), {
messages: { timeoutSocket: message },
});
}
private async _registerErrorSocketMessage(peer: Peer) {
const self = this;
const message = await peer.channel.addMessage({
encoding: errorSocketEncoding,
onmessage(m: ErrorSocketRequest) {
// @ts-ignore
self._sockets.get(m.id)?.emit("error", m.err);
},
});
this.update(await this._getPublicKey(peer), {
messages: { errorSocket: message },
});
}
private _toString(pubkey: Uint8Array) {
return b4a.from(pubkey).toString("hex");
}
private async _getPublicKey(peer: Peer) {
return maybeGetAsyncProperty(peer.stream.remotePublicKey);
}
}