diff --git a/src/rpc/common.ts b/src/rpc/common.ts index 56e37ba..0085ff9 100644 --- a/src/rpc/common.ts +++ b/src/rpc/common.ts @@ -8,6 +8,10 @@ import { ERR_INVALID_CHAIN } from "../error.js"; type RpcProviderMethod = (method: string, params: Array) => Promise; +interface RpcContext { + chain?: string; +} + const gatewayProviders: { [name: string]: RpcProviderMethod } = {}; const gatewayMethods: { @@ -126,3 +130,13 @@ class RpcError extends Error { export function rpcError(message: string): Promise { return Promise.reject(new RpcError(message)); } + +export function validateChain(chain: string, handler: any) { + return async (args: any, context: RpcContext) => { + if (!context?.chain || "hns" !== context?.chain) { + return rpcError(ERR_INVALID_CHAIN); + } + + return handler(args, context); + }; +} diff --git a/src/rpc/dns.ts b/src/rpc/dns.ts index 71687b9..de70845 100644 --- a/src/rpc/dns.ts +++ b/src/rpc/dns.ts @@ -2,7 +2,7 @@ //const require = createRequire(import.meta.url); import { isIp } from "../util.js"; -import { RpcMethodList } from "./index.js"; +import { RpcMethodList, validateChain } from "./index.js"; // @ts-ignore import bns from "bns"; const { StubResolver, RecursiveResolver } = bns; @@ -59,7 +59,7 @@ async function getDnsRecords( } export default { - dnslookup: async function (args: any) { + dnslookup: validateChain("icann", async function (args: any) { let dnsResults: string[] = []; let domain = args.domain; let ns = args.nameserver; @@ -104,5 +104,5 @@ export default { } return false; - }, + }), } as RpcMethodList; diff --git a/src/rpc/handshake.ts b/src/rpc/handshake.ts index c2aa275..b70e9fd 100644 --- a/src/rpc/handshake.ts +++ b/src/rpc/handshake.ts @@ -1,7 +1,7 @@ //const require = createRequire(import.meta.url); //import { createRequire } from "module"; -import { rpcError, RpcMethodList } from "./index.js"; +import { rpcError, RpcMethodList, validateChain } from "./index.js"; // @ts-ignore import rand from "random-key"; // @ts-ignore @@ -78,12 +78,7 @@ if (!config.bool("hsd-use-external-node")) { const hnsClient = new NodeClient(clientArgs); export default { - getnameresource: async (args: any, context: object) => { - // @ts-ignore - if ("hns" !== context.chain) { - throw rpcError(ERR_INVALID_CHAIN); - } - + getnameresource: validateChain("hns", async (args: any) => { let resp; try { resp = await hnsClient.execute("getnameresource", args); @@ -100,5 +95,5 @@ export default { } return resp; - }, + }), } as RpcMethodList; diff --git a/src/rpc/misc.ts b/src/rpc/misc.ts index 6f603e2..c9adc44 100644 --- a/src/rpc/misc.ts +++ b/src/rpc/misc.ts @@ -1,7 +1,7 @@ -import { RpcMethodList } from "./index"; +import { RpcMethodList, validateChain } from "./index.js"; export default { - ping: async () => { + ping: validateChain("misc", async () => { return { pong: true }; - }, + }), } as RpcMethodList;