forked from Rust-related/luminal
mostly fixed symbolic
This commit is contained in:
@@ -61,22 +61,6 @@ impl GraphTensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for ConstantValue {
|
||||
fn from(value: f32) -> Self {
|
||||
ConstantValue::Float(value)
|
||||
}
|
||||
}
|
||||
impl From<f64> for ConstantValue {
|
||||
fn from(value: f64) -> Self {
|
||||
ConstantValue::Float(value as f32)
|
||||
}
|
||||
}
|
||||
impl<T: Into<Expression>> From<T> for ConstantValue {
|
||||
fn from(value: T) -> Self {
|
||||
ConstantValue::Expression(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Graph {
|
||||
/// A scalar constant
|
||||
pub fn constant(&mut self, i: impl Into<ConstantValue>) -> GraphTensor {
|
||||
|
||||
20
src/op.rs
20
src/op.rs
@@ -138,6 +138,22 @@ pub enum ConstantValue {
|
||||
Float(f32),
|
||||
}
|
||||
|
||||
impl From<f32> for ConstantValue {
|
||||
fn from(value: f32) -> Self {
|
||||
ConstantValue::Float(value)
|
||||
}
|
||||
}
|
||||
impl From<f64> for ConstantValue {
|
||||
fn from(value: f64) -> Self {
|
||||
ConstantValue::Float(value as f32)
|
||||
}
|
||||
}
|
||||
impl<T: Into<Expression>> From<T> for ConstantValue {
|
||||
fn from(value: T) -> Self {
|
||||
ConstantValue::Expression(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Produces a single number constant from an expression or a float
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct Constant(pub ConstantValue, pub *const FxHashMap<char, usize>);
|
||||
@@ -280,7 +296,6 @@ pub struct Mul;
|
||||
impl Operator for Mul {
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0));
|
||||
println!("EXPR: {:?}", inp[0].1.dims());
|
||||
let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()];
|
||||
let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
|
||||
let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
|
||||
@@ -389,7 +404,8 @@ fn get_index(
|
||||
index: usize,
|
||||
) -> f32 {
|
||||
if val.exec_single_var_stack(index, stack) != 0 {
|
||||
data[ind.exec_single_var_stack(index, stack)]
|
||||
let i = ind.exec_single_var_stack(index, stack);
|
||||
data[i]
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ use std::{
|
||||
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, DivAssign, Mul, MulAssign,
|
||||
Rem, RemAssign, Sub, SubAssign,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use symbolic_expressions::Sexp;
|
||||
|
||||
@@ -476,6 +475,13 @@ impl From<&bool> for Expression {
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<Expression> for usize {
|
||||
type Output = Expression;
|
||||
fn add(self, rhs: Expression) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Expression> for usize {
|
||||
type Output = Expression;
|
||||
fn sub(self, rhs: Expression) -> Self::Output {
|
||||
@@ -497,6 +503,76 @@ impl Div<Expression> for usize {
|
||||
}
|
||||
}
|
||||
|
||||
impl Rem<Expression> for usize {
|
||||
type Output = Expression;
|
||||
fn rem(self, rhs: Expression) -> Self::Output {
|
||||
Expression::from(self) % rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl BitAnd<Expression> for usize {
|
||||
type Output = Expression;
|
||||
fn bitand(self, rhs: Expression) -> Self::Output {
|
||||
rhs & self
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr<Expression> for usize {
|
||||
type Output = Expression;
|
||||
fn bitor(self, rhs: Expression) -> Self::Output {
|
||||
rhs | self
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<Expression> for i32 {
|
||||
type Output = Expression;
|
||||
fn add(self, rhs: Expression) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Expression> for i32 {
|
||||
type Output = Expression;
|
||||
fn sub(self, rhs: Expression) -> Self::Output {
|
||||
Expression::from(self) - rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<Expression> for i32 {
|
||||
type Output = Expression;
|
||||
fn mul(self, rhs: Expression) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<Expression> for i32 {
|
||||
type Output = Expression;
|
||||
fn div(self, rhs: Expression) -> Self::Output {
|
||||
Expression::from(self) / rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Rem<Expression> for i32 {
|
||||
type Output = Expression;
|
||||
fn rem(self, rhs: Expression) -> Self::Output {
|
||||
Expression::from(self) % rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl BitAnd<Expression> for i32 {
|
||||
type Output = Expression;
|
||||
fn bitand(self, rhs: Expression) -> Self::Output {
|
||||
rhs & self
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr<Expression> for i32 {
|
||||
type Output = Expression;
|
||||
fn bitor(self, rhs: Expression) -> Self::Output {
|
||||
rhs | self
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<Expression>> Add<E> for Expression {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: E) -> Self::Output {
|
||||
@@ -883,37 +959,13 @@ impl Analysis<Math> for ConstantFold {
|
||||
a.checked_div(b)?
|
||||
}
|
||||
}
|
||||
Math::Mod([a, b]) if x(b) != Some(0) => x(a)?.checked_rem(x(b)?)?,
|
||||
Math::Min([a, b]) if x(b) != Some(0) => x(a)?.min(x(b)?),
|
||||
Math::Max([a, b]) if x(b) != Some(0) => x(a)?.max(x(b)?),
|
||||
Math::And([a, b]) if x(b) != Some(0) => {
|
||||
if x(a)? != 0 && x(b)? != 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
Math::Or([a, b]) if x(b) != Some(0) => {
|
||||
if x(a)? != 0 || x(b)? != 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
Math::LessThan([a, b]) if x(b) != Some(0) => {
|
||||
if x(a)? < x(b)? {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
Math::GreaterThanEqual([a, b]) if x(b) != Some(0) => {
|
||||
if x(a)? >= x(b)? {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
Math::Mod([a, b]) => x(a)?.checked_rem(x(b)?)?,
|
||||
Math::Min([a, b]) => x(a)?.min(x(b)?),
|
||||
Math::Max([a, b]) => x(a)?.max(x(b)?),
|
||||
Math::And([a, b]) => (x(a)? != 0 && x(b)? != 0) as i32,
|
||||
Math::Or([a, b]) => (x(a)? != 0 || x(b)? != 0) as i32,
|
||||
Math::LessThan([a, b]) => (x(a)? < x(b)?) as i32,
|
||||
Math::GreaterThanEqual([a, b]) => (x(a)? >= x(b)?) as i32,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
@@ -965,23 +1017,23 @@ fn make_rules() -> Vec<Rewrite> {
|
||||
rewrite!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
|
||||
rewrite!("assoc-div"; "(/ (/ ?a ?b) ?c)" => "(/ ?a (* ?b ?c))"),
|
||||
rewrite!("mul-div-associative"; "(/ (* ?a ?b) ?c)" => "(* ?a (/ ?b ?c))"),
|
||||
// rewrite!("mul-div-associative-rev"; "(* ?a (/ ?b ?c))" => "(/ (* ?a ?b) ?c)"),
|
||||
// rewrite!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"),
|
||||
// rewrite!("mul-div-associative-rev"; "(* ?a (/ ?b ?c))" => "(/ (* ?a ?b) ?c)"), // BAD? Makes test_pool_1d fail
|
||||
rewrite!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"),
|
||||
// Distributive
|
||||
rewrite!("distribute-mul"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
|
||||
// rewrite!("distribute-div"; "(/ (+ ?a ?b) ?c)" => "(+ (/ ?a ?c) (/ ?b ?c))"),
|
||||
rewrite!("distribute-div"; "(/ (+ ?a ?b) ?c)" => "(+ (/ ?a ?c) (/ ?b ?c))"),
|
||||
rewrite!("distribute-max"; "(* ?a (max ?b ?c))" => "(max (* ?a ?b) (* ?a ?c))" if is_const_positive(&["?a"])),
|
||||
// rewrite!("distribute-min"; "(* ?a (min ?b ?c))" => "(min (* ?a ?b) (* ?a ?c))"),
|
||||
rewrite!("distribute-min"; "(* ?a (min ?b ?c))" => "(min (* ?a ?b) (* ?a ?c))"),
|
||||
// rewrite!("distribute-mod"; "(* (% ?b ?c) ?a)" => "(% (* ?b ?a) (* ?c ?a))"),
|
||||
// Factoring
|
||||
rewrite!("factor-mul" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
|
||||
rewrite!("factor-div" ; "(+ (/ ?a ?b) (/ ?a ?c))" => "(/ ?a (+ ?b ?c))"),
|
||||
rewrite!("group-terms"; "(+ ?a ?a)" => "(* 2 ?a)" if is_const_positive(&["?a"])),
|
||||
// rewrite!("factor-div" ; "(+ (/ ?a ?b) (/ ?a ?c))" => "(/ ?a (+ ?b ?c))"),
|
||||
rewrite!("group-terms"; "(+ ?a ?a)" => "(* 2 ?a)"),
|
||||
// Other
|
||||
// rewrite!("explicit-truncate"; "(* (/ ?a ?b) ?b)" => "(- ?a (% ?a ?b))"),
|
||||
// rewrite!("mul-mod"; "(% (* ?a ?b) ?b)" => "0"),
|
||||
rewrite!("div-move-inside"; "(+ (/ ?a ?b) ?c)" => "(/ (+ ?a (* ?c ?b)) ?b)"),
|
||||
// rewrite!("mul-distribute"; "(* ?a (% (/ ?b ?c) ?d))" => "(% (/ ?b (* ?c ?a)) (* ?d ?a))"),
|
||||
// rewrite!("mul-distribute"; "(* ?a (% (/ ?b ?c) ?d))" => "(% (/ ?b (* ?c ?a)) (* ?d ?a))"), // BAD
|
||||
// rewrite!("div-mod-mul"; "(% (/ ?a ?b) ?c)" => "(% ?a (* ?b ?c))"),
|
||||
// Simple binary reductions
|
||||
rewrite!("add-0"; "(+ ?a 0)" => "?a"),
|
||||
@@ -1008,11 +1060,12 @@ fn egg_simplify(e: Expression) -> Expression {
|
||||
let expr = luminal_to_egg(&e);
|
||||
// Simplify
|
||||
let runner = Runner::default()
|
||||
.with_iter_limit(1_000)
|
||||
.with_time_limit(Duration::from_secs(30))
|
||||
.with_node_limit(100_000)
|
||||
// .with_iter_limit(1_000)
|
||||
// .with_time_limit(std::time::Duration::from_secs(30))
|
||||
// .with_node_limit(100_000_000)
|
||||
.with_expr(&expr)
|
||||
.run(&make_rules());
|
||||
// runner.print_report();
|
||||
let extractor = Extractor::new(&runner.egraph, AstSize);
|
||||
let (_, best) = extractor.find_best(runner.roots[0]);
|
||||
// Convert back to luminal expression
|
||||
@@ -1048,7 +1101,7 @@ mod tests {
|
||||
let main = Expression::from('x') - 255;
|
||||
let sub = Expression::from('x') / 2;
|
||||
let new = main.substitute('x', sub).simplify();
|
||||
assert_eq!(new, (Expression::from('x') / 2) + -255);
|
||||
assert_eq!(new.len(), 5);
|
||||
expression_cleanup();
|
||||
}
|
||||
|
||||
@@ -1056,7 +1109,7 @@ mod tests {
|
||||
fn test_group_terms() {
|
||||
let s = Expression::from('s');
|
||||
let expr = (s * ((s - 4) + 1)) + (((s + 1) * ((s - 4) + 1)) - (s * ((s - 4) + 1)));
|
||||
assert_eq!(expr.simplify().terms.read().len(), 7);
|
||||
assert_eq!(expr.simplify().len(), 7);
|
||||
expression_cleanup();
|
||||
}
|
||||
|
||||
@@ -1073,27 +1126,22 @@ mod tests {
|
||||
let z = Expression::from('z');
|
||||
let w = Expression::from('w');
|
||||
let h = Expression::from('h');
|
||||
let x = ((z
|
||||
/ ((Expression::from(-5)
|
||||
+ (((((Expression::from(-5) + ((((((w + 153) / 2) / 2) / 2) / 2) / 2)) * 4)
|
||||
+ 9)
|
||||
/ 2)
|
||||
/ 2))
|
||||
* (Expression::from(-5)
|
||||
+ (((Expression::from(9)
|
||||
+ (4 * (Expression::from(-5)
|
||||
+ ((((((Expression::from(153) + h) / 2) / 2) / 2) / 2) / 2))))
|
||||
/ 2)
|
||||
/ 2))))
|
||||
% 64)
|
||||
.simplify();
|
||||
panic!("{x}")
|
||||
let o = (z
|
||||
/ ((-5 + (((((-5 + ((((((w + 153) / 2) / 2) / 2) / 2) / 2)) * 4) + 9) / 2) / 2))
|
||||
* (-5 + (((9 + (4 * (-5 + ((((((153 + h) / 2) / 2) / 2) / 2) / 2)))) / 2) / 2))))
|
||||
% 64;
|
||||
let x = o.simplify();
|
||||
assert_eq!(x.len(), 23); // Should be 21 if we can re-enable mul-div-associative-rev
|
||||
expression_cleanup();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_final() {
|
||||
let z = Expression::from('z');
|
||||
let w = Expression::from('w');
|
||||
let x = (((w + -7) / 32) + (Expression::from(-11) / 4)).simplify();
|
||||
assert_eq!(x.len(), 5);
|
||||
let h = Expression::from('h');
|
||||
let x = (z % (((((153 + h) / 8) + -31) * ((((w + 153) / 8) + -31) / 16)) * 64)).simplify();
|
||||
assert_eq!(x.len(), 15); // Should be 11 if we can re-enable mul-div-associative-rev
|
||||
expression_cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user