1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
use bigdecimal::{
num_bigint::{BigInt, Sign},
BigDecimal, Zero,
};
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{BufMut, BytesMut};
use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use std::{cmp, convert::TryInto, error, fmt, io::Cursor};
#[derive(Debug, Clone)]
pub struct DecimalWrapper(pub BigDecimal);
#[derive(Debug, Clone)]
pub struct InvalidDecimal(&'static str);
impl fmt::Display for InvalidDecimal {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_fmt(format_args!("Invalid Decimal: {}", self.0))
}
}
impl error::Error for InvalidDecimal {}
struct PostgresDecimal<D> {
neg: bool,
weight: i16,
scale: u16,
digits: D,
}
fn from_postgres<D: ExactSizeIterator<Item = u16>>(dec: PostgresDecimal<D>) -> Result<BigDecimal, InvalidDecimal> {
let PostgresDecimal {
neg, digits, weight, ..
} = dec;
if digits.len() == 0 {
return Ok(0u64.into());
}
let sign = match neg {
false => Sign::Plus,
true => Sign::Minus,
};
// weight is 0 if the decimal point falls after the first base-10000 digit
let scale = (digits.len() as i64 - weight as i64 - 1) * 4;
// no optimized algorithm for base-10 so use base-100 for faster processing
let mut cents = Vec::with_capacity(digits.len() * 2);
for digit in digits {
cents.push((digit / 100) as u8);
cents.push((digit % 100) as u8);
}
let bigint = BigInt::from_radix_be(sign, ¢s, 100)
.ok_or(InvalidDecimal("PostgresDecimal contained an out-of-range digit"))?;
Ok(BigDecimal::new(bigint, scale))
}
fn to_postgres(decimal: &BigDecimal) -> crate::Result<PostgresDecimal<Vec<i16>>> {
if decimal.is_zero() {
return Ok(PostgresDecimal {
neg: false,
weight: 0,
scale: 0,
digits: vec![],
});
}
let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16);
// NOTE: this unfortunately copies the BigInt internally
let (integer, exp) = decimal.as_bigint_and_exponent();
// this routine is specifically optimized for base-10
// FIXME: is there a way to iterate over the digits to avoid the Vec allocation
let (sign, base_10) = integer.to_radix_be(10);
// weight is positive power of 10000
// exp is the negative power of 10
let weight_10 = base_10.len() as i64 - exp;
// scale is only nonzero when we have fractional digits
// since `exp` is the _negative_ decimal exponent, it tells us
// exactly what our scale should be
let scale: u16 = cmp::max(0, exp).try_into()?;
// there's an implicit +1 offset in the interpretation
let weight: i16 = if weight_10 <= 0 {
weight_10 / 4 - 1
} else {
// the `-1` is a fix for an off by 1 error (4 digits should still be 0 weight)
(weight_10 - 1) / 4
}
.try_into()?;
let digits_len = if base_10.len() % 4 != 0 {
base_10.len() / 4 + 1
} else {
base_10.len() / 4
};
let offset = weight_10.rem_euclid(4) as usize;
// Array to store max mantissa of Decimal in Postgres decimal format.
let mut digits = Vec::with_capacity(digits_len);
// Convert to base 10000
if let Some(first) = base_10.get(..offset) {
if !first.is_empty() {
digits.push(base_10_to_10000(first));
}
} else if offset != 0 {
digits.push(base_10_to_10000(&base_10) * 10i16.pow((offset - base_10.len()) as u32));
}
// Convert to base 10000
if let Some(rest) = base_10.get(offset..) {
digits.extend(
rest.chunks(4)
.map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)),
);
}
// Remove non-significant zeroes.
while let Some(&0) = digits.last() {
digits.pop();
}
let neg = match sign {
Sign::Plus | Sign::NoSign => false,
Sign::Minus => true,
};
Ok(PostgresDecimal {
neg,
weight,
scale,
digits,
})
}
impl<'a> FromSql<'a> for DecimalWrapper {
// Decimals are represented as follows:
// Header:
// u16 numGroups
// i16 weightFirstGroup (10000^weight)
// u16 sign (0x0000 = positive, 0x4000 = negative, 0xC000 = NaN)
// i16 dscale. Number of digits (in base 10) to print after decimal separator
//
// Pseudo code :
// const Decimals [
// 0.0000000000000000000000000001,
// 0.000000000000000000000001,
// 0.00000000000000000001,
// 0.0000000000000001,
// 0.000000000001,
// 0.00000001,
// 0.0001,
// 1,
// 10000,
// 100000000,
// 1000000000000,
// 10000000000000000,
// 100000000000000000000,
// 1000000000000000000000000,
// 10000000000000000000000000000
// ]
// overflow = false
// result = 0
// for i = 0, weight = weightFirstGroup + 7; i < numGroups; i++, weight--
// group = read.u16
// if weight < 0 or weight > MaxNum
// overflow = true
// else
// result += Decimals[weight] * group
// sign == 0x4000 ? -result : result
// So if we were to take the number: 3950.123456
//
// Stored on Disk:
// 00 03 00 00 00 00 00 06 0F 6E 04 D2 15 E0
//
// Number of groups: 00 03
// Weight of first group: 00 00
// Sign: 00 00
// DScale: 00 06
//
// 0F 6E = 3950
// result = result + 3950 * 1;
// 04 D2 = 1234
// result = result + 1234 * 0.0001;
// 15 E0 = 5600
// result = result + 5600 * 0.00000001;
//
fn from_sql(_: &Type, raw: &[u8]) -> Result<DecimalWrapper, Box<dyn error::Error + 'static + Sync + Send>> {
let mut raw = Cursor::new(raw);
let num_groups = raw.read_u16::<BigEndian>()?;
let weight = raw.read_i16::<BigEndian>()?; // 10000^weight
// Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN
let sign = raw.read_u16::<BigEndian>()?;
// Number of digits (in base 10) to print after decimal separator
let scale = raw.read_u16::<BigEndian>()?;
// Read all of the groups
let mut groups = Vec::new();
for _ in 0..num_groups as usize {
groups.push(raw.read_u16::<BigEndian>()?);
}
let dec = from_postgres(PostgresDecimal {
neg: sign == 0x4000,
weight,
scale,
digits: groups.into_iter(),
})
.map_err(Box::new)?;
Ok(DecimalWrapper(dec))
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
}
impl ToSql for DecimalWrapper {
fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn error::Error + 'static + Sync + Send>> {
let PostgresDecimal {
neg,
weight,
scale,
digits,
} = to_postgres(&self.0)?;
let num_digits = digits.len();
// Reserve bytes
out.reserve(8 + num_digits * 2);
// Number of groups
out.put_u16(num_digits.try_into()?);
// Weight of first group
out.put_i16(weight);
// Sign
out.put_u16(if neg { 0x4000 } else { 0x0000 });
// DScale
out.put_u16(scale);
// Now process the number
for digit in digits[0..num_digits].iter() {
out.put_i16(*digit);
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
to_sql_checked!();
}