use std::collections::HashMap;
use std::convert::TryInto;
use std::num::NonZero;
pub use logru;
use logru::ast::{Sym, Var};
use logru::search::{Resolved, Resolver, ResolveContext, SolutionState};
use logru::term_arena::{AppTerm, ArgRange, Term, TermId};
use logru::universe::SymbolStorage;
use tracing::{debug, warn};
#[derive(Clone)]
pub struct ArithmeticResolver {
exp_map: HashMap<Sym, Exp>,
pred_map: HashMap<Sym, Pred>,
}
impl ArithmeticResolver {
pub fn new<T: SymbolStorage>(symbols: &mut T) -> Self {
let exps = [
("add", Exp::Add),
("sub", Exp::Sub),
("mul", Exp::Mul),
("div", Exp::Div),
("rem", Exp::Rem),
("pow", Exp::Pow),
];
let preds = [
("is", Pred::Is),
("isLess", Pred::IsLess),
("isGreater", Pred::IsGreater),
("isLessEq", Pred::IsLessEq),
("isGreaterEq", Pred::IsGreaterEq),
("isDivider", Pred::IsDivider),
("isNeg", Pred::IsNeg),
("isInRange", Pred::IsInRange),
];
Self {
exp_map: symbols.build_sym_map(exps),
pred_map: symbols.build_sym_map(preds),
}
}
fn eval_exp(&self, solution: &SolutionState, exp: TermId) -> Option<i64> {
match solution.follow_vars(exp).1 {
Term::Var(_) => None,
Term::App(AppTerm(sym, arg_range)) => {
let op = self.exp_map.get(&sym)?;
let [a1, a2] = solution.terms().get_args_fixed(arg_range)?;
let v1 = self.eval_exp(solution, a1)?;
let v2 = self.eval_exp(solution, a2)?;
let ret = match op {
Exp::Add => v1.checked_add(v2)?,
Exp::Sub => v1.checked_sub(v2)?,
Exp::Mul => v1.checked_mul(v2)?,
Exp::Div => v1.checked_div(v2)?,
Exp::Rem => v1.checked_rem(v2)?,
Exp::Pow => v1.checked_pow(v2.try_into().ok()?)?,
};
Some(ret)
}
Term::Int(i) => Some(i),
_ => None,
}
}
fn resolve_with_args(
&mut self,
context: &mut ResolveContext,
left: TermId,
right: TermId,
op: impl Fn(i64, i64) -> bool,
var_generator: fn(i64) -> VariableSolutions,
) -> Option<Resolved<<Self as Resolver>::Choice>> {
let right_val = self.eval_exp(context.solution(), right)?;
if let Some(left_val) = self.eval_exp(context.solution(), left) {
op(left_val, right_val).then_some(Resolved::Success)
} else {
let (_left_id, left_term) = context.solution().follow_vars(left);
match left_term {
Term::Var(var) => match var_generator(right_val) {
VariableSolutions::None => None,
VariableSolutions::TooMany => {
warn!("Too many solutions for variable {0:?}, skipping all. {0:?} can only be evaluated in this predicate if it resolves to an integer.", var);
None
},
VariableSolutions::Single(val) => {
let result_term = context.solution_mut().terms_mut().int(val);
context
.solution_mut()
.set_var(var, result_term)
.then_some(Resolved::Success)
},
}
Term::Int(left_val) => op(left_val, right_val).then_some(Resolved::Success),
other => {
debug!("Can't evaluate expression {:?}. It must resolve to an integer first.", other);
None
}
}
}
}
fn resolve_both_sides_with_args(
&mut self,
context: &mut ResolveContext,
left: TermId,
right: TermId,
op: impl Fn(i64, i64) -> bool,
var_generator: fn(i64) -> VariableSolutions,
) -> Option<Resolved<<Self as Resolver>::Choice>> {
self.resolve_with_args(context, left, right, &op, var_generator)
.or_else(|| self.resolve_with_args(context, right, left, |l, r| op(r,l), var_generator))
}
fn resolve_is(
&mut self,
args: ArgRange,
context: &mut ResolveContext,
) -> Option<Resolved<<Self as Resolver>::Choice>> {
let [left, right] = context.solution().terms().get_args_fixed(args)?;
self.resolve_both_sides_with_args(context, left, right, |l, r| l == r, |v| VariableSolutions::Single(v))
}
fn resolve_neg(
&mut self,
args: ArgRange,
context: &mut ResolveContext,
) -> Option<Resolved<<Self as Resolver>::Choice>> {
let [left, right] = context.solution().terms().get_args_fixed(args)?;
self.resolve_both_sides_with_args(context, left, right, |l, r| l == -r, |v| VariableSolutions::Single(-v))
}
fn resolve_instantiated_op2(
&mut self,
args: ArgRange,
context: &mut ResolveContext,
op: fn(i64, i64) -> bool,
) -> Option<Resolved<<Self as Resolver>::Choice>> {
let [left, right] = context.solution().terms().get_args_fixed(args)?;
self.resolve_both_sides_with_args(context, left, right, op, |_| VariableSolutions::TooMany)
}
fn resolve_range(
&mut self,
args: ArgRange,
context: &mut ResolveContext,
) -> Option<Resolved<<Self as Resolver>::Choice>> {
let [var, min, max, step] = context.solution().terms().get_args_fixed(args)?;
let eval_expr = |var| {
if let Some(val) = self.eval_exp(context.solution(), var) {
Some(val)
} else {
let (_id, term) = context.solution().follow_vars(var);
match term {
Term::Int(val) => Some(val),
other => {
debug!("Can't evaluate expression {:?}. It must resolve to an integer first.", other);
None
}
}
}
};
let min = eval_expr(min)?;
let max = eval_expr(max)?;
let step = eval_expr(step)?;
let step = step.try_into()
.ok()
.and_then(NonZero::new)
.or_else(|| {
warn!("Range step must be positive, but resolved to {}", step); None
})?;
let range = Range { min, max, step };
if let Some(val) = self.eval_exp(context.solution(), var) {
range.contains(val).then_some(Resolved::Success)
} else {
let (_id, term) = context.solution().follow_vars(var);
match term {
Term::Var(var) => {
let mut iter = range.into_iter();
iter.next()
.and_then(|val| {
let result_term = context.solution_mut().terms_mut().int(val);
context
.solution_mut()
.set_var(var, result_term)
.then_some(Resolved::SuccessRetry((
var,
DynIter::<dyn Iterator<Item=i64>>::new(iter),
)))
})
},
Term::Int(val) => range.contains(val).then_some(Resolved::Success),
other => {
warn!("Can't evaluate expression {:?}. It must resolve to an integer first.", other);
None
}
}
}
}
}
#[derive(Debug)]
struct Range {
min: i64,
max: i64,
step: std::num::NonZero<u64>,
}
impl Range {
fn contains(&self, value: i64) -> bool {
value >= self.min && value <= self.max && value.rem_euclid(self.step.get() as i64) == 0
}
fn into_iter(self) -> impl Iterator<Item=i64> {
let mut count = Some(self.min);
std::iter::from_fn(move || {
if let Some(c) = count.take() {
count = match c.checked_add(self.step.get() as i64) {
Some(new) => (new <= self.max).then_some(new),
None => None,
};
Some(c)
} else {
None
}
})
}
}
pub struct DynIter<T: ?Sized>(Box<T>);
impl<T> DynIter<dyn Iterator<Item=T>> {
fn new(iter: impl Iterator<Item=T> + 'static) -> Self {
Self(Box::new(iter))
}
}
impl<T: ?Sized> std::ops::Deref for DynIter<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl<T: ?Sized> std::ops::DerefMut for DynIter<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.0
}
}
impl<T: ?Sized> std::fmt::Debug for DynIter<T> {
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
panic!()
}
}
enum VariableSolutions {
None,
TooMany,
Single(i64),
}
#[derive(Clone)]
enum Exp {
Add,
Sub,
Mul,
Div,
Rem,
Pow,
}
#[derive(Clone, Debug)]
enum Pred {
Is,
IsLess,
IsGreater,
IsLessEq,
IsGreaterEq,
IsDivider,
IsNeg,
IsInRange,
}
impl Resolver for ArithmeticResolver {
type Choice = (Var, DynIter<dyn Iterator<Item=i64>>);
fn resolve(
&mut self,
_goal_id: TermId,
AppTerm(sym, args): AppTerm,
context: &mut ResolveContext,
) -> Option<Resolved<Self::Choice>> {
let pred = self.pred_map.get(&sym)?;
match pred {
Pred::Is => self.resolve_is(args, context),
Pred::IsLess => self.resolve_instantiated_op2(args, context, |a, b| a < b),
Pred::IsGreater => self.resolve_instantiated_op2(args, context, |a, b| a > b),
Pred::IsLessEq => self.resolve_instantiated_op2(args, context, |a, b| a <= b),
Pred::IsGreaterEq => self.resolve_instantiated_op2(args, context, |a, b| a >= b),
Pred::IsDivider => self.resolve_instantiated_op2(args, context, |a, b| a % b == 0),
Pred::IsNeg => self.resolve_neg(args, context),
Pred::IsInRange => self.resolve_range(args, context),
}
}
fn resume(
&mut self,
choice: &mut Self::Choice,
_goal_id: TermId,
context: &mut ResolveContext,
) -> bool {
let next = choice.1.next();
if let Some(val) = next {
let term_id = context.solution_mut().terms_mut().int(val);
context.solution_mut().set_var(choice.0, term_id);
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use logru::ast::Term;
use logru::query_dfs;
use logru::resolve::ResolverExt;
use logru::search::Solution;
use logru::textual::TextualUniverse;
use super::ArithmeticResolver;
#[test]
fn simple() {
let tu = TextualUniverse::new();
let mut query = tu
.prepare_query("is(X, add(3, mul(3, sub(6, div(10, rem(10, pow(2,3))))))).")
.unwrap();
let resolver = ArithmeticResolver::new(&mut query.symbols_mut());
let mut results = query_dfs(resolver.or_else(tu.resolver()), query.query());
assert_eq!(results.next(), Some(Solution(vec![Some(Term::Int(6))])));
assert!(results.next().is_none());
}
fn complex(test: &str, expected: Option<Solution>) {
let mut tu = TextualUniverse::new();
let mut arith = ArithmeticResolver::new(&mut tu.symbols);
tu.load_str(r"eq(Exp1, Exp2) :- is(Exp1, Exp2).").unwrap();
{
let query = tu.prepare_query(test).unwrap();
let mut results = query_dfs(arith.by_ref().or_else(tu.resolver()), query.query());
assert_eq!(results.next(), expected);
assert!(results.next().is_none());
}
}
fn complex_multi(test: &str, expected: Vec<Solution>, limit: usize) {
let mut tu = TextualUniverse::new();
let mut arith = ArithmeticResolver::new(&mut tu.symbols);
{
let query = tu.prepare_query(test).unwrap();
let results = query_dfs(arith.by_ref().or_else(tu.resolver()), query.query());
assert_eq!(results.take(limit).collect::<Vec<_>>(), expected);
}
}
#[test]
fn bothsides() {
complex("eq(add(2, 2), pow(2, 2)).", Some(Solution(vec![])));
}
#[test]
fn left_var() {
complex("eq(X, pow(2, 2)).", Some(Solution(vec![Some(Term::Int(4))])));
}
#[test]
fn right_var() {
complex("eq(add(2, 2), X).", Some(Solution(vec![Some(Term::Int(4))])));
}
#[test]
fn both_ints() {
complex("eq(2, 2).", Some(Solution(vec![])));
}
#[test]
fn neg_bothsides() {
complex("isNeg(add(-2, -2), pow(2, 2)).", Some(Solution(vec![])));
}
#[test]
fn neg_left_var() {
complex("isNeg(X, pow(2, 2)).", Some(Solution(vec![Some(Term::Int(-4))])));
}
#[test]
fn neg_right_var() {
complex("isNeg(add(2, 2), X).", Some(Solution(vec![Some(Term::Int(-4))])));
}
#[test]
fn neg_both_ints() {
complex("isNeg(-2, 2).", Some(Solution(vec![])));
}
#[test]
fn less_bothsides() {
complex("isLess(add(-2, -2), pow(2, 2)).", Some(Solution(vec![])));
}
#[test]
fn less_both_ints() {
complex("isLess(-2, 2).", Some(Solution(vec![])));
}
#[test]
fn range_iter_ends() {
assert_eq!(
Range {
min: 0,
max: 5,
step: NonZero::new(1).unwrap(),
}.into_iter().take(10).collect::<Vec<_>>(),
vec![0, 1, 2, 3, 4, 5],
);
assert_eq!(
Range {
min: i64::max_value() - 5,
max: i64::max_value(),
step: NonZero::new(1).unwrap(),
}.into_iter().count(),
6,
);
assert_eq!(
Range {
min: i64::max_value() - 1,
max: i64::max_value(),
step: NonZero::new(5).unwrap(),
}.into_iter().count(),
1,
);
}
#[test]
fn range_query_instantiate() {
let int = |i| Solution(vec![Some(Term::Int(i))]);
complex_multi(
"isInRange(A, -2, 2, 1).",
vec![int(-2), int(-1), int(0), int(1), int(2)],
10,
);
}
}