// the return pair: the length of longest increasing/decreasing sequence at root pair<int, int> traverse(TreeNode* root){ if (!root) { return {0, 0}; }
int inc = 1, dec = 1, l = 1, r = 1; if (root->left) { auto left = traverse(root->left); if (root->val == root->left->val + 1) { inc += left.first; l = max(l, 1 + left.first); } elseif (root->val == root->left->val - 1) { dec += left.second; r = max(r, 1 + left.second); } } if (root->right) { auto right = traverse(root->right); if (root->val == root->right->val + 1) { dec += right.first; l = max(l, 1 + right.first); } elseif (root->val == root->right->val - 1) { inc += right.second; r = max(r, 1 + right.second); } } res = max(res, max(inc, dec)); return {l, r}; } };