Lean4
/-- Implementation function for `additiveTest`.
Failure means that in that subexpression there is no constant that blocks `e` from being translated.
We cache previous applications of the function, using an expression cache using ptr equality
to avoid visiting the same subexpression many times. Note that we only need to cache the
expressions without taking the value of `inApp` into account, since `inApp` only matters when
the expression is a constant. However, for this reason we have to make sure that we never
cache constant expressions, so that's why the `if`s in the implementation are in this order.
Note that this function is still called many times by `applyReplacementFun`
and we're not remembering the cache between these calls. -/
unsafe def additiveTestUnsafe (env : Environment) (e : Expr) (dontTranslate : Array FVarId) : Option (Name ⊕ FVarId) :=
let rec visit (e : Expr) (inApp := false) : OptionT (StateM (PtrSet Expr)) (Name ⊕ FVarId) :=
(do
if e.isConst then
if (dontTranslateAttr.find? env e.constName).isNone && (inApp || (findTranslation? env e.constName).isSome) then
failure
else
return .inl e.constName
if (← get).contains e then
failure
modify fun s => s.insert e
match e with
| x@(.app e a) =>
visit e true <|> do
-- make sure that we don't treat `(fun x => α) (n + 1)` as a type that depends on `Nat`
guard !x.isConstantApplication
if let some n := e.getAppFn.constName? then
if let some l := ignoreArgsAttr.find? env n then
if e.getAppNumArgs + 1 ∈ l then
failure
visit a
| .lam _ _ t _ =>
visit t
| .forallE _ _ t _ =>
visit t
| .letE _ _ e body _ =>
visit e <|> visit body
| .mdata _ b =>
visit b
| .proj _ _ b =>
visit b
| .fvar fvarId =>
if dontTranslate.contains fvarId then
return .inr fvarId
else
failure
| _ =>
failure)
Id.run <| (visit e).run' mkPtrSet