feat(abigen): support overloaded functions with different casing (#650)
* fix: support overloaded functions with different casing * chore: fmt * chore: typos * feat: better method alias handling
This commit is contained in:
parent
f10b47e600
commit
279a2c316c
|
@ -13,6 +13,7 @@ use ethers_core::{
|
||||||
macros::{ethers_contract_crate, ethers_core_crate, ethers_providers_crate},
|
macros::{ethers_contract_crate, ethers_core_crate, ethers_providers_crate},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::contract::methods::MethodAlias;
|
||||||
use proc_macro2::{Ident, Literal, TokenStream};
|
use proc_macro2::{Ident, Literal, TokenStream};
|
||||||
use quote::quote;
|
use quote::quote;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
@ -78,7 +79,7 @@ pub struct Context {
|
||||||
contract_name: Ident,
|
contract_name: Ident,
|
||||||
|
|
||||||
/// Manually specified method aliases.
|
/// Manually specified method aliases.
|
||||||
method_aliases: BTreeMap<String, Ident>,
|
method_aliases: BTreeMap<String, MethodAlias>,
|
||||||
|
|
||||||
/// Derives added to event structs and enums.
|
/// Derives added to event structs and enums.
|
||||||
event_derives: Vec<Path>,
|
event_derives: Vec<Path>,
|
||||||
|
@ -204,7 +205,11 @@ impl Context {
|
||||||
// method will be re-defined.
|
// method will be re-defined.
|
||||||
let mut method_aliases = BTreeMap::new();
|
let mut method_aliases = BTreeMap::new();
|
||||||
for (signature, alias) in args.method_aliases.into_iter() {
|
for (signature, alias) in args.method_aliases.into_iter() {
|
||||||
let alias = syn::parse_str(&alias)?;
|
let alias = MethodAlias {
|
||||||
|
function_name: util::safe_ident(&alias),
|
||||||
|
struct_name: util::safe_pascal_case_ident(&alias),
|
||||||
|
};
|
||||||
|
|
||||||
if method_aliases.insert(signature.clone(), alias).is_some() {
|
if method_aliases.insert(signature.clone(), alias).is_some() {
|
||||||
return Err(anyhow!("duplicate method signature '{}' in method aliases", signature,))
|
return Err(anyhow!("duplicate method signature '{}' in method aliases", signature,))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::collections::{btree_map::Entry, BTreeMap};
|
use std::collections::{btree_map::Entry, BTreeMap, HashMap};
|
||||||
|
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use inflector::Inflector;
|
use inflector::Inflector;
|
||||||
|
@ -47,7 +47,7 @@ impl Context {
|
||||||
fn expand_call_struct(
|
fn expand_call_struct(
|
||||||
&self,
|
&self,
|
||||||
function: &Function,
|
function: &Function,
|
||||||
alias: Option<&Ident>,
|
alias: Option<&MethodAlias>,
|
||||||
) -> Result<TokenStream> {
|
) -> Result<TokenStream> {
|
||||||
let call_name = expand_call_struct_name(function, alias);
|
let call_name = expand_call_struct_name(function, alias);
|
||||||
let fields = self.expand_input_pairs(function)?;
|
let fields = self.expand_input_pairs(function)?;
|
||||||
|
@ -82,7 +82,7 @@ impl Context {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Expands all structs
|
/// Expands all structs
|
||||||
fn expand_call_structs(&self, aliases: BTreeMap<String, Ident>) -> Result<TokenStream> {
|
fn expand_call_structs(&self, aliases: BTreeMap<String, MethodAlias>) -> Result<TokenStream> {
|
||||||
let mut struct_defs = Vec::new();
|
let mut struct_defs = Vec::new();
|
||||||
let mut struct_names = Vec::new();
|
let mut struct_names = Vec::new();
|
||||||
let mut variant_names = Vec::new();
|
let mut variant_names = Vec::new();
|
||||||
|
@ -236,7 +236,11 @@ impl Context {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Expands a single function with the given alias
|
/// Expands a single function with the given alias
|
||||||
fn expand_function(&self, function: &Function, alias: Option<Ident>) -> Result<TokenStream> {
|
fn expand_function(
|
||||||
|
&self,
|
||||||
|
function: &Function,
|
||||||
|
alias: Option<MethodAlias>,
|
||||||
|
) -> Result<TokenStream> {
|
||||||
let name = expand_function_name(function, alias.as_ref());
|
let name = expand_function_name(function, alias.as_ref());
|
||||||
let selector = expand_selector(function.selector());
|
let selector = expand_selector(function.selector());
|
||||||
|
|
||||||
|
@ -275,10 +279,21 @@ impl Context {
|
||||||
// The first function or the function with the least amount of arguments should
|
// The first function or the function with the least amount of arguments should
|
||||||
// be named as in the ABI, the following functions suffixed with _with_ +
|
// be named as in the ABI, the following functions suffixed with _with_ +
|
||||||
// additional_params[0].name + (_and_(additional_params[1+i].name))*
|
// additional_params[0].name + (_and_(additional_params[1+i].name))*
|
||||||
fn get_method_aliases(&self) -> Result<BTreeMap<String, Ident>> {
|
fn get_method_aliases(&self) -> Result<BTreeMap<String, MethodAlias>> {
|
||||||
let mut aliases = self.method_aliases.clone();
|
let mut aliases = self.method_aliases.clone();
|
||||||
|
|
||||||
|
// it might be the case that there are functions with different capitalization so we sort
|
||||||
|
// them all by lc name first
|
||||||
|
let mut all_functions = HashMap::new();
|
||||||
|
for function in self.abi.functions() {
|
||||||
|
all_functions
|
||||||
|
.entry(function.name.to_lowercase())
|
||||||
|
.or_insert_with(Vec::new)
|
||||||
|
.push(function);
|
||||||
|
}
|
||||||
|
|
||||||
// find all duplicates, where no aliases where provided
|
// find all duplicates, where no aliases where provided
|
||||||
for functions in self.abi.functions.values() {
|
for functions in all_functions.values() {
|
||||||
if functions.iter().filter(|f| !aliases.contains_key(&f.abi_signature())).count() <= 1 {
|
if functions.iter().filter(|f| !aliases.contains_key(&f.abi_signature())).count() <= 1 {
|
||||||
// no overloads, hence no conflicts
|
// no overloads, hence no conflicts
|
||||||
continue
|
continue
|
||||||
|
@ -318,7 +333,7 @@ impl Context {
|
||||||
let mut diffs = Vec::new();
|
let mut diffs = Vec::new();
|
||||||
|
|
||||||
/// helper function that checks if there are any conflicts due to parameter names
|
/// helper function that checks if there are any conflicts due to parameter names
|
||||||
fn name_conflicts(idx: usize, diffs: &[(usize, Vec<&Param>, &Function)]) -> bool {
|
fn name_conflicts(idx: usize, diffs: &[(usize, Vec<&Param>, &&Function)]) -> bool {
|
||||||
let diff = &diffs.iter().find(|(i, _, _)| *i == idx).expect("diff exists").1;
|
let diff = &diffs.iter().find(|(i, _, _)| *i == idx).expect("diff exists").1;
|
||||||
|
|
||||||
for (_, other, _) in diffs.iter().filter(|(i, _, _)| *i != idx) {
|
for (_, other, _) in diffs.iter().filter(|(i, _, _)| *i != idx) {
|
||||||
|
@ -333,7 +348,6 @@ impl Context {
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
// compare each overloaded function with the `first_fun`
|
// compare each overloaded function with the `first_fun`
|
||||||
for (idx, overloaded_fun) in functions.into_iter().skip(1) {
|
for (idx, overloaded_fun) in functions.into_iter().skip(1) {
|
||||||
// attempt to find diff in the input arguments
|
// attempt to find diff in the input arguments
|
||||||
|
@ -357,6 +371,29 @@ impl Context {
|
||||||
for (idx, diff, overloaded_fun) in &diffs {
|
for (idx, diff, overloaded_fun) in &diffs {
|
||||||
let alias = match diff.len() {
|
let alias = match diff.len() {
|
||||||
0 => {
|
0 => {
|
||||||
|
// this may happen if there are functions with different casing,
|
||||||
|
// like `INDEX`and `index`
|
||||||
|
if overloaded_fun.name != first_fun.name {
|
||||||
|
let overloaded_id = overloaded_fun.name.to_snake_case();
|
||||||
|
let first_fun_id = first_fun.name.to_snake_case();
|
||||||
|
if first_fun_id != overloaded_id {
|
||||||
|
// no conflict
|
||||||
|
overloaded_id
|
||||||
|
} else {
|
||||||
|
let overloaded_alias = MethodAlias {
|
||||||
|
function_name: util::safe_ident(&overloaded_fun.name),
|
||||||
|
struct_name: util::safe_ident(&overloaded_fun.name),
|
||||||
|
};
|
||||||
|
aliases.insert(overloaded_fun.abi_signature(), overloaded_alias);
|
||||||
|
|
||||||
|
let first_fun_alias = MethodAlias {
|
||||||
|
function_name: util::safe_ident(&first_fun.name),
|
||||||
|
struct_name: util::safe_ident(&first_fun.name),
|
||||||
|
};
|
||||||
|
aliases.insert(first_fun.abi_signature(), first_fun_alias);
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// this should not happen since functions with same name and inputs are
|
// this should not happen since functions with same name and inputs are
|
||||||
// illegal
|
// illegal
|
||||||
anyhow::bail!(
|
anyhow::bail!(
|
||||||
|
@ -364,6 +401,7 @@ impl Context {
|
||||||
overloaded_fun.name
|
overloaded_fun.name
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
1 => {
|
1 => {
|
||||||
// single additional input params
|
// single additional input params
|
||||||
if diff[0].name.is_empty() ||
|
if diff[0].name.is_empty() ||
|
||||||
|
@ -404,13 +442,17 @@ impl Context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
aliases.insert(overloaded_fun.abi_signature(), util::safe_ident(&alias));
|
let alias = MethodAlias::new(&alias);
|
||||||
|
aliases.insert(overloaded_fun.abi_signature(), alias);
|
||||||
}
|
}
|
||||||
|
|
||||||
if needs_alias_for_first_fun_using_idx {
|
if needs_alias_for_first_fun_using_idx {
|
||||||
// insert an alias for the root duplicated call
|
// insert an alias for the root duplicated call
|
||||||
let prev_alias = format!("{}{}", first_fun.name.to_snake_case(), first_fun_idx);
|
let prev_alias = format!("{}{}", first_fun.name.to_snake_case(), first_fun_idx);
|
||||||
aliases.insert(first_fun.abi_signature(), util::safe_ident(&prev_alias));
|
|
||||||
|
let alias = MethodAlias::new(&prev_alias);
|
||||||
|
|
||||||
|
aliases.insert(first_fun.abi_signature(), alias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -426,7 +468,7 @@ impl Context {
|
||||||
for function in functions {
|
for function in functions {
|
||||||
if let Entry::Vacant(entry) = aliases.entry(function.abi_signature()) {
|
if let Entry::Vacant(entry) = aliases.entry(function.abi_signature()) {
|
||||||
// use the full name as alias
|
// use the full name as alias
|
||||||
entry.insert(util::ident(name.as_str()));
|
entry.insert(MethodAlias::new(name.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -455,27 +497,34 @@ fn expand_selector(selector: Selector) -> TokenStream {
|
||||||
quote! { [#( #bytes ),*] }
|
quote! { [#( #bytes ),*] }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expand_function_name(function: &Function, alias: Option<&Ident>) -> Ident {
|
/// Represents the aliases to use when generating method related elements
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MethodAlias {
|
||||||
|
pub function_name: Ident,
|
||||||
|
pub struct_name: Ident,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MethodAlias {
|
||||||
|
pub fn new(alias: &str) -> Self {
|
||||||
|
MethodAlias {
|
||||||
|
function_name: util::safe_snake_case_ident(alias),
|
||||||
|
struct_name: util::safe_pascal_case_ident(alias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expand_function_name(function: &Function, alias: Option<&MethodAlias>) -> Ident {
|
||||||
if let Some(alias) = alias {
|
if let Some(alias) = alias {
|
||||||
// snake_case strips leading and trailing underscores so we simply add them back if the
|
alias.function_name.clone()
|
||||||
// alias starts/ends with underscores
|
|
||||||
let alias = alias.to_string();
|
|
||||||
let ident = alias.to_snake_case();
|
|
||||||
util::ident(&util::preserve_underscore_delim(&ident, &alias))
|
|
||||||
} else {
|
} else {
|
||||||
util::safe_ident(&function.name.to_snake_case())
|
util::safe_ident(&function.name.to_snake_case())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Expands to the name of the call struct
|
/// Expands to the name of the call struct
|
||||||
fn expand_call_struct_name(function: &Function, alias: Option<&Ident>) -> Ident {
|
fn expand_call_struct_name(function: &Function, alias: Option<&MethodAlias>) -> Ident {
|
||||||
let name = if let Some(alias) = alias {
|
let name = if let Some(alias) = alias {
|
||||||
// pascal_case strips leading and trailing underscores so we simply add them back if the
|
format!("{}Call", alias.struct_name)
|
||||||
// alias starts/ends with underscores
|
|
||||||
let alias = alias.to_string();
|
|
||||||
let ident = alias.to_pascal_case();
|
|
||||||
let alias = util::preserve_underscore_delim(&ident, &alias);
|
|
||||||
format!("{}Call", alias)
|
|
||||||
} else {
|
} else {
|
||||||
format!("{}Call", function.name.to_pascal_case())
|
format!("{}Call", function.name.to_pascal_case())
|
||||||
};
|
};
|
||||||
|
@ -483,15 +532,12 @@ fn expand_call_struct_name(function: &Function, alias: Option<&Ident>) -> Ident
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Expands to the name of the call struct
|
/// Expands to the name of the call struct
|
||||||
fn expand_call_struct_variant_name(function: &Function, alias: Option<&Ident>) -> Ident {
|
fn expand_call_struct_variant_name(function: &Function, alias: Option<&MethodAlias>) -> Ident {
|
||||||
let name = if let Some(alias) = alias {
|
if let Some(alias) = alias {
|
||||||
let alias = alias.to_string();
|
alias.struct_name.clone()
|
||||||
let ident = alias.to_pascal_case();
|
|
||||||
util::preserve_underscore_delim(&ident, &alias)
|
|
||||||
} else {
|
} else {
|
||||||
function.name.to_pascal_case()
|
util::safe_ident(&function.name.to_pascal_case())
|
||||||
};
|
}
|
||||||
util::ident(&name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Expands to the tuple struct definition
|
/// Expands to the tuple struct definition
|
||||||
|
|
|
@ -9,7 +9,7 @@ use quote::quote;
|
||||||
|
|
||||||
use syn::{Ident as SynIdent, Path};
|
use syn::{Ident as SynIdent, Path};
|
||||||
|
|
||||||
/// Expands a identifier string into an token.
|
/// Expands a identifier string into a token.
|
||||||
pub fn ident(name: &str) -> Ident {
|
pub fn ident(name: &str) -> Ident {
|
||||||
Ident::new(name, Span::call_site())
|
Ident::new(name, Span::call_site())
|
||||||
}
|
}
|
||||||
|
@ -22,6 +22,18 @@ pub fn safe_ident(name: &str) -> Ident {
|
||||||
syn::parse_str::<SynIdent>(name).unwrap_or_else(|_| ident(&format!("{}_", name)))
|
syn::parse_str::<SynIdent>(name).unwrap_or_else(|_| ident(&format!("{}_", name)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Expands an identifier as snakecase and preserve any leading or trailing underscores
|
||||||
|
pub fn safe_snake_case_ident(name: &str) -> Ident {
|
||||||
|
let i = name.to_snake_case();
|
||||||
|
ident(&preserve_underscore_delim(&i, name))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expands an identifier as pascal case and preserve any leading or trailing underscores
|
||||||
|
pub fn safe_pascal_case_ident(name: &str) -> Ident {
|
||||||
|
let i = name.to_pascal_case();
|
||||||
|
ident(&preserve_underscore_delim(&i, name))
|
||||||
|
}
|
||||||
|
|
||||||
/// Reapplies leading and trailing underscore chars to the ident
|
/// Reapplies leading and trailing underscore chars to the ident
|
||||||
/// Example `ident = "pascalCase"; alias = __pascalcase__` -> `__pascalCase__`
|
/// Example `ident = "pascalCase"; alias = __pascalcase__` -> `__pascalCase__`
|
||||||
pub fn preserve_underscore_delim(ident: &str, alias: &str) -> String {
|
pub fn preserve_underscore_delim(ident: &str, alias: &str) -> String {
|
||||||
|
|
|
@ -402,3 +402,20 @@ fn can_generate_nested_types() {
|
||||||
let decoded_call = MyfunCall::decode(encoded_call.as_ref()).unwrap();
|
let decoded_call = MyfunCall::decode(encoded_call.as_ref()).unwrap();
|
||||||
assert_eq!(call, decoded_call);
|
assert_eq!(call, decoded_call);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn can_handle_case_sensitive_calls() {
|
||||||
|
abigen!(
|
||||||
|
StakedOHM,
|
||||||
|
r#"[
|
||||||
|
index()
|
||||||
|
INDEX()
|
||||||
|
]"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (client, _mock) = Provider::mocked();
|
||||||
|
let contract = StakedOHM::new(Address::default(), Arc::new(client));
|
||||||
|
|
||||||
|
let _ = contract.index();
|
||||||
|
let _ = contract.INDEX();
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue