use std::collections::HashMap;
use std::convert::TryInto;
use std::num::NonZero;
pub use logru;
use logru::ast::{Sym, Var};
use logru::search::{Resolved, Resolver, ResolveContext, SolutionState};
use logru::term_arena::{AppTerm, ArgRange, Term, TermId};
use logru::universe::SymbolStorage;
use tracing::{debug, warn};
#[derive(Clone)]
pub struct ArithmeticResolver {
    exp_map: HashMap<Sym, Exp>,
    pred_map: HashMap<Sym, Pred>,
}
impl ArithmeticResolver {
    pub fn new<T: SymbolStorage>(symbols: &mut T) -> Self {
        let exps = [
            ("add", Exp::Add),
            ("sub", Exp::Sub),
            ("mul", Exp::Mul),
            ("div", Exp::Div),
            ("rem", Exp::Rem),
            ("pow", Exp::Pow),
        ];
        let preds = [
            ("is", Pred::Is),
            ("isLess", Pred::IsLess),
            ("isGreater", Pred::IsGreater),
            ("isLessEq", Pred::IsLessEq),
            ("isGreaterEq", Pred::IsGreaterEq),
            ("isDivider", Pred::IsDivider),
            ("isNeg", Pred::IsNeg),
            ("isInRange", Pred::IsInRange),
        ];
        Self {
            exp_map: symbols.build_sym_map(exps),
            pred_map: symbols.build_sym_map(preds),
        }
    }
    fn eval_exp(&self, solution: &SolutionState, exp: TermId) -> Option<i64> {
        match solution.follow_vars(exp).1 {
            Term::Var(_) => None,
            Term::App(AppTerm(sym, arg_range)) => {
                let op = self.exp_map.get(&sym)?;
                let [a1, a2] = solution.terms().get_args_fixed(arg_range)?;
                let v1 = self.eval_exp(solution, a1)?;
                let v2 = self.eval_exp(solution, a2)?;
                let ret = match op {
                    Exp::Add => v1.checked_add(v2)?,
                    Exp::Sub => v1.checked_sub(v2)?,
                    Exp::Mul => v1.checked_mul(v2)?,
                    Exp::Div => v1.checked_div(v2)?,
                    Exp::Rem => v1.checked_rem(v2)?,
                    Exp::Pow => v1.checked_pow(v2.try_into().ok()?)?,
                };
                Some(ret)
            }
            Term::Int(i) => Some(i),
            _ => None,
        }
    }
    
    fn resolve_with_args(
        &mut self,
        context: &mut ResolveContext,
        left: TermId,
        right: TermId,
        op: impl Fn(i64, i64) -> bool,
        var_generator: fn(i64) -> VariableSolutions,
    ) -> Option<Resolved<<Self as Resolver>::Choice>> {
        let right_val = self.eval_exp(context.solution(), right)?;
        if let Some(left_val) = self.eval_exp(context.solution(), left) {
            op(left_val, right_val).then_some(Resolved::Success)
        } else {
            let (_left_id, left_term) = context.solution().follow_vars(left);
            match left_term {
                Term::Var(var) => match var_generator(right_val) {
                    VariableSolutions::None => None,
                    VariableSolutions::TooMany => {
                        warn!("Too many solutions for variable {0:?}, skipping all. {0:?} can only be evaluated in this predicate if it resolves to an integer.", var);
                        None
                    },
                    VariableSolutions::Single(val) => {
                        let result_term = context.solution_mut().terms_mut().int(val);
                        context
                            .solution_mut()
                            .set_var(var, result_term)
                            .then_some(Resolved::Success)
                    },
                }
                Term::Int(left_val) => op(left_val, right_val).then_some(Resolved::Success),
                other => {
                    debug!("Can't evaluate expression {:?}. It must resolve to an integer first.", other);
                    None
                }
            }
        }
    }
    fn resolve_both_sides_with_args(
        &mut self,
        context: &mut ResolveContext,
        left: TermId,
        right: TermId,
        op: impl Fn(i64, i64) -> bool,
        var_generator: fn(i64) -> VariableSolutions,
    ) -> Option<Resolved<<Self as Resolver>::Choice>> {
        self.resolve_with_args(context, left, right, &op, var_generator)
            .or_else(|| self.resolve_with_args(context, right, left, |l, r| op(r,l), var_generator))
    }
    
    fn resolve_is(
        &mut self,
        args: ArgRange,
        context: &mut ResolveContext,
    ) -> Option<Resolved<<Self as Resolver>::Choice>> {
        let [left, right] = context.solution().terms().get_args_fixed(args)?;
        self.resolve_both_sides_with_args(context, left, right, |l, r| l == r, |v| VariableSolutions::Single(v))
    }
    fn resolve_neg(
        &mut self,
        args: ArgRange,
        context: &mut ResolveContext,
    ) -> Option<Resolved<<Self as Resolver>::Choice>> {
        let [left, right] = context.solution().terms().get_args_fixed(args)?;
        self.resolve_both_sides_with_args(context, left, right, |l, r| l == -r, |v| VariableSolutions::Single(-v))
    }
    
