Lean4
/-- `findSquares s e` collects all terms of the form `a ^ 2` and `a * a` that appear in `e`
and adds them to the set `s`.
A pair `(i, true)` is added to `s` when `atoms[i]^2` appears in `e`,
and `(i, false)` is added to `s` when `atoms[i]*atoms[i]` appears in `e`. -/
partial def findSquares (s : TreeSet (Nat × Bool) lexOrd.compare) (e : Expr) :
AtomM (TreeSet (Nat × Bool) lexOrd.compare) :=
-- Completely traversing the expression is non-ideal,
-- as we can descend into expressions that could not possibly be seen by `linarith`.
-- As a result we visit expressions with bvars, which then cause panics.
-- Ideally this preprocessor would be reimplemented so it only visits things that could be atoms.
-- In the meantime we just bail out if we ever encounter loose bvars.
if e.hasLooseBVars then return s
else
match e.getAppFnArgs with
| (``HPow.hPow, #[_, _, _, _, a, b]) =>
match b.numeral? with
| some 2 => do
let s ← findSquares s a
let (ai, _) ← AtomM.addAtom a
return (s.insert (ai, true))
| _ => e.foldlM findSquares s
| (``HMul.hMul, #[_, _, _, _, a, b]) => do
let (ai, _) ← AtomM.addAtom a
let (bi, _) ← AtomM.addAtom b
if ai = bi then
do
let s ← findSquares s a
return (s.insert (ai, false))
else
e.foldlM findSquares s
| _ => e.foldlM findSquares s