# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
def height(root):
if not root:
return 0
l_height = height(root.left)
r_height = height(root.right)
return max(l_height, r_height) + 1
return height(root.left) + height(root.right)
你的实现有两个问题:
- 只计算了经过根节点的最长路径
你最后直接返回了
height(root.left) + height(root.right)
,这只算出了一条“从最深左子节点经根到最深右子节点”的路径长度,但真正的直径(diameter)可能出现在某个子树内部,并不一定经过根节点。 - 高度函数返回的是“节点数”而非“边数”
你定义的
height(root)
对于叶子节点返回1
,这表示“节点数”。可是题目中“路径长度”是按 边 数来算的:- 如果把高度定义为“节点数”,那么真正的边数要
height−1
。 - 你直接把两边的“节点数”相加,结果会比边数多 2。
- 如果把高度定义为“节点数”,那么真正的边数要
改进思路
- 使用一次 DFS(后序遍历)来同时计算高度和更新全局最大直径。
- 在访问每个节点时,用左右子树高度的和(这次我们把高度定义为“边数”)去更新一个全局变量
ans
。 - 最后返回
ans
即可。
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
self.ans = 0 # 全局最大直径(以边为单位)
def height(node: Optional[TreeNode]) -> int:
# 返回节点 node 为根的子树的高度(边数)
if not node:
return 0
# 左、右子树的高度(边数)
lh = height(node.left)
rh = height(node.right)
# 更新直径:如果这条“穿过当前节点”的路径更长,就记录下来
self.ans = max(self.ans, lh + rh)
# 返回当前节点的高度(最长子路径 + 1 条边)
return max(lh, rh) + 1
height(root)
return self.ans
解释
self.ans
记录了遇到过的最大lh + rh
值lh
是左子树的最大边数,rh
是右子树的最大边数。- 它们相加正好是“从左最深叶子节点到右最深叶子节点”这条路径的边数。
height
函数返回的是“边的数量”- 对于空节点,定义高度(边数)为 0。
- 对于非空节点,高度 = 子树最高高度 + 1 条从当前节点到子树根的边。
- 递归遍历所有节点
- 每访问到一个节点,就尝试用它作为“中转点”来更新直径。
- 这样就不会漏掉任何出现在子树内部、但不经过根节点的最长路径。
这样既保证了直径一定会被更新到全局最大,又保证计算的是“边数”而不是“节点数”。
边和节点的区别
其实这里的“多算 2”是基于一种把 height
纯粹当做“节点数”来理解的思路——但在我们的递归里,不同的理解会得出不同的结论。你我俩的 height
函数代码是一模一样的,都是
if not node:
return 0
lh = height(node.left)
rh = height(node.right)
return max(lh, rh) + 1
- 把它看成“节点数”:
height(node)
就是从node
到最深叶子一共有多少个节点。 - 把它看成“父节点到子树最深处的边数”:当你在父节点里调用
height(node.left)
时,那个+1
其实正好把那条“从父到左子节点”这一条边算进去了。
所以——
height(node.left)
恰好等于“从当前节点(父)到左子树最深叶子的边数”,height(node.right)
恰好等于“从当前节点(父)到右子树最深叶子的边数”。
把它们相加,直接就是“穿过这个节点的最长路径的边数”,不会再多算 2。
当初说“多算 2”其实是把 height
当成“纯节点数”来减法——节点数要转成边数确实要 –1
,左边 –1
、右边 –1
,再加回中间那两条从父到子,就会让人觉得“多算 2”。但仔细想,height(child)
已经内含了“父→子”这一条边,就不需要再减。
真正让你原来代码跑不出正确答案的,是你只在根节点做了一次 height(root.left)+height(root.right)
,而没有在每个节点都去尝试更新一次直径,所以漏掉了那些不经过根节点的最长路径**。单位(边数)方面,其实你我的写法是一致的。
class Solution:
def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
ans = 0
def dfs(node: Optional[TreeNode]) -> int:
if node is None:
return -1 # 对于叶子来说,链长就是 -1+1=0
l_len = dfs(node.left) + 1 # 左子树最大链长+1
r_len = dfs(node.right) + 1 # 右子树最大链长+1
nonlocal ans
ans = max(ans, l_len + r_len) # 两条链拼成路径
return max(l_len, r_len) # 当前子树最大链长
dfs(root)
return ans
作者:灵茶山艾府
链接:https://leetcode.cn/problems/diameter-of-binary-tree/description/
来源:力扣(LeetCode)
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。