From 279a2c316cf384a6b086e9cb2e5e86c85b844e54 Mon Sep 17 00:00:00 2001 From: Matthias Seitz Date: Sun, 5 Dec 2021 21:36:49 +0100 Subject: [PATCH] feat(abigen): support overloaded functions with different casing (#650) * fix: support overloaded functions with different casing * chore: fmt * chore: typos * feat: better method alias handling --- .../ethers-contract-abigen/src/contract.rs | 9 +- .../src/contract/methods.rs | 122 ++++++++++++------ .../ethers-contract-abigen/src/util.rs | 14 +- ethers-contract/tests/abigen.rs | 17 +++ 4 files changed, 121 insertions(+), 41 deletions(-) diff --git a/ethers-contract/ethers-contract-abigen/src/contract.rs b/ethers-contract/ethers-contract-abigen/src/contract.rs index 24730e5e..b7af7dbb 100644 --- a/ethers-contract/ethers-contract-abigen/src/contract.rs +++ b/ethers-contract/ethers-contract-abigen/src/contract.rs @@ -13,6 +13,7 @@ use ethers_core::{ macros::{ethers_contract_crate, ethers_core_crate, ethers_providers_crate}, }; +use crate::contract::methods::MethodAlias; use proc_macro2::{Ident, Literal, TokenStream}; use quote::quote; use serde::Deserialize; @@ -78,7 +79,7 @@ pub struct Context { contract_name: Ident, /// Manually specified method aliases. - method_aliases: BTreeMap, + method_aliases: BTreeMap, /// Derives added to event structs and enums. event_derives: Vec, @@ -204,7 +205,11 @@ impl Context { // method will be re-defined. let mut method_aliases = BTreeMap::new(); for (signature, alias) in args.method_aliases.into_iter() { - let alias = syn::parse_str(&alias)?; + let alias = MethodAlias { + function_name: util::safe_ident(&alias), + struct_name: util::safe_pascal_case_ident(&alias), + }; + if method_aliases.insert(signature.clone(), alias).is_some() { return Err(anyhow!("duplicate method signature '{}' in method aliases", signature,)) } diff --git a/ethers-contract/ethers-contract-abigen/src/contract/methods.rs b/ethers-contract/ethers-contract-abigen/src/contract/methods.rs index 4b68abd3..eb40aa8f 100644 --- a/ethers-contract/ethers-contract-abigen/src/contract/methods.rs +++ b/ethers-contract/ethers-contract-abigen/src/contract/methods.rs @@ -1,4 +1,4 @@ -use std::collections::{btree_map::Entry, BTreeMap}; +use std::collections::{btree_map::Entry, BTreeMap, HashMap}; use anyhow::{Context as _, Result}; use inflector::Inflector; @@ -47,7 +47,7 @@ impl Context { fn expand_call_struct( &self, function: &Function, - alias: Option<&Ident>, + alias: Option<&MethodAlias>, ) -> Result { let call_name = expand_call_struct_name(function, alias); let fields = self.expand_input_pairs(function)?; @@ -82,7 +82,7 @@ impl Context { } /// Expands all structs - fn expand_call_structs(&self, aliases: BTreeMap) -> Result { + fn expand_call_structs(&self, aliases: BTreeMap) -> Result { let mut struct_defs = Vec::new(); let mut struct_names = Vec::new(); let mut variant_names = Vec::new(); @@ -236,7 +236,11 @@ impl Context { } /// Expands a single function with the given alias - fn expand_function(&self, function: &Function, alias: Option) -> Result { + fn expand_function( + &self, + function: &Function, + alias: Option, + ) -> Result { let name = expand_function_name(function, alias.as_ref()); let selector = expand_selector(function.selector()); @@ -275,10 +279,21 @@ impl Context { // The first function or the function with the least amount of arguments should // be named as in the ABI, the following functions suffixed with _with_ + // additional_params[0].name + (_and_(additional_params[1+i].name))* - fn get_method_aliases(&self) -> Result> { + fn get_method_aliases(&self) -> Result> { let mut aliases = self.method_aliases.clone(); + + // it might be the case that there are functions with different capitalization so we sort + // them all by lc name first + let mut all_functions = HashMap::new(); + for function in self.abi.functions() { + all_functions + .entry(function.name.to_lowercase()) + .or_insert_with(Vec::new) + .push(function); + } + // find all duplicates, where no aliases where provided - for functions in self.abi.functions.values() { + for functions in all_functions.values() { if functions.iter().filter(|f| !aliases.contains_key(&f.abi_signature())).count() <= 1 { // no overloads, hence no conflicts continue @@ -318,7 +333,7 @@ impl Context { let mut diffs = Vec::new(); /// helper function that checks if there are any conflicts due to parameter names - fn name_conflicts(idx: usize, diffs: &[(usize, Vec<&Param>, &Function)]) -> bool { + fn name_conflicts(idx: usize, diffs: &[(usize, Vec<&Param>, &&Function)]) -> bool { let diff = &diffs.iter().find(|(i, _, _)| *i == idx).expect("diff exists").1; for (_, other, _) in diffs.iter().filter(|(i, _, _)| *i != idx) { @@ -333,7 +348,6 @@ impl Context { } false } - // compare each overloaded function with the `first_fun` for (idx, overloaded_fun) in functions.into_iter().skip(1) { // attempt to find diff in the input arguments @@ -357,12 +371,36 @@ impl Context { for (idx, diff, overloaded_fun) in &diffs { let alias = match diff.len() { 0 => { - // this should not happen since functions with same name and inputs are - // illegal - anyhow::bail!( - "Function with same name and parameter types defined twice: {}", - overloaded_fun.name - ); + // this may happen if there are functions with different casing, + // like `INDEX`and `index` + if overloaded_fun.name != first_fun.name { + let overloaded_id = overloaded_fun.name.to_snake_case(); + let first_fun_id = first_fun.name.to_snake_case(); + if first_fun_id != overloaded_id { + // no conflict + overloaded_id + } else { + let overloaded_alias = MethodAlias { + function_name: util::safe_ident(&overloaded_fun.name), + struct_name: util::safe_ident(&overloaded_fun.name), + }; + aliases.insert(overloaded_fun.abi_signature(), overloaded_alias); + + let first_fun_alias = MethodAlias { + function_name: util::safe_ident(&first_fun.name), + struct_name: util::safe_ident(&first_fun.name), + }; + aliases.insert(first_fun.abi_signature(), first_fun_alias); + continue + } + } else { + // this should not happen since functions with same name and inputs are + // illegal + anyhow::bail!( + "Function with same name and parameter types defined twice: {}", + overloaded_fun.name + ); + } } 1 => { // single additional input params @@ -404,13 +442,17 @@ impl Context { } } }; - aliases.insert(overloaded_fun.abi_signature(), util::safe_ident(&alias)); + let alias = MethodAlias::new(&alias); + aliases.insert(overloaded_fun.abi_signature(), alias); } if needs_alias_for_first_fun_using_idx { // insert an alias for the root duplicated call let prev_alias = format!("{}{}", first_fun.name.to_snake_case(), first_fun_idx); - aliases.insert(first_fun.abi_signature(), util::safe_ident(&prev_alias)); + + let alias = MethodAlias::new(&prev_alias); + + aliases.insert(first_fun.abi_signature(), alias); } } @@ -426,7 +468,7 @@ impl Context { for function in functions { if let Entry::Vacant(entry) = aliases.entry(function.abi_signature()) { // use the full name as alias - entry.insert(util::ident(name.as_str())); + entry.insert(MethodAlias::new(name.as_str())); } } } @@ -455,27 +497,34 @@ fn expand_selector(selector: Selector) -> TokenStream { quote! { [#( #bytes ),*] } } -fn expand_function_name(function: &Function, alias: Option<&Ident>) -> Ident { +/// Represents the aliases to use when generating method related elements +#[derive(Debug, Clone)] +pub struct MethodAlias { + pub function_name: Ident, + pub struct_name: Ident, +} + +impl MethodAlias { + pub fn new(alias: &str) -> Self { + MethodAlias { + function_name: util::safe_snake_case_ident(alias), + struct_name: util::safe_pascal_case_ident(alias), + } + } +} + +fn expand_function_name(function: &Function, alias: Option<&MethodAlias>) -> Ident { if let Some(alias) = alias { - // snake_case strips leading and trailing underscores so we simply add them back if the - // alias starts/ends with underscores - let alias = alias.to_string(); - let ident = alias.to_snake_case(); - util::ident(&util::preserve_underscore_delim(&ident, &alias)) + alias.function_name.clone() } else { util::safe_ident(&function.name.to_snake_case()) } } /// Expands to the name of the call struct -fn expand_call_struct_name(function: &Function, alias: Option<&Ident>) -> Ident { +fn expand_call_struct_name(function: &Function, alias: Option<&MethodAlias>) -> Ident { let name = if let Some(alias) = alias { - // pascal_case strips leading and trailing underscores so we simply add them back if the - // alias starts/ends with underscores - let alias = alias.to_string(); - let ident = alias.to_pascal_case(); - let alias = util::preserve_underscore_delim(&ident, &alias); - format!("{}Call", alias) + format!("{}Call", alias.struct_name) } else { format!("{}Call", function.name.to_pascal_case()) }; @@ -483,15 +532,12 @@ fn expand_call_struct_name(function: &Function, alias: Option<&Ident>) -> Ident } /// Expands to the name of the call struct -fn expand_call_struct_variant_name(function: &Function, alias: Option<&Ident>) -> Ident { - let name = if let Some(alias) = alias { - let alias = alias.to_string(); - let ident = alias.to_pascal_case(); - util::preserve_underscore_delim(&ident, &alias) +fn expand_call_struct_variant_name(function: &Function, alias: Option<&MethodAlias>) -> Ident { + if let Some(alias) = alias { + alias.struct_name.clone() } else { - function.name.to_pascal_case() - }; - util::ident(&name) + util::safe_ident(&function.name.to_pascal_case()) + } } /// Expands to the tuple struct definition diff --git a/ethers-contract/ethers-contract-abigen/src/util.rs b/ethers-contract/ethers-contract-abigen/src/util.rs index f9435a51..138f744c 100644 --- a/ethers-contract/ethers-contract-abigen/src/util.rs +++ b/ethers-contract/ethers-contract-abigen/src/util.rs @@ -9,7 +9,7 @@ use quote::quote; use syn::{Ident as SynIdent, Path}; -/// Expands a identifier string into an token. +/// Expands a identifier string into a token. pub fn ident(name: &str) -> Ident { Ident::new(name, Span::call_site()) } @@ -22,6 +22,18 @@ pub fn safe_ident(name: &str) -> Ident { syn::parse_str::(name).unwrap_or_else(|_| ident(&format!("{}_", name))) } +/// Expands an identifier as snakecase and preserve any leading or trailing underscores +pub fn safe_snake_case_ident(name: &str) -> Ident { + let i = name.to_snake_case(); + ident(&preserve_underscore_delim(&i, name)) +} + +/// Expands an identifier as pascal case and preserve any leading or trailing underscores +pub fn safe_pascal_case_ident(name: &str) -> Ident { + let i = name.to_pascal_case(); + ident(&preserve_underscore_delim(&i, name)) +} + /// Reapplies leading and trailing underscore chars to the ident /// Example `ident = "pascalCase"; alias = __pascalcase__` -> `__pascalCase__` pub fn preserve_underscore_delim(ident: &str, alias: &str) -> String { diff --git a/ethers-contract/tests/abigen.rs b/ethers-contract/tests/abigen.rs index 7358fa82..fcfa8b39 100644 --- a/ethers-contract/tests/abigen.rs +++ b/ethers-contract/tests/abigen.rs @@ -402,3 +402,20 @@ fn can_generate_nested_types() { let decoded_call = MyfunCall::decode(encoded_call.as_ref()).unwrap(); assert_eq!(call, decoded_call); } + +#[test] +fn can_handle_case_sensitive_calls() { + abigen!( + StakedOHM, + r#"[ + index() + INDEX() + ]"#, + ); + + let (client, _mock) = Provider::mocked(); + let contract = StakedOHM::new(Address::default(), Arc::new(client)); + + let _ = contract.index(); + let _ = contract.INDEX(); +}