core.sr.ht/srht/flask.py

489 lines
17 KiB
Python

DATE_FORMAT = "%Y-%m-%dT%H:%M:%S+00:00"
from flask import Flask, Response, request, url_for, render_template, redirect
from flask import Blueprint, current_app, g, abort, session as flask_session
from flask import make_response
from enum import Enum
from srht.config import cfg, cfgi, cfgkeys, config, get_origin, get_global_domain
from srht.crypto import fernet
from srht.email import mail_exception
from srht.database import db
from srht.markdown import markdown
from srht.validation import Validation
from datetime import datetime, timedelta
from jinja2 import FileSystemLoader, ChoiceLoader, pass_context
from markupsafe import Markup, escape
from prometheus_client import Histogram, CollectorRegistry, REGISTRY, make_wsgi_app
from prometheus_client.multiprocess import MultiProcessCollector
from timeit import default_timer
from urllib.parse import urlparse, quote, quote_plus
from werkzeug.local import LocalProxy
from werkzeug.routing import UnicodeConverter
try:
from werkzeug.middleware.dispatcher import DispatcherMiddleware
except ImportError:
from werkzeug.wsgi import DispatcherMiddleware
import binascii
import bleach
import decimal
import hashlib
import humanize
import inspect
import json
import locale
import os
import psycopg2.errors
import secrets
import sqlalchemy.exc
import sqlalchemy.orm.exc
import sys
import unicodedata
from functools import update_wrapper
class NamespacedSession:
def __getitem__(self, key):
return flask_session[f"{current_app.site}:{key}"]
def __setitem__(self, key, value):
flask_session[f"{current_app.site}:{key}"] = value
def __delitem__(self, key):
del flask_session[f"{current_app.site}:{key}"]
def get(self, key, *args, **kwargs):
return flask_session.get(f"{current_app.site}:{key}", *args, **kwargs)
def set(self, key, *args, **kwargs):
return flask_session.set(f"{current_app.site}:{key}", *args, **kwargs)
def setdefault(self, key, *args, **kwargs):
return flask_session.setdefault(
f"{current_app.site}:{key}", *args, **kwargs)
def pop(self, key, *args, **kwargs):
return flask_session.pop(f"{current_app.site}:{key}", *args, **kwargs)
_session = NamespacedSession()
session = LocalProxy(lambda: _session)
humanize.time._now = lambda: datetime.utcnow()
try:
locale.setlocale(locale.LC_ALL, 'en_US')
except:
pass
def date_handler(obj):
if hasattr(obj, 'strftime'):
return obj.strftime(DATE_FORMAT)
if isinstance(obj, decimal.Decimal):
return "{:.2f}".format(obj)
if isinstance(obj, Enum):
return obj.name
return obj
def datef(d):
if not d:
return 'Never'
if isinstance(d, timedelta):
return Markup('<span title="{}">{}</span>'.format(
f'{d.seconds} seconds', humanize.naturaldelta(d)))
return Markup('<span title="{}">{}</span>'.format(
d.strftime('%Y-%m-%d %H:%M:%S UTC'),
humanize.naturaltime(d)))
icon_cache = {}
def icon(i, cls=""):
if i in icon_cache:
svg = icon_cache[i]
return Markup(f'<span class="icon icon-{i} {cls}" aria-hidden="true">{svg}</span>')
fa_license = """<!--
Font Awesome Free 5.3.1 by @fontawesome - https://fontawesome.com
License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License)
-->"""
path = os.path.join(current_app.mod_path, 'static', 'icons', i + '.svg')
with open(path) as f:
svg = f.read()
icon_cache[i] = svg
if g and "fa_license" not in g:
svg += fa_license
g.fa_license = True
return Markup(f'<span class="icon icon-{i} {cls}" aria-hidden="true">{svg}</span>')
@pass_context
def coalesce_search_terms(context):
ret = ""
for key in ["search"] + (context.get("search_keys") or []):
val = context.get(key)
if val:
val = quote_plus(val)
ret += f"&{key}={val}"
return ret
@pass_context
def pagination(context):
template = context.environment.get_template("pagination.html")
return Markup(template.render(**context.parent))
def csrf_token():
if '_csrf_token_v2' not in flask_session:
flask_session['_csrf_token_v2'] = binascii.hexlify(os.urandom(64)).decode()
return Markup("""<input
type='hidden'
name='_csrf_token'
value='{}' />""".format(escape(flask_session['_csrf_token_v2'])))
_csrf_bypass_views = set()
_csrf_bypass_blueprints = set()
def csrf_bypass(f):
if isinstance(f, Blueprint):
_csrf_bypass_blueprints.update([f])
else:
view = '.'.join((f.__module__, f.__name__))
_csrf_bypass_views.update([view])
return f
def paginate_query(query, results_per_page=15):
page = request.args.get("page")
total_results = query.count()
total_pages = total_results // results_per_page + 1
if total_results % results_per_page == 0:
total_pages -= 1
if page is not None:
try:
page = int(page) - 1
query = query.offset(page * results_per_page)
except:
page = 0
else:
page = 0
if page < 0:
abort(400)
query = query.limit(results_per_page).all()
return query, {
"total_pages": total_pages,
"page": page + 1,
"total_results": total_results
}
class ModifiedUnicodeConverter(UnicodeConverter):
"""Added ~ and ^ to safe URL characters, otherwise no changes."""
def to_url(self, value):
if not isinstance(value, str):
value = str(value)
return quote(value, safe='/:~^')
class SrhtFlask(Flask):
def __init__(self, site, name,
oauth_service=None, oauth_provider=None, *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.site = site
if os.environ.get("prometheus_multiproc_dir"):
self.metrics_registry = CollectorRegistry()
MultiProcessCollector(self.metrics_registry)
else:
self.metrics_registry = REGISTRY
self.wsgi_app = DispatcherMiddleware(self.wsgi_app, {
"/metrics": make_wsgi_app(registry=self.metrics_registry),
})
self.metrics = type("metrics", tuple(), {
m.describe()[0].name: m
for m in [
Histogram("request_time", "Duration of HTTP requests", [
"method", "route", "status"
]),
]
})
self.url_map.converters['default'] = ModifiedUnicodeConverter
self.url_map.converters['string'] = ModifiedUnicodeConverter
choices = [
FileSystemLoader("templates"),
FileSystemLoader(os.path.join("/etc", self.site, "templates")),
]
mod = __import__(name)
if hasattr(mod, "__path__"):
path = list(mod.__path__)[0]
elif hasattr(mod, "__file__"):
path = os.path.dirname(mod.__file__)
else:
raise Exception("Can't find the module's path, how are you running the app???")
self.mod_path = path
choices.append(FileSystemLoader(os.path.join(path, "templates")))
choices.append(FileSystemLoader(os.path.join(
os.path.dirname(__file__),
"templates"
)))
try:
with open(os.path.join(path, "schema.graphqls")) as f:
self.graphql_schema = f.read()
with open(os.path.join(path, "default_query.graphql")) as f:
self.graphql_query = f.read()
except:
pass
self.jinja_env.filters['date'] = datef
self.jinja_env.globals['pagination'] = pagination
self.jinja_env.globals['icon'] = icon
self.jinja_env.globals['csrf_token'] = csrf_token
self.jinja_loader = ChoiceLoader(choices)
self.jinja_env.add_extension('jinja2.ext.do')
self.secret_key = cfg("sr.ht", "service-key", default=
cfg("sr.ht", "secret-key", default=None))
if self.secret_key is None:
raise Exception("[sr.ht]service-key missing from config")
self.oauth_service = oauth_service
self.oauth_provider = oauth_provider
if self.oauth_service:
from srht.oauth import oauth_blueprint
self.register_blueprint(oauth_blueprint)
from srht.oauth.scope import set_client_id
set_client_id(self.oauth_service.client_id)
# TODO: Remove
self.no_csrf_prefixes = ['/api']
@self.before_request
def _csrf_check():
if request.method != 'POST':
return
if request.blueprint in _csrf_bypass_blueprints:
return
view = self.view_functions.get(request.endpoint)
if not view:
return
view = "{0}.{1}".format(view.__module__, view.__name__)
if view in _csrf_bypass_views:
return
# TODO: Remove
for prefix in self.no_csrf_prefixes:
if request.path.startswith(prefix):
return
token = flask_session.get('_csrf_token_v2', None)
if not token:
abort(403)
if not secrets.compare_digest(token, request.form.get('_csrf_token')):
abort(403)
@self.teardown_appcontext
def expire_db(err):
db.session.expire_all()
@self.errorhandler(500)
def handle_500(e):
if isinstance(e.original_exception, sqlalchemy.exc.InternalError):
e = e.original_exception.orig
if isinstance(e, psycopg2.errors.ReadOnlySqlTransaction):
return render_template("read_only.html")
# shit
try:
from srht.oauth import current_user
user = None
if hasattr(db, 'session'):
db.session.rollback()
if current_user:
user = f"{current_user.canonical_name} " + \
f"<{current_user.email}>"
db.session.close()
mail_exception(e, user=user)
except Exception as e2:
# shit shit
raise e2.with_traceback(e2.__traceback__)
return render_template("internal_error.html"), 500
@self.errorhandler(401)
def handle_401(e):
if request.path.startswith("/api"):
return { "errors": [ { "reason": "401 unauthorized" } ] }, 401
return render_template("unauthorized.html"), 401
@self.errorhandler(404)
def handle_404(e):
if request.path.startswith("/api"):
return { "errors": [ { "reason": "404 not found" } ] }, 404
return render_template("not_found.html"), 404
@self.context_processor
def inject():
root = get_origin(self.site, external=True)
ctx = {
'root': root,
'domain': urlparse(root).netloc,
'app': self,
'len': len,
'any': any,
'str': str,
'request': request,
'url_for': url_for,
'cfg': cfg,
'cfgi': cfgi,
'cfgkeys': cfgkeys,
'get_origin': get_origin,
'valid': Validation(request),
'site': site,
'site_name': cfg("sr.ht", "site-name", default=None),
'environment': cfg("sr.ht", "environment", default="production"),
'network': self.get_network(),
'static_resource': self.static_resource,
'coalesce_search_terms': coalesce_search_terms,
}
try:
from srht.oauth import current_user
user_class = (current_user._get_current_object().__class__
if current_user else None)
ctx = {
**ctx,
'current_user': (user_class.query
.filter(user_class.id == current_user.id)
).one_or_none() if current_user else None,
}
except sqlalchemy.orm.exc.DetachedInstanceError:
pass # Can happen while cleaning up from 500 errors
except sqlalchemy.exc.InvalidRequestError:
pass # Can happen while cleaning up from 500 errors
if self.oauth_service:
ctx.update({
"oauth_url": self.oauth_service.oauth_url(
request.full_path),
"logout_url": "{}/logout?return_to={}{}".format(
get_origin("meta.sr.ht", external=True),
root, quote_plus(request.full_path)),
})
return ctx
@self.teardown_appcontext
def shutdown_session(resp):
db.session.remove()
return resp
@self.template_filter()
def md(text):
return markdown(text)
@self.template_filter()
def extended_md(text, baselevel=1):
return markdown(text, baselevel)
@self.before_request
def get_session_cookie():
# TODO: We could probably speed things up by skipping the
# round-trip until we actually need any user info which isn't
# present in the user's info cookie
cookie = request.cookies.get("sr.ht.unified-login.v1")
if not cookie:
return
user_info = json.loads(fernet.decrypt(cookie.encode()).decode())
g.current_user = self.oauth_service.lookup_user(user_info["name"])
@self.before_request
def begin_track_request():
request._srht_start_time = default_timer()
@self.after_request
def track_request(resp):
if not hasattr(request, "_srht_start_time"):
return resp
self.metrics.request_time.labels(
method=request.method,
route=request.endpoint,
status=resp.status_code,
).observe(max(default_timer() - request._srht_start_time, 0))
return resp
def make_response(self, rv):
# Converts responses from dicts to JSON response objects
response = None
def jsonify_wrap(obj):
jsonification = json.dumps(obj, default=date_handler)
return Response(jsonification, mimetype='application/json')
if isinstance(rv, tuple) and \
(isinstance(rv[0], dict) or isinstance(rv[0], list)):
response = jsonify_wrap(rv[0]), rv[1]
elif isinstance(rv, dict):
response = jsonify_wrap(rv)
elif isinstance(rv, list):
response = jsonify_wrap(rv)
else:
response = rv
response = super(SrhtFlask, self).make_response(response)
global_domain = get_global_domain(self.site)
if "set_current_user" in g and g.set_current_user:
cookie_key = f"sr.ht.unified-login.v1"
if not g.current_user:
# Clear user info cookie
response.set_cookie(cookie_key, "",
domain=global_domain,
httponly=True,
max_age=0)
else:
# Set user info cookie
user_info = g.current_user.to_dict(first_party=True)
user_info = {k:v for k,v in user_info.items() if k not in ['bio', 'location', 'url']}
user_info = json.dumps(user_info)
response.set_cookie(cookie_key,
fernet.encrypt(user_info.encode()).decode(),
domain=global_domain,
httponly=True,
max_age=60 * 60 * 24 * 365)
path = request.path
return response
def static_resource(self, path):
"""
Given /example.ext, hashes the file and returns /example.hash.ext
"""
if not hasattr(self, "static_cache"):
self.static_cache = dict()
if path in self.static_cache:
return self.static_cache[path]
sha256 = hashlib.sha256()
with open(os.path.join(self.mod_path, path), "rb") as f:
sha256.update(f.read())
path, ext = os.path.splitext(path)
self.static_cache[path] = f"{path}.{sha256.hexdigest()[:8]}{ext}"
return self.static_cache[path]
def get_network(self):
return [
s for s in config
if s.endswith(".sr.ht") and s not in [
"paste.sr.ht",
"pages.sr.ht",
"dispatch.sr.ht",
]
]
def cross_origin(f):
"""
Enable CORS headers on a route.
"""
f.required_methods = getattr(f, "required_methods", set())
f.required_methods.add("OPTIONS")
f.provide_automatic_options = False
def wrapped_function(*args, **kwargs):
if request.method == "OPTIONS":
resp = current_app.make_default_options_response()
else:
resp = make_response(f(*args, **kwargs))
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Methods"] = "OPTIONS, GET, POST"
resp.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return resp
return update_wrapper(wrapped_function, f)