Lean4
/-- Lagrange's **four squares theorem** for a prime number. Use `Nat.sum_four_squares` instead. -/
protected theorem sum_four_squares {p : ℕ} (hp : p.Prime) : ∃ a b c d : ℕ, a ^ 2 + b ^ 2 + c ^ 2 + d ^ 2 = p := by
classical
have := Fact.mk hp
have natAbs_iff {a b c d : ℤ} {k : ℕ} :
a.natAbs ^ 2 + b.natAbs ^ 2 + c.natAbs ^ 2 + d.natAbs ^ 2 = k ↔ a ^ 2 + b ^ 2 + c ^ 2 + d ^ 2 = k := by
rw [← @Nat.cast_inj ℤ]; push_cast [sq_abs]; rfl
have hm : ∃ m < p, 0 < m ∧ ∃ a b c d : ℕ, a ^ 2 + b ^ 2 + c ^ 2 + d ^ 2 = m * p :=
by
obtain ⟨a, b, k, hk₀, hkp, hk⟩ := exists_sq_add_sq_add_one_eq_mul p
refine ⟨k, hkp, hk₀, a, b, 1, 0, ?_⟩
simpa
-- Take the minimal possible `m`
rcases Nat.findX hm with
⟨m, ⟨hmp, hm₀, a, b, c, d, habcd⟩, hmin⟩
-- If `m = 1`, then we are done
rcases (Nat.one_le_iff_ne_zero.2 hm₀.ne').eq_or_lt with rfl | hm₁
· use a, b, c, d; simpa using habcd
exfalso
have : NeZero m := ⟨hm₀.ne'⟩
by_cases hm : 2 ∣ m
· -- If `m` is an even number, then `(m / 2) * p` can be represented as a sum of four squares
rcases hm with ⟨m, rfl⟩
rw [mul_pos_iff_of_pos_left two_pos] at hm₀
have hm₂ : m < 2 * m := by simpa [two_mul]
apply_fun (Nat.cast : ℕ → ℤ) at habcd
push_cast [mul_assoc] at habcd
obtain ⟨_, _, _, _, h⟩ := sum_four_squares_of_two_mul_sum_four_squares habcd
exact hmin m hm₂ ⟨hm₂.trans hmp, hm₀, _, _, _, _, natAbs_iff.2 h⟩
· -- For each `x` in `a`, `b`, `c`, `d`, take a number `f x ≡ x [ZMOD m]` with least possible
-- absolute value
obtain ⟨f, hf_lt, hf_mod⟩ : ∃ f : ℕ → ℤ, (∀ x, 2 * (f x).natAbs < m) ∧ ∀ x, (f x : ZMod m) = x :=
by
refine ⟨fun x ↦ (x : ZMod m).valMinAbs, fun x ↦ ?_, fun x ↦ (x : ZMod m).coe_valMinAbs⟩
exact
(mul_le_mul' le_rfl (x : ZMod m).natAbs_valMinAbs_le).trans_lt
(Nat.mul_div_lt_iff_not_dvd.2 hm)
-- Since `|f x| ^ 2 = (f x) ^ 2 ≡ x ^ 2 [ZMOD m]`, we have
-- `m ∣ |f a| ^ 2 + |f b| ^ 2 + |f c| ^ 2 + |f d| ^ 2`
obtain ⟨r, hr⟩ : m ∣ (f a).natAbs ^ 2 + (f b).natAbs ^ 2 + (f c).natAbs ^ 2 + (f d).natAbs ^ 2 :=
by
simp only [← Int.natCast_dvd_natCast, ← ZMod.intCast_zmod_eq_zero_iff_dvd]
push_cast [hf_mod, sq_abs]
norm_cast
simp [habcd]
-- The quotient `r` is not zero, because otherwise `f a = f b = f c = f d = 0`, hence
-- `m` divides each `a`, `b`, `c`, `d`, thus `m ∣ p` which is impossible.
rcases (zero_le r).eq_or_lt with rfl | hr₀
· replace hr : f a = 0 ∧ f b = 0 ∧ f c = 0 ∧ f d = 0 := by simpa [and_assoc] using hr
obtain ⟨⟨a, rfl⟩, ⟨b, rfl⟩, ⟨c, rfl⟩, ⟨d, rfl⟩⟩ : m ∣ a ∧ m ∣ b ∧ m ∣ c ∧ m ∣ d := by
simp only [← ZMod.natCast_eq_zero_iff, ← hf_mod, hr, Int.cast_zero, and_self]
have : m * m ∣ m * p := habcd ▸ ⟨a ^ 2 + b ^ 2 + c ^ 2 + d ^ 2, by ring⟩
rw [mul_dvd_mul_iff_left hm₀.ne'] at this
exact (hp.eq_one_or_self_of_dvd _ this).elim hm₁.ne' hmp.ne
have hrm : r < m := by
rw [mul_comm] at hr
apply lt_of_sum_four_squares_eq_mul hr <;> apply hf_lt
rsuffices ⟨w, x, y, z, hw, hx, hy, hz, h⟩ :
∃ w x y z : ℤ, ↑m ∣ w ∧ ↑m ∣ x ∧ ↑m ∣ y ∧ ↑m ∣ z ∧ w ^ 2 + x ^ 2 + y ^ 2 + z ^ 2 = ↑(m * r) * ↑(m * p)
· have : (w / m) ^ 2 + (x / m) ^ 2 + (y / m) ^ 2 + (z / m) ^ 2 = ↑(r * p) :=
by
refine mul_left_cancel₀ (pow_ne_zero 2 (Nat.cast_ne_zero.2 hm₀.ne')) ?_
conv_rhs => rw [← Nat.cast_pow, ← Nat.cast_mul, sq m, mul_mul_mul_comm, Nat.cast_mul, ← h]
simp only [mul_add, ← mul_pow, Int.mul_ediv_cancel', *]
rw [← natAbs_iff] at this
exact
hmin r hrm
⟨hrm.trans hmp, hr₀, _, _, _, _, this⟩
-- To do the last step, we apply the Euler's four square identity once more
replace hr : (f b) ^ 2 + (f a) ^ 2 + (f d) ^ 2 + (-f c) ^ 2 = ↑(m * r) :=
by
rw [← natAbs_iff, natAbs_neg, ← hr]
ac_rfl
have := congr_arg₂ (· * Nat.cast ·) hr habcd
simp only [← _root_.euler_four_squares, Nat.cast_add, Nat.cast_pow] at this
refine ⟨_, _, _, _, ?_, ?_, ?_, ?_, this⟩
· simp [← ZMod.intCast_zmod_eq_zero_iff_dvd, hf_mod, mul_comm]
· suffices ((a : ZMod m) ^ 2 + (b : ZMod m) ^ 2 + (c : ZMod m) ^ 2 + (d : ZMod m) ^ 2) = 0 by
simpa [← ZMod.intCast_zmod_eq_zero_iff_dvd, hf_mod, sq, add_comm, add_assoc, add_left_comm] using this
norm_cast
simp [habcd]
· simp [← ZMod.intCast_zmod_eq_zero_iff_dvd, hf_mod, mul_comm]
· simp [← ZMod.intCast_zmod_eq_zero_iff_dvd, hf_mod, mul_comm]