diff --git a/src/auth/LoginPage.jsx b/src/auth/LoginPage.jsx index 45d7beb..e464172 100644 --- a/src/auth/LoginPage.jsx +++ b/src/auth/LoginPage.jsx @@ -17,6 +17,7 @@ limitations under the License. import React, { useCallback, useRef, useState, useMemo } from "react"; import { useHistory, useLocation, Link } from "react-router-dom"; import { ReactComponent as Logo } from "../icons/LogoLarge.svg"; +import { useClient } from "../ClientContext"; import { FieldRow, InputField, ErrorMessage } from "../input/Input"; import { Button } from "../button"; import { defaultHomeserver, defaultHomeserverHost } from "../matrix-utils"; @@ -27,6 +28,7 @@ import { usePageTitle } from "../usePageTitle"; export function LoginPage() { usePageTitle("Login"); + const { setClient } = useClient(); const [_, login] = useInteractiveLogin(); const [homeserver, setHomeServer] = useState(defaultHomeserver); const usernameRef = useRef(); @@ -44,7 +46,9 @@ export function LoginPage() { setLoading(true); login(homeserver, usernameRef.current.value, passwordRef.current.value) - .then(() => { + .then(([client, session]) => { + setClient(client, session); + if (location.state && location.state.from) { history.push(location.state.from); } else { diff --git a/src/auth/RegisterPage.jsx b/src/auth/RegisterPage.jsx index 6f08acb..d649267 100644 --- a/src/auth/RegisterPage.jsx +++ b/src/auth/RegisterPage.jsx @@ -31,7 +31,8 @@ import { usePageTitle } from "../usePageTitle"; export function RegisterPage() { usePageTitle("Register"); - const { loading, isAuthenticated, isPasswordlessUser, client } = useClient(); + const { loading, isAuthenticated, isPasswordlessUser, client, setClient } = + useClient(); const confirmPasswordRef = useRef(); const history = useHistory(); const location = useLocation(); @@ -68,7 +69,7 @@ export function RegisterPage() { } const recaptchaResponse = await execute(); - const newClient = await register( + const [newClient, session] = await register( userName, password, userName, @@ -84,6 +85,8 @@ export function RegisterPage() { } } } + + setClient(newClient, session); } submit() diff --git a/src/auth/useInteractiveLogin.js b/src/auth/useInteractiveLogin.js index 6cdd4bd..a9fa804 100644 --- a/src/auth/useInteractiveLogin.js +++ b/src/auth/useInteractiveLogin.js @@ -16,49 +16,43 @@ limitations under the License. import matrix, { InteractiveAuth } from "matrix-js-sdk/src/browser-index"; import { useState, useCallback } from "react"; -import { useClient } from "../ClientContext"; import { initClient, defaultHomeserver } from "../matrix-utils"; export function useInteractiveLogin() { - const { setClient } = useClient(); const [state, setState] = useState({ loading: false }); - const auth = useCallback( - async (homeserver, username, password) => { - const authClient = matrix.createClient(homeserver); + const auth = useCallback(async (homeserver, username, password) => { + const authClient = matrix.createClient(homeserver); - const interactiveAuth = new InteractiveAuth({ - matrixClient: authClient, - busyChanged(loading) { - setState((prev) => ({ ...prev, loading })); - }, - async doRequest(_auth, _background) { - return authClient.login("m.login.password", { - identifier: { - type: "m.id.user", - user: username, - }, - password, - }); - }, - }); + const interactiveAuth = new InteractiveAuth({ + matrixClient: authClient, + busyChanged(loading) { + setState((prev) => ({ ...prev, loading })); + }, + async doRequest(_auth, _background) { + return authClient.login("m.login.password", { + identifier: { + type: "m.id.user", + user: username, + }, + password, + }); + }, + }); - const { user_id, access_token, device_id } = - await interactiveAuth.attemptAuth(); + const { user_id, access_token, device_id } = + await interactiveAuth.attemptAuth(); + const session = { user_id, access_token, device_id }; - const client = await initClient({ - baseUrl: defaultHomeserver, - accessToken: access_token, - userId: user_id, - deviceId: device_id, - }); + const client = await initClient({ + baseUrl: defaultHomeserver, + accessToken: access_token, + userId: user_id, + deviceId: device_id, + }); - setClient(client, { user_id, access_token, device_id }); - - return client; - }, - [setClient] - ); + return [client, session]; + }, []); return [state, auth]; } diff --git a/src/auth/useInteractiveRegistration.js b/src/auth/useInteractiveRegistration.js index 583df1c..8e6fbb8 100644 --- a/src/auth/useInteractiveRegistration.js +++ b/src/auth/useInteractiveRegistration.js @@ -16,11 +16,9 @@ limitations under the License. import matrix, { InteractiveAuth } from "matrix-js-sdk/src/browser-index"; import { useState, useEffect, useCallback, useRef } from "react"; -import { useClient } from "../ClientContext"; import { initClient, defaultHomeserver } from "../matrix-utils"; export function useInteractiveRegistration() { - const { setClient } = useClient(); const [state, setState] = useState({ privacyPolicyUrl: "#", loading: false }); const authClientRef = useRef(); @@ -96,16 +94,14 @@ export function useInteractiveRegistration() { session.tempPassword = password; } - setClient(client, session); - const user = client.getUser(client.getUserId()); user.setRawDisplayName(displayName); user.setDisplayName(displayName); - return client; + return [client, session]; }, - [setClient] + [] ); return [state, register]; diff --git a/src/home/UnauthenticatedView.jsx b/src/home/UnauthenticatedView.jsx index 93650b8..50d7d46 100644 --- a/src/home/UnauthenticatedView.jsx +++ b/src/home/UnauthenticatedView.jsx @@ -15,6 +15,7 @@ limitations under the License. */ import React, { useCallback, useState } from "react"; +import { useClient } from "../ClientContext"; import { Header, HeaderLogo, LeftNav, RightNav } from "../Header"; import { UserMenuContainer } from "../UserMenuContainer"; import { useHistory } from "react-router-dom"; @@ -34,12 +35,14 @@ import { generateRandomName } from "../auth/generateRandomName"; import { useShouldShowPtt } from "../useShouldShowPtt"; export function UnauthenticatedView() { + const { setClient } = useClient(); const shouldShowPtt = useShouldShowPtt(); const [loading, setLoading] = useState(false); const [error, setError] = useState(); const [{ privacyPolicyUrl, recaptchaKey }, register] = useInteractiveRegistration(); const { execute, reset, recaptchaId } = useRecaptcha(recaptchaKey); + const onSubmit = useCallback( (e) => { e.preventDefault(); @@ -53,7 +56,7 @@ export function UnauthenticatedView() { setLoading(true); const recaptchaResponse = await execute(); const userName = generateRandomName(); - const client = await register( + const [client, session] = await register( userName, randomString(16), displayName, @@ -62,6 +65,9 @@ export function UnauthenticatedView() { ); const roomIdOrAlias = await createRoom(client, roomName, ptt); + // Only consider the registration successful if we managed to create the room, too + setClient(client, session); + if (roomIdOrAlias) { history.push(`/room/${roomIdOrAlias}`); } diff --git a/src/room/RoomAuthView.jsx b/src/room/RoomAuthView.jsx index 8748790..df7412a 100644 --- a/src/room/RoomAuthView.jsx +++ b/src/room/RoomAuthView.jsx @@ -16,6 +16,7 @@ limitations under the License. import React, { useCallback, useState } from "react"; import styles from "./RoomAuthView.module.css"; +import { useClient } from "../ClientContext"; import { Button } from "../button"; import { Body, Caption, Link, Headline } from "../typography/Typography"; import { Header, HeaderLogo, LeftNav, RightNav } from "../Header"; @@ -29,11 +30,13 @@ import { UserMenuContainer } from "../UserMenuContainer"; import { generateRandomName } from "../auth/generateRandomName"; export function RoomAuthView() { + const { setClient } = useClient(); const [loading, setLoading] = useState(false); const [error, setError] = useState(); const [{ privacyPolicyUrl, recaptchaKey }, register] = useInteractiveRegistration(); const { execute, reset, recaptchaId } = useRecaptcha(recaptchaKey); + const onSubmit = useCallback( (e) => { e.preventDefault(); @@ -45,13 +48,14 @@ export function RoomAuthView() { setLoading(true); const recaptchaResponse = await execute(); const userName = generateRandomName(); - await register( + const [client, session] = await register( userName, randomString(16), displayName, recaptchaResponse, true ); + setClient(client, session); } submit().catch((error) => {