    fn resolve_instantiated_op2(
        &mut self,
        args: ArgRange,
        context: &mut ResolveContext,
        op: fn(i64, i64) -> bool,
    ) -> Option<Resolved<<Self as Resolver>::Choice>> {
        let [left, right] = context.solution().terms().get_args_fixed(args)?;
        self.resolve_both_sides_with_args(context, left, right, op, |_| VariableSolutions::TooMany)
    }
    
    fn resolve_range(
        &mut self,
        args: ArgRange,
        context: &mut ResolveContext,
    ) -> Option<Resolved<<Self as Resolver>::Choice>> {
        let [var, min, max, step] = context.solution().terms().get_args_fixed(args)?;
        let eval_expr = |var| {
            if let Some(val) = self.eval_exp(context.solution(), var) {
                Some(val)
            } else {
                let (_id, term) = context.solution().follow_vars(var);
                match term {
                    Term::Int(val) => Some(val),
                    other => {
                        debug!("Can't evaluate expression {:?}. It must resolve to an integer first.", other);
                        None
                    }
                }
            }
        };
        
        let min = eval_expr(min)?;
        let max = eval_expr(max)?;
        let step = eval_expr(step)?;
        let step = step.try_into()
            .ok()
            .and_then(NonZero::new)
                .or_else(|| {
                warn!("Range step must be positive, but resolved to {}", step); None
            })?;
        let range = Range { min, max, step };
        
        if let Some(val) = self.eval_exp(context.solution(), var) {
            range.contains(val).then_some(Resolved::Success)
        } else {
            let (_id, term) = context.solution().follow_vars(var);
            match term {
                Term::Var(var) => {
                    let mut iter = range.into_iter();
                    iter.next()
                        .and_then(|val| {
                            let result_term = context.solution_mut().terms_mut().int(val);
                            context
                                .solution_mut()
                                .set_var(var, result_term)
                                .then_some(Resolved::SuccessRetry((
                                    var,
                                    DynIter::<dyn Iterator<Item=i64>>::new(iter),
                                )))
                        })
                },
                Term::Int(val) => range.contains(val).then_some(Resolved::Success),
                other => {
                    warn!("Can't evaluate expression {:?}. It must resolve to an integer first.", other);
                    None
                }
            }
        }
    }
}
#[derive(Debug)]
struct Range {
    min: i64,
    max: i64,
    step: std::num::NonZero<u64>,
}
impl Range {
    fn contains(&self, value: i64) -> bool {
        value >= self.min && value <= self.max && value.rem_euclid(self.step.get() as i64) == 0
    }
    
