diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b517dd0..c0dab0e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -71,6 +71,7 @@ [#1184](https://github.com/gakonst/ethers-rs/pull/1184) - Add `From` and From> traits to `ValueOrArray` [#1199](https://github.com/gakonst/ethers-rs/pull/1200) - Fix handling of Websocket connection errors [#1287](https://github.com/gakonst/ethers-rs/pull/1287) +- Add Arithmetic Shift Right operation for I256 [#1323](https://github.com/gakonst/ethers-rs/issues/1323) ## ethers-contract-abigen diff --git a/ethers-core/src/types/i256.rs b/ethers-core/src/types/i256.rs index e568c8b2..6a1c4dba 100644 --- a/ethers-core/src/types/i256.rs +++ b/ethers-core/src/types/i256.rs @@ -10,7 +10,9 @@ use serde::{Deserialize, Serialize}; use std::{ cmp, convert::{TryFrom, TryInto}, - fmt, i128, i64, iter, ops, str, u64, + fmt, i128, i64, iter, ops, + ops::Sub, + str, u64, }; use thiserror::Error; @@ -935,6 +937,35 @@ impl I256 { let (result, _) = self.overflowing_pow(exp); result } + + /// Arithmetic Shift Right operation. Shifts `shift` number of times to the right maintaining + /// the original sign. If the number is positive this is the same as logic shift right. + pub fn asr(self, shift: u32) -> Self { + // Avoid shifting if we are going to know the result regardless of the value. + if shift == 0 { + self + } else if shift >= 255u32 { + match self.sign() { + // It's always going to be zero (i.e. 00000000...00000000) + Sign::Positive => Self::zero(), + // It's always going to be -1 (i.e. 11111111...11111111) + Sign::Negative => Self::from(-1i8), + } + } else { + // Perform the shift. + match self.sign() { + Sign::Positive => self >> shift, + // We need to do: `for 0..shift { self >> 1 | 2^255 }` + // We can avoid the loop by doing: `self >> shift | ~(2^(255 - shift) - 1)` + // where '~' represents ones complement + Sign::Negative => { + let bitwise_or = + Self::from_raw(!U256::from(2u8).pow(U256::from(255u32 - shift)).sub(1u8)); + (self >> shift) | bitwise_or + } + } + } + } } macro_rules! impl_from { @@ -1276,6 +1307,7 @@ mod tests { use crate::abi::Tokenizable; use once_cell::sync::Lazy; use serde_json::json; + use std::ops::Neg; static MIN_ABS: Lazy = Lazy::new(|| U256::from(1) << 255); @@ -1521,6 +1553,33 @@ mod tests { assert_eq!(I256::MIN >> 255, I256::one()); } + #[test] + fn arithmetic_shift_right() { + let value = I256::from_raw(U256::from(2u8).pow(U256::from(254u8))).neg(); + let expected_result = I256::from_raw(U256::MAX.sub(1u8)); + assert_eq!(value.asr(253u32), expected_result, "1011...1111 >> 253 was not 1111...1110"); + + let value = I256::from(-1i8); + let expected_result = I256::from(-1i8); + assert_eq!(value.asr(250u32), expected_result, "-1 >> any_amount was not -1"); + + let value = I256::from_raw(U256::from(2u8).pow(U256::from(254u8))).neg(); + let expected_result = I256::from(-1i8); + assert_eq!(value.asr(255u32), expected_result, "1011...1111 >> 255 was not -1"); + + let value = I256::from_raw(U256::from(2u8).pow(U256::from(254u8))).neg(); + let expected_result = I256::from(-1i8); + assert_eq!(value.asr(1024u32), expected_result, "1011...1111 >> 1024 was not -1"); + + let value = I256::from(1024i32); + let expected_result = I256::from(32i32); + assert_eq!(value.asr(5u32), expected_result, "1024 >> 5 was not 32"); + + let value = I256::MAX; + let expected_result = I256::zero(); + assert_eq!(value.asr(255u32), expected_result, "I256::MAX >> 255 was not 0"); + } + #[test] fn addition() { assert_eq!(I256::MIN.overflowing_add(I256::MIN), (I256::zero(), true));