    fn into_iter(self) -> impl Iterator<Item=i64> {
        let mut count = Some(self.min);
        std::iter::from_fn(move || {
            if let Some(c) = count.take() {
                count = match c.checked_add(self.step.get() as i64) {
                    Some(new) => (new <= self.max).then_some(new),
                    None => None,
                };
                Some(c)
            } else {
                None
            }
        })
    }
}
pub struct DynIter<T: ?Sized>(Box<T>);
impl<T> DynIter<dyn Iterator<Item=T>> {
    fn new(iter: impl Iterator<Item=T> + 'static) -> Self {
        Self(Box::new(iter))
    }
}
impl<T: ?Sized> std::ops::Deref for DynIter<T> {
    type Target = T;
    fn deref(&self) -> &Self::Target {
        &*self.0
    }
}
impl<T: ?Sized> std::ops::DerefMut for DynIter<T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut *self.0
    }  
}
impl<T: ?Sized> std::fmt::Debug for DynIter<T> {
    fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        panic!()
    }
}
enum VariableSolutions {
    None,
    TooMany,
    Single(i64),
}
#[derive(Clone)]
enum Exp {
    Add,
    Sub,
    Mul,
    Div,
    Rem,
    Pow,
}
#[derive(Clone, Debug)]
enum Pred {
    Is,
    IsLess,
    IsGreater,
    IsLessEq,
    IsGreaterEq,
    IsDivider,
    IsNeg,
    IsInRange,
}
impl Resolver for ArithmeticResolver {
    type Choice = (Var, DynIter<dyn Iterator<Item=i64>>);
    fn resolve(
        &mut self,
        _goal_id: TermId,
        AppTerm(sym, args): AppTerm,
        context: &mut ResolveContext,
    ) -> Option<Resolved<Self::Choice>> {
        let pred = self.pred_map.get(&sym)?;
        match pred {
            Pred::Is => self.resolve_is(args, context),
            Pred::IsLess => self.resolve_instantiated_op2(args, context, |a, b| a < b),
            Pred::IsGreater => self.resolve_instantiated_op2(args, context, |a, b| a > b),
            Pred::IsLessEq => self.resolve_instantiated_op2(args, context, |a, b| a <= b),
            Pred::IsGreaterEq => self.resolve_instantiated_op2(args, context, |a, b| a >= b),
            Pred::IsDivider => self.resolve_instantiated_op2(args, context, |a, b| a % b == 0),
            Pred::IsNeg => self.resolve_neg(args, context),
            Pred::IsInRange => self.resolve_range(args, context),
        }
    }
    fn resume(
        &mut self,
        choice: &mut Self::Choice,
        _goal_id: TermId,
        context: &mut ResolveContext,
    ) -> bool {
        let next = choice.1.next();
        if let Some(val) = next {
            let term_id = context.solution_mut().terms_mut().int(val);
            context.solution_mut().set_var(choice.0, term_id);
            true
        } else {
            false
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use logru::ast::Term;
    use logru::query_dfs;
    use logru::resolve::ResolverExt;
    use logru::search::Solution;
    use logru::textual::TextualUniverse;
    use super::ArithmeticResolver;
    #[test]
    fn simple() {
        let tu = TextualUniverse::new();
        let mut query = tu
            .prepare_query("is(X, add(3, mul(3, sub(6, div(10, rem(10, pow(2,3))))))).")
            .unwrap();
        let resolver = ArithmeticResolver::new(&mut query.symbols_mut());
        let mut results = query_dfs(resolver.or_else(tu.resolver()), query.query());
        assert_eq!(results.next(), Some(Solution(vec![Some(Term::Int(6))])));
        assert!(results.next().is_none());
    }
    fn complex(test: &str, expected: Option<Solution>) {
        let mut tu = TextualUniverse::new();
        let mut arith = ArithmeticResolver::new(&mut tu.symbols);
        tu.load_str(r"eq(Exp1, Exp2) :- is(Exp1, Exp2).").unwrap();
        {
            let query = tu.prepare_query(test).unwrap();
            let mut results = query_dfs(arith.by_ref().or_else(tu.resolver()), query.query());
            assert_eq!(results.next(), expected);
            assert!(results.next().is_none());
        }
    }
    fn complex_multi(test: &str, expected: Vec<Solution>, limit: usize) {
        let mut tu = TextualUniverse::new();
        let mut arith = ArithmeticResolver::new(&mut tu.symbols);
        {
            let query = tu.prepare_query(test).unwrap();
            let results = query_dfs(arith.by_ref().or_else(tu.resolver()), query.query());
            assert_eq!(results.take(limit).collect::<Vec<_>>(), expected);
        }
    }
    
    #[test]
    fn bothsides() {
        complex("eq(add(2, 2), pow(2, 2)).", Some(Solution(vec![])));
    }
    
    #[test]
    fn left_var() {
        complex("eq(X, pow(2, 2)).", Some(Solution(vec![Some(Term::Int(4))])));
    }
    #[test]
    fn right_var() {
        complex("eq(add(2, 2), X).", Some(Solution(vec![Some(Term::Int(4))])));
    }
    #[test]
    fn both_ints() {
        complex("eq(2, 2).", Some(Solution(vec![])));
    }
    
    #[test]
    fn neg_bothsides() {
        complex("isNeg(add(-2, -2), pow(2, 2)).", Some(Solution(vec![])));
    }
    
    #[test]
    fn neg_left_var() {
        complex("isNeg(X, pow(2, 2)).", Some(Solution(vec![Some(Term::Int(-4))])));
    }
    #[test]
    fn neg_right_var() {
        complex("isNeg(add(2, 2), X).", Some(Solution(vec![Some(Term::Int(-4))])));
    }
    #[test]
    fn neg_both_ints() {
        complex("isNeg(-2, 2).", Some(Solution(vec![])));
    }
    
    
    #[test]
    fn less_bothsides() {
        complex("isLess(add(-2, -2), pow(2, 2)).", Some(Solution(vec![])));
    }
    #[test]
    fn less_both_ints() {
        complex("isLess(-2, 2).", Some(Solution(vec![])));
    }
    
    #[test]
    fn range_iter_ends() {
        assert_eq!(
            Range {
                min: 0,
                max: 5,
                step: NonZero::new(1).unwrap(),
            }.into_iter().take(10).collect::<Vec<_>>(),
            vec![0, 1, 2, 3, 4, 5],
        );
        assert_eq!(
            Range {
                min: i64::max_value() - 5,
                max: i64::max_value(),
                step: NonZero::new(1).unwrap(),
            }.into_iter().count(),
            6,
        );
        assert_eq!(
            Range {
                min: i64::max_value() - 1,
                max: i64::max_value(),
                step: NonZero::new(5).unwrap(),
            }.into_iter().count(),
            1,
        );
    }
    #[test]
    fn range_query_instantiate() {
        let int = |i| Solution(vec![Some(Term::Int(i))]);
        complex_multi(
            "isInRange(A, -2, 2, 1).",
            vec![int(-2), int(-1), int(0), int(1), int(2)],
            10,
        );
    }
}