Add vendored packages back

This commit is contained in:
Massimiliano Culpo 2022-11-16 20:43:44 +01:00 committed by Todd Gamblin
parent 5175189412
commit 033cb86fd6
154 changed files with 41107 additions and 0 deletions

View file

@ -0,0 +1 @@
__version__ = '0.18.0'

View file

@ -0,0 +1 @@
from _pyrsistent_version import *

View file

@ -0,0 +1 @@
from altgraph import *

View file

@ -0,0 +1,321 @@
"""
altgraph.Dot - Interface to the dot language
============================================
The :py:mod:`~altgraph.Dot` module provides a simple interface to the
file format used in the
`graphviz <http://www.research.att.com/sw/tools/graphviz/>`_
program. The module is intended to offload the most tedious part of the process
(the **dot** file generation) while transparently exposing most of its
features.
To display the graphs or to generate image files the
`graphviz <http://www.research.att.com/sw/tools/graphviz/>`_
package needs to be installed on the system, moreover the :command:`dot` and
:command:`dotty` programs must be accesible in the program path so that they
can be ran from processes spawned within the module.
Example usage
-------------
Here is a typical usage::
from altgraph import Graph, Dot
# create a graph
edges = [ (1,2), (1,3), (3,4), (3,5), (4,5), (5,4) ]
graph = Graph.Graph(edges)
# create a dot representation of the graph
dot = Dot.Dot(graph)
# display the graph
dot.display()
# save the dot representation into the mydot.dot file
dot.save_dot(file_name='mydot.dot')
# save dot file as gif image into the graph.gif file
dot.save_img(file_name='graph', file_type='gif')
Directed graph and non-directed graph
-------------------------------------
Dot class can use for both directed graph and non-directed graph
by passing ``graphtype`` parameter.
Example::
# create directed graph(default)
dot = Dot.Dot(graph, graphtype="digraph")
# create non-directed graph
dot = Dot.Dot(graph, graphtype="graph")
Customizing the output
----------------------
The graph drawing process may be customized by passing
valid :command:`dot` parameters for the nodes and edges. For a list of all
parameters see the `graphviz <http://www.research.att.com/sw/tools/graphviz/>`_
documentation.
Example::
# customizing the way the overall graph is drawn
dot.style(size='10,10', rankdir='RL', page='5, 5' , ranksep=0.75)
# customizing node drawing
dot.node_style(1, label='BASE_NODE',shape='box', color='blue' )
dot.node_style(2, style='filled', fillcolor='red')
# customizing edge drawing
dot.edge_style(1, 2, style='dotted')
dot.edge_style(3, 5, arrowhead='dot', label='binds', labelangle='90')
dot.edge_style(4, 5, arrowsize=2, style='bold')
.. note::
dotty (invoked via :py:func:`~altgraph.Dot.display`) may not be able to
display all graphics styles. To verify the output save it to an image file
and look at it that way.
Valid attributes
----------------
- dot styles, passed via the :py:meth:`Dot.style` method::
rankdir = 'LR' (draws the graph horizontally, left to right)
ranksep = number (rank separation in inches)
- node attributes, passed via the :py:meth:`Dot.node_style` method::
style = 'filled' | 'invisible' | 'diagonals' | 'rounded'
shape = 'box' | 'ellipse' | 'circle' | 'point' | 'triangle'
- edge attributes, passed via the :py:meth:`Dot.edge_style` method::
style = 'dashed' | 'dotted' | 'solid' | 'invis' | 'bold'
arrowhead = 'box' | 'crow' | 'diamond' | 'dot' | 'inv' | 'none'
| 'tee' | 'vee'
weight = number (the larger the number the closer the nodes will be)
- valid `graphviz colors
<http://www.research.att.com/~erg/graphviz/info/colors.html>`_
- for more details on how to control the graph drawing process see the
`graphviz reference
<http://www.research.att.com/sw/tools/graphviz/refs.html>`_.
"""
import os
import warnings
from altgraph import GraphError
class Dot(object):
"""
A class providing a **graphviz** (dot language) representation
allowing a fine grained control over how the graph is being
displayed.
If the :command:`dot` and :command:`dotty` programs are not in the current
system path their location needs to be specified in the contructor.
"""
def __init__(
self,
graph=None,
nodes=None,
edgefn=None,
nodevisitor=None,
edgevisitor=None,
name="G",
dot="dot",
dotty="dotty",
neato="neato",
graphtype="digraph",
):
"""
Initialization.
"""
self.name, self.attr = name, {}
assert graphtype in ["graph", "digraph"]
self.type = graphtype
self.temp_dot = "tmp_dot.dot"
self.temp_neo = "tmp_neo.dot"
self.dot, self.dotty, self.neato = dot, dotty, neato
# self.nodes: node styles
# self.edges: edge styles
self.nodes, self.edges = {}, {}
if graph is not None and nodes is None:
nodes = graph
if graph is not None and edgefn is None:
def edgefn(node, graph=graph):
return graph.out_nbrs(node)
if nodes is None:
nodes = ()
seen = set()
for node in nodes:
if nodevisitor is None:
style = {}
else:
style = nodevisitor(node)
if style is not None:
self.nodes[node] = {}
self.node_style(node, **style)
seen.add(node)
if edgefn is not None:
for head in seen:
for tail in (n for n in edgefn(head) if n in seen):
if edgevisitor is None:
edgestyle = {}
else:
edgestyle = edgevisitor(head, tail)
if edgestyle is not None:
if head not in self.edges:
self.edges[head] = {}
self.edges[head][tail] = {}
self.edge_style(head, tail, **edgestyle)
def style(self, **attr):
"""
Changes the overall style
"""
self.attr = attr
def display(self, mode="dot"):
"""
Displays the current graph via dotty
"""
if mode == "neato":
self.save_dot(self.temp_neo)
neato_cmd = "%s -o %s %s" % (self.neato, self.temp_dot, self.temp_neo)
os.system(neato_cmd)
else:
self.save_dot(self.temp_dot)
plot_cmd = "%s %s" % (self.dotty, self.temp_dot)
os.system(plot_cmd)
def node_style(self, node, **kwargs):
"""
Modifies a node style to the dot representation.
"""
if node not in self.edges:
self.edges[node] = {}
self.nodes[node] = kwargs
def all_node_style(self, **kwargs):
"""
Modifies all node styles
"""
for node in self.nodes:
self.node_style(node, **kwargs)
def edge_style(self, head, tail, **kwargs):
"""
Modifies an edge style to the dot representation.
"""
if tail not in self.nodes:
raise GraphError("invalid node %s" % (tail,))
try:
if tail not in self.edges[head]:
self.edges[head][tail] = {}
self.edges[head][tail] = kwargs
except KeyError:
raise GraphError("invalid edge %s -> %s " % (head, tail))
def iterdot(self):
# write graph title
if self.type == "digraph":
yield "digraph %s {\n" % (self.name,)
elif self.type == "graph":
yield "graph %s {\n" % (self.name,)
else:
raise GraphError("unsupported graphtype %s" % (self.type,))
# write overall graph attributes
for attr_name, attr_value in sorted(self.attr.items()):
yield '%s="%s";' % (attr_name, attr_value)
yield "\n"
# some reusable patterns
cpatt = '%s="%s",' # to separate attributes
epatt = "];\n" # to end attributes
# write node attributes
for node_name, node_attr in sorted(self.nodes.items()):
yield '\t"%s" [' % (node_name,)
for attr_name, attr_value in sorted(node_attr.items()):
yield cpatt % (attr_name, attr_value)
yield epatt
# write edge attributes
for head in sorted(self.edges):
for tail in sorted(self.edges[head]):
if self.type == "digraph":
yield '\t"%s" -> "%s" [' % (head, tail)
else:
yield '\t"%s" -- "%s" [' % (head, tail)
for attr_name, attr_value in sorted(self.edges[head][tail].items()):
yield cpatt % (attr_name, attr_value)
yield epatt
# finish file
yield "}\n"
def __iter__(self):
return self.iterdot()
def save_dot(self, file_name=None):
"""
Saves the current graph representation into a file
"""
if not file_name:
warnings.warn(DeprecationWarning, "always pass a file_name")
file_name = self.temp_dot
with open(file_name, "w") as fp:
for chunk in self.iterdot():
fp.write(chunk)
def save_img(self, file_name=None, file_type="gif", mode="dot"):
"""
Saves the dot file as an image file
"""
if not file_name:
warnings.warn(DeprecationWarning, "always pass a file_name")
file_name = "out"
if mode == "neato":
self.save_dot(self.temp_neo)
neato_cmd = "%s -o %s %s" % (self.neato, self.temp_dot, self.temp_neo)
os.system(neato_cmd)
plot_cmd = self.dot
else:
self.save_dot(self.temp_dot)
plot_cmd = self.dot
file_name = "%s.%s" % (file_name, file_type)
create_cmd = "%s -T%s %s -o %s" % (
plot_cmd,
file_type,
self.temp_dot,
file_name,
)
os.system(create_cmd)

View file

@ -0,0 +1,682 @@
"""
altgraph.Graph - Base Graph class
=================================
..
#--Version 2.1
#--Bob Ippolito October, 2004
#--Version 2.0
#--Istvan Albert June, 2004
#--Version 1.0
#--Nathan Denny, May 27, 1999
"""
from collections import deque
from altgraph import GraphError
class Graph(object):
"""
The Graph class represents a directed graph with *N* nodes and *E* edges.
Naming conventions:
- the prefixes such as *out*, *inc* and *all* will refer to methods
that operate on the outgoing, incoming or all edges of that node.
For example: :py:meth:`inc_degree` will refer to the degree of the node
computed over the incoming edges (the number of neighbours linking to
the node).
- the prefixes such as *forw* and *back* will refer to the
orientation of the edges used in the method with respect to the node.
For example: :py:meth:`forw_bfs` will start at the node then use the
outgoing edges to traverse the graph (goes forward).
"""
def __init__(self, edges=None):
"""
Initialization
"""
self.next_edge = 0
self.nodes, self.edges = {}, {}
self.hidden_edges, self.hidden_nodes = {}, {}
if edges is not None:
for item in edges:
if len(item) == 2:
head, tail = item
self.add_edge(head, tail)
elif len(item) == 3:
head, tail, data = item
self.add_edge(head, tail, data)
else:
raise GraphError("Cannot create edge from %s" % (item,))
def __repr__(self):
return "<Graph: %d nodes, %d edges>" % (
self.number_of_nodes(),
self.number_of_edges(),
)
def add_node(self, node, node_data=None):
"""
Adds a new node to the graph. Arbitrary data can be attached to the
node via the node_data parameter. Adding the same node twice will be
silently ignored.
The node must be a hashable value.
"""
#
# the nodes will contain tuples that will store incoming edges,
# outgoing edges and data
#
# index 0 -> incoming edges
# index 1 -> outgoing edges
if node in self.hidden_nodes:
# Node is present, but hidden
return
if node not in self.nodes:
self.nodes[node] = ([], [], node_data)
def add_edge(self, head_id, tail_id, edge_data=1, create_nodes=True):
"""
Adds a directed edge going from head_id to tail_id.
Arbitrary data can be attached to the edge via edge_data.
It may create the nodes if adding edges between nonexisting ones.
:param head_id: head node
:param tail_id: tail node
:param edge_data: (optional) data attached to the edge
:param create_nodes: (optional) creates the head_id or tail_id
node in case they did not exist
"""
# shorcut
edge = self.next_edge
# add nodes if on automatic node creation
if create_nodes:
self.add_node(head_id)
self.add_node(tail_id)
# update the corresponding incoming and outgoing lists in the nodes
# index 0 -> incoming edges
# index 1 -> outgoing edges
try:
self.nodes[tail_id][0].append(edge)
self.nodes[head_id][1].append(edge)
except KeyError:
raise GraphError("Invalid nodes %s -> %s" % (head_id, tail_id))
# store edge information
self.edges[edge] = (head_id, tail_id, edge_data)
self.next_edge += 1
def hide_edge(self, edge):
"""
Hides an edge from the graph. The edge may be unhidden at some later
time.
"""
try:
head_id, tail_id, edge_data = self.hidden_edges[edge] = self.edges[edge]
self.nodes[tail_id][0].remove(edge)
self.nodes[head_id][1].remove(edge)
del self.edges[edge]
except KeyError:
raise GraphError("Invalid edge %s" % edge)
def hide_node(self, node):
"""
Hides a node from the graph. The incoming and outgoing edges of the
node will also be hidden. The node may be unhidden at some later time.
"""
try:
all_edges = self.all_edges(node)
self.hidden_nodes[node] = (self.nodes[node], all_edges)
for edge in all_edges:
self.hide_edge(edge)
del self.nodes[node]
except KeyError:
raise GraphError("Invalid node %s" % node)
def restore_node(self, node):
"""
Restores a previously hidden node back into the graph and restores
all of its incoming and outgoing edges.
"""
try:
self.nodes[node], all_edges = self.hidden_nodes[node]
for edge in all_edges:
self.restore_edge(edge)
del self.hidden_nodes[node]
except KeyError:
raise GraphError("Invalid node %s" % node)
def restore_edge(self, edge):
"""
Restores a previously hidden edge back into the graph.
"""
try:
head_id, tail_id, data = self.hidden_edges[edge]
self.nodes[tail_id][0].append(edge)
self.nodes[head_id][1].append(edge)
self.edges[edge] = head_id, tail_id, data
del self.hidden_edges[edge]
except KeyError:
raise GraphError("Invalid edge %s" % edge)
def restore_all_edges(self):
"""
Restores all hidden edges.
"""
for edge in list(self.hidden_edges.keys()):
try:
self.restore_edge(edge)
except GraphError:
pass
def restore_all_nodes(self):
"""
Restores all hidden nodes.
"""
for node in list(self.hidden_nodes.keys()):
self.restore_node(node)
def __contains__(self, node):
"""
Test whether a node is in the graph
"""
return node in self.nodes
def edge_by_id(self, edge):
"""
Returns the edge that connects the head_id and tail_id nodes
"""
try:
head, tail, data = self.edges[edge]
except KeyError:
head, tail = None, None
raise GraphError("Invalid edge %s" % edge)
return (head, tail)
def edge_by_node(self, head, tail):
"""
Returns the edge that connects the head_id and tail_id nodes
"""
for edge in self.out_edges(head):
if self.tail(edge) == tail:
return edge
return None
def number_of_nodes(self):
"""
Returns the number of nodes
"""
return len(self.nodes)
def number_of_edges(self):
"""
Returns the number of edges
"""
return len(self.edges)
def __iter__(self):
"""
Iterates over all nodes in the graph
"""
return iter(self.nodes)
def node_list(self):
"""
Return a list of the node ids for all visible nodes in the graph.
"""
return list(self.nodes.keys())
def edge_list(self):
"""
Returns an iterator for all visible nodes in the graph.
"""
return list(self.edges.keys())
def number_of_hidden_edges(self):
"""
Returns the number of hidden edges
"""
return len(self.hidden_edges)
def number_of_hidden_nodes(self):
"""
Returns the number of hidden nodes
"""
return len(self.hidden_nodes)
def hidden_node_list(self):
"""
Returns the list with the hidden nodes
"""
return list(self.hidden_nodes.keys())
def hidden_edge_list(self):
"""
Returns a list with the hidden edges
"""
return list(self.hidden_edges.keys())
def describe_node(self, node):
"""
return node, node data, outgoing edges, incoming edges for node
"""
incoming, outgoing, data = self.nodes[node]
return node, data, outgoing, incoming
def describe_edge(self, edge):
"""
return edge, edge data, head, tail for edge
"""
head, tail, data = self.edges[edge]
return edge, data, head, tail
def node_data(self, node):
"""
Returns the data associated with a node
"""
return self.nodes[node][2]
def edge_data(self, edge):
"""
Returns the data associated with an edge
"""
return self.edges[edge][2]
def update_edge_data(self, edge, edge_data):
"""
Replace the edge data for a specific edge
"""
self.edges[edge] = self.edges[edge][0:2] + (edge_data,)
def head(self, edge):
"""
Returns the node of the head of the edge.
"""
return self.edges[edge][0]
def tail(self, edge):
"""
Returns node of the tail of the edge.
"""
return self.edges[edge][1]
def out_nbrs(self, node):
"""
List of nodes connected by outgoing edges
"""
return [self.tail(n) for n in self.out_edges(node)]
def inc_nbrs(self, node):
"""
List of nodes connected by incoming edges
"""
return [self.head(n) for n in self.inc_edges(node)]
def all_nbrs(self, node):
"""
List of nodes connected by incoming and outgoing edges
"""
return list(dict.fromkeys(self.inc_nbrs(node) + self.out_nbrs(node)))
def out_edges(self, node):
"""
Returns a list of the outgoing edges
"""
try:
return list(self.nodes[node][1])
except KeyError:
raise GraphError("Invalid node %s" % node)
def inc_edges(self, node):
"""
Returns a list of the incoming edges
"""
try:
return list(self.nodes[node][0])
except KeyError:
raise GraphError("Invalid node %s" % node)
def all_edges(self, node):
"""
Returns a list of incoming and outging edges.
"""
return set(self.inc_edges(node) + self.out_edges(node))
def out_degree(self, node):
"""
Returns the number of outgoing edges
"""
return len(self.out_edges(node))
def inc_degree(self, node):
"""
Returns the number of incoming edges
"""
return len(self.inc_edges(node))
def all_degree(self, node):
"""
The total degree of a node
"""
return self.inc_degree(node) + self.out_degree(node)
def _topo_sort(self, forward=True):
"""
Topological sort.
Returns a list of nodes where the successors (based on outgoing and
incoming edges selected by the forward parameter) of any given node
appear in the sequence after that node.
"""
topo_list = []
queue = deque()
indeg = {}
# select the operation that will be performed
if forward:
get_edges = self.out_edges
get_degree = self.inc_degree
get_next = self.tail
else:
get_edges = self.inc_edges
get_degree = self.out_degree
get_next = self.head
for node in self.node_list():
degree = get_degree(node)
if degree:
indeg[node] = degree
else:
queue.append(node)
while queue:
curr_node = queue.popleft()
topo_list.append(curr_node)
for edge in get_edges(curr_node):
tail_id = get_next(edge)
if tail_id in indeg:
indeg[tail_id] -= 1
if indeg[tail_id] == 0:
queue.append(tail_id)
if len(topo_list) == len(self.node_list()):
valid = True
else:
# the graph has cycles, invalid topological sort
valid = False
return (valid, topo_list)
def forw_topo_sort(self):
"""
Topological sort.
Returns a list of nodes where the successors (based on outgoing edges)
of any given node appear in the sequence after that node.
"""
return self._topo_sort(forward=True)
def back_topo_sort(self):
"""
Reverse topological sort.
Returns a list of nodes where the successors (based on incoming edges)
of any given node appear in the sequence after that node.
"""
return self._topo_sort(forward=False)
def _bfs_subgraph(self, start_id, forward=True):
"""
Private method creates a subgraph in a bfs order.
The forward parameter specifies whether it is a forward or backward
traversal.
"""
if forward:
get_bfs = self.forw_bfs
get_nbrs = self.out_nbrs
else:
get_bfs = self.back_bfs
get_nbrs = self.inc_nbrs
g = Graph()
bfs_list = get_bfs(start_id)
for node in bfs_list:
g.add_node(node)
for node in bfs_list:
for nbr_id in get_nbrs(node):
if forward:
g.add_edge(node, nbr_id)
else:
g.add_edge(nbr_id, node)
return g
def forw_bfs_subgraph(self, start_id):
"""
Creates and returns a subgraph consisting of the breadth first
reachable nodes based on their outgoing edges.
"""
return self._bfs_subgraph(start_id, forward=True)
def back_bfs_subgraph(self, start_id):
"""
Creates and returns a subgraph consisting of the breadth first
reachable nodes based on the incoming edges.
"""
return self._bfs_subgraph(start_id, forward=False)
def iterdfs(self, start, end=None, forward=True):
"""
Collecting nodes in some depth first traversal.
The forward parameter specifies whether it is a forward or backward
traversal.
"""
visited, stack = {start}, deque([start])
if forward:
get_edges = self.out_edges
get_next = self.tail
else:
get_edges = self.inc_edges
get_next = self.head
while stack:
curr_node = stack.pop()
yield curr_node
if curr_node == end:
break
for edge in sorted(get_edges(curr_node)):
tail = get_next(edge)
if tail not in visited:
visited.add(tail)
stack.append(tail)
def iterdata(self, start, end=None, forward=True, condition=None):
"""
Perform a depth-first walk of the graph (as ``iterdfs``)
and yield the item data of every node where condition matches. The
condition callback is only called when node_data is not None.
"""
visited, stack = {start}, deque([start])
if forward:
get_edges = self.out_edges
get_next = self.tail
else:
get_edges = self.inc_edges
get_next = self.head
get_data = self.node_data
while stack:
curr_node = stack.pop()
curr_data = get_data(curr_node)
if curr_data is not None:
if condition is not None and not condition(curr_data):
continue
yield curr_data
if curr_node == end:
break
for edge in get_edges(curr_node):
tail = get_next(edge)
if tail not in visited:
visited.add(tail)
stack.append(tail)
def _iterbfs(self, start, end=None, forward=True):
"""
The forward parameter specifies whether it is a forward or backward
traversal. Returns a list of tuples where the first value is the hop
value the second value is the node id.
"""
queue, visited = deque([(start, 0)]), {start}
# the direction of the bfs depends on the edges that are sampled
if forward:
get_edges = self.out_edges
get_next = self.tail
else:
get_edges = self.inc_edges
get_next = self.head
while queue:
curr_node, curr_step = queue.popleft()
yield (curr_node, curr_step)
if curr_node == end:
break
for edge in get_edges(curr_node):
tail = get_next(edge)
if tail not in visited:
visited.add(tail)
queue.append((tail, curr_step + 1))
def forw_bfs(self, start, end=None):
"""
Returns a list of nodes in some forward BFS order.
Starting from the start node the breadth first search proceeds along
outgoing edges.
"""
return [node for node, step in self._iterbfs(start, end, forward=True)]
def back_bfs(self, start, end=None):
"""
Returns a list of nodes in some backward BFS order.
Starting from the start node the breadth first search proceeds along
incoming edges.
"""
return [node for node, _ in self._iterbfs(start, end, forward=False)]
def forw_dfs(self, start, end=None):
"""
Returns a list of nodes in some forward DFS order.
Starting with the start node the depth first search proceeds along
outgoing edges.
"""
return list(self.iterdfs(start, end, forward=True))
def back_dfs(self, start, end=None):
"""
Returns a list of nodes in some backward DFS order.
Starting from the start node the depth first search proceeds along
incoming edges.
"""
return list(self.iterdfs(start, end, forward=False))
def connected(self):
"""
Returns :py:data:`True` if the graph's every node can be reached from
every other node.
"""
node_list = self.node_list()
for node in node_list:
bfs_list = self.forw_bfs(node)
if len(bfs_list) != len(node_list):
return False
return True
def clust_coef(self, node):
"""
Computes and returns the local clustering coefficient of node.
The local cluster coefficient is proportion of the actual number of
edges between neighbours of node and the maximum number of edges
between those neighbours.
See "Local Clustering Coefficient" on
<http://en.wikipedia.org/wiki/Clustering_coefficient>
for a formal definition.
"""
num = 0
nbr_set = set(self.out_nbrs(node))
if node in nbr_set:
nbr_set.remove(node) # loop defense
for nbr in nbr_set:
sec_set = set(self.out_nbrs(nbr))
if nbr in sec_set:
sec_set.remove(nbr) # loop defense
num += len(nbr_set & sec_set)
nbr_num = len(nbr_set)
if nbr_num:
clust_coef = float(num) / (nbr_num * (nbr_num - 1))
else:
clust_coef = 0.0
return clust_coef
def get_hops(self, start, end=None, forward=True):
"""
Computes the hop distance to all nodes centered around a node.
First order neighbours are at hop 1, their neigbours are at hop 2 etc.
Uses :py:meth:`forw_bfs` or :py:meth:`back_bfs` depending on the value
of the forward parameter. If the distance between all neighbouring
nodes is 1 the hop number corresponds to the shortest distance between
the nodes.
:param start: the starting node
:param end: ending node (optional). When not specified will search the
whole graph.
:param forward: directionality parameter (optional).
If C{True} (default) it uses L{forw_bfs} otherwise L{back_bfs}.
:return: returns a list of tuples where each tuple contains the
node and the hop.
Typical usage::
>>> print (graph.get_hops(1, 8))
>>> [(1, 0), (2, 1), (3, 1), (4, 2), (5, 3), (7, 4), (8, 5)]
# node 1 is at 0 hops
# node 2 is at 1 hop
# ...
# node 8 is at 5 hops
"""
if forward:
return list(self._iterbfs(start=start, end=end, forward=True))
else:
return list(self._iterbfs(start=start, end=end, forward=False))

View file

@ -0,0 +1,171 @@
"""
altgraph.GraphAlgo - Graph algorithms
=====================================
"""
from altgraph import GraphError
def dijkstra(graph, start, end=None):
"""
Dijkstra's algorithm for shortest paths
`David Eppstein, UC Irvine, 4 April 2002
<http://www.ics.uci.edu/~eppstein/161/python/>`_
`Python Cookbook Recipe
<http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/119466>`_
Find shortest paths from the start node to all nodes nearer than or
equal to the end node.
Dijkstra's algorithm is only guaranteed to work correctly when all edge
lengths are positive. This code does not verify this property for all
edges (only the edges examined until the end vertex is reached), but will
correctly compute shortest paths even for some graphs with negative edges,
and will raise an exception if it discovers that a negative edge has
caused it to make a mistake.
Adapted to altgraph by Istvan Albert, Pennsylvania State University -
June, 9 2004
"""
D = {} # dictionary of final distances
P = {} # dictionary of predecessors
Q = _priorityDictionary() # estimated distances of non-final vertices
Q[start] = 0
for v in Q:
D[v] = Q[v]
if v == end:
break
for w in graph.out_nbrs(v):
edge_id = graph.edge_by_node(v, w)
vwLength = D[v] + graph.edge_data(edge_id)
if w in D:
if vwLength < D[w]:
raise GraphError(
"Dijkstra: found better path to already-final vertex"
)
elif w not in Q or vwLength < Q[w]:
Q[w] = vwLength
P[w] = v
return (D, P)
def shortest_path(graph, start, end):
"""
Find a single shortest path from the *start* node to the *end* node.
The input has the same conventions as dijkstra(). The output is a list of
the nodes in order along the shortest path.
**Note that the distances must be stored in the edge data as numeric data**
"""
D, P = dijkstra(graph, start, end)
Path = []
while 1:
Path.append(end)
if end == start:
break
end = P[end]
Path.reverse()
return Path
#
# Utility classes and functions
#
class _priorityDictionary(dict):
"""
Priority dictionary using binary heaps (internal use only)
David Eppstein, UC Irvine, 8 Mar 2002
Implements a data structure that acts almost like a dictionary, with
two modifications:
1. D.smallest() returns the value x minimizing D[x]. For this to
work correctly, all values D[x] stored in the dictionary must be
comparable.
2. iterating "for x in D" finds and removes the items from D in sorted
order. Each item is not removed until the next item is requested,
so D[x] will still return a useful value until the next iteration
of the for-loop. Each operation takes logarithmic amortized time.
"""
def __init__(self):
"""
Initialize priorityDictionary by creating binary heap of pairs
(value,key). Note that changing or removing a dict entry will not
remove the old pair from the heap until it is found by smallest()
or until the heap is rebuilt.
"""
self.__heap = []
dict.__init__(self)
def smallest(self):
"""
Find smallest item after removing deleted items from front of heap.
"""
if len(self) == 0:
raise IndexError("smallest of empty priorityDictionary")
heap = self.__heap
while heap[0][1] not in self or self[heap[0][1]] != heap[0][0]:
lastItem = heap.pop()
insertionPoint = 0
while 1:
smallChild = 2 * insertionPoint + 1
if (
smallChild + 1 < len(heap)
and heap[smallChild] > heap[smallChild + 1]
):
smallChild += 1
if smallChild >= len(heap) or lastItem <= heap[smallChild]:
heap[insertionPoint] = lastItem
break
heap[insertionPoint] = heap[smallChild]
insertionPoint = smallChild
return heap[0][1]
def __iter__(self):
"""
Create destructive sorted iterator of priorityDictionary.
"""
def iterfn():
while len(self) > 0:
x = self.smallest()
yield x
del self[x]
return iterfn()
def __setitem__(self, key, val):
"""
Change value stored in dictionary and add corresponding pair to heap.
Rebuilds the heap if the number of deleted items gets large, to avoid
memory leakage.
"""
dict.__setitem__(self, key, val)
heap = self.__heap
if len(heap) > 2 * len(self):
self.__heap = [(v, k) for k, v in self.items()]
self.__heap.sort()
else:
newPair = (val, key)
insertionPoint = len(heap)
heap.append(None)
while insertionPoint > 0 and newPair < heap[(insertionPoint - 1) // 2]:
heap[insertionPoint] = heap[(insertionPoint - 1) // 2]
insertionPoint = (insertionPoint - 1) // 2
heap[insertionPoint] = newPair
def setdefault(self, key, val):
"""
Reimplement setdefault to pass through our customized __setitem__.
"""
if key not in self:
self[key] = val
return self[key]

View file

@ -0,0 +1,73 @@
"""
altgraph.GraphStat - Functions providing various graph statistics
=================================================================
"""
def degree_dist(graph, limits=(0, 0), bin_num=10, mode="out"):
"""
Computes the degree distribution for a graph.
Returns a list of tuples where the first element of the tuple is the
center of the bin representing a range of degrees and the second element
of the tuple are the number of nodes with the degree falling in the range.
Example::
....
"""
deg = []
if mode == "inc":
get_deg = graph.inc_degree
else:
get_deg = graph.out_degree
for node in graph:
deg.append(get_deg(node))
if not deg:
return []
results = _binning(values=deg, limits=limits, bin_num=bin_num)
return results
_EPS = 1.0 / (2.0 ** 32)
def _binning(values, limits=(0, 0), bin_num=10):
"""
Bins data that falls between certain limits, if the limits are (0, 0) the
minimum and maximum values are used.
Returns a list of tuples where the first element of the tuple is the
center of the bin and the second element of the tuple are the counts.
"""
if limits == (0, 0):
min_val, max_val = min(values) - _EPS, max(values) + _EPS
else:
min_val, max_val = limits
# get bin size
bin_size = (max_val - min_val) / float(bin_num)
bins = [0] * (bin_num)
# will ignore these outliers for now
for value in values:
try:
if (value - min_val) >= 0:
index = int((value - min_val) / float(bin_size))
bins[index] += 1
except IndexError:
pass
# make it ready for an x,y plot
result = []
center = (bin_size / 2) + min_val
for i, y in enumerate(bins):
x = center + bin_size * i
result.append((x, y))
return result

View file

@ -0,0 +1,139 @@
"""
altgraph.GraphUtil - Utility classes and functions
==================================================
"""
import random
from collections import deque
from altgraph import Graph, GraphError
def generate_random_graph(node_num, edge_num, self_loops=False, multi_edges=False):
"""
Generates and returns a :py:class:`~altgraph.Graph.Graph` instance with
*node_num* nodes randomly connected by *edge_num* edges.
"""
g = Graph.Graph()
if not multi_edges:
if self_loops:
max_edges = node_num * node_num
else:
max_edges = node_num * (node_num - 1)
if edge_num > max_edges:
raise GraphError("inconsistent arguments to 'generate_random_graph'")
nodes = range(node_num)
for node in nodes:
g.add_node(node)
while 1:
head = random.choice(nodes)
tail = random.choice(nodes)
# loop defense
if head == tail and not self_loops:
continue
# multiple edge defense
if g.edge_by_node(head, tail) is not None and not multi_edges:
continue
# add the edge
g.add_edge(head, tail)
if g.number_of_edges() >= edge_num:
break
return g
def generate_scale_free_graph(steps, growth_num, self_loops=False, multi_edges=False):
"""
Generates and returns a :py:class:`~altgraph.Graph.Graph` instance that
will have *steps* \\* *growth_num* nodes and a scale free (powerlaw)
connectivity. Starting with a fully connected graph with *growth_num*
nodes at every step *growth_num* nodes are added to the graph and are
connected to existing nodes with a probability proportional to the degree
of these existing nodes.
"""
# The code doesn't seem to do what the documentation claims.
graph = Graph.Graph()
# initialize the graph
store = []
for i in range(growth_num):
for j in range(i + 1, growth_num):
store.append(i)
store.append(j)
graph.add_edge(i, j)
# generate
for node in range(growth_num, steps * growth_num):
graph.add_node(node)
while graph.out_degree(node) < growth_num:
nbr = random.choice(store)
# loop defense
if node == nbr and not self_loops:
continue
# multi edge defense
if graph.edge_by_node(node, nbr) and not multi_edges:
continue
graph.add_edge(node, nbr)
for nbr in graph.out_nbrs(node):
store.append(node)
store.append(nbr)
return graph
def filter_stack(graph, head, filters):
"""
Perform a walk in a depth-first order starting
at *head*.
Returns (visited, removes, orphans).
* visited: the set of visited nodes
* removes: the list of nodes where the node
data does not all *filters*
* orphans: tuples of (last_good, node),
where node is not in removes, is directly
reachable from a node in *removes* and
*last_good* is the closest upstream node that is not
in *removes*.
"""
visited, removes, orphans = {head}, set(), set()
stack = deque([(head, head)])
get_data = graph.node_data
get_edges = graph.out_edges
get_tail = graph.tail
while stack:
last_good, node = stack.pop()
data = get_data(node)
if data is not None:
for filtfunc in filters:
if not filtfunc(data):
removes.add(node)
break
else:
last_good = node
for edge in get_edges(node):
tail = get_tail(edge)
if last_good is not node:
orphans.add((last_good, tail))
if tail not in visited:
visited.add(tail)
stack.append((last_good, tail))
orphans = [(lg, tl) for (lg, tl) in orphans if tl not in removes]
return visited, removes, orphans

View file

@ -0,0 +1,18 @@
Copyright (c) 2004 Istvan Albert unless otherwise noted.
Copyright (c) 2006-2010 Bob Ippolito
Copyright (2) 2010-2020 Ronald Oussoren, et. al.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.

View file

@ -0,0 +1,212 @@
"""
altgraph.ObjectGraph - Graph of objects with an identifier
==========================================================
A graph of objects that have a "graphident" attribute.
graphident is the key for the object in the graph
"""
from altgraph import GraphError
from altgraph.Graph import Graph
from altgraph.GraphUtil import filter_stack
class ObjectGraph(object):
"""
A graph of objects that have a "graphident" attribute.
graphident is the key for the object in the graph
"""
def __init__(self, graph=None, debug=0):
if graph is None:
graph = Graph()
self.graphident = self
self.graph = graph
self.debug = debug
self.indent = 0
graph.add_node(self, None)
def __repr__(self):
return "<%s>" % (type(self).__name__,)
def flatten(self, condition=None, start=None):
"""
Iterate over the subgraph that is entirely reachable by condition
starting from the given start node or the ObjectGraph root
"""
if start is None:
start = self
start = self.getRawIdent(start)
return self.graph.iterdata(start=start, condition=condition)
def nodes(self):
for ident in self.graph:
node = self.graph.node_data(ident)
if node is not None:
yield self.graph.node_data(ident)
def get_edges(self, node):
if node is None:
node = self
start = self.getRawIdent(node)
_, _, outraw, incraw = self.graph.describe_node(start)
def iter_edges(lst, n):
seen = set()
for tpl in (self.graph.describe_edge(e) for e in lst):
ident = tpl[n]
if ident not in seen:
yield self.findNode(ident)
seen.add(ident)
return iter_edges(outraw, 3), iter_edges(incraw, 2)
def edgeData(self, fromNode, toNode):
if fromNode is None:
fromNode = self
start = self.getRawIdent(fromNode)
stop = self.getRawIdent(toNode)
edge = self.graph.edge_by_node(start, stop)
return self.graph.edge_data(edge)
def updateEdgeData(self, fromNode, toNode, edgeData):
if fromNode is None:
fromNode = self
start = self.getRawIdent(fromNode)
stop = self.getRawIdent(toNode)
edge = self.graph.edge_by_node(start, stop)
self.graph.update_edge_data(edge, edgeData)
def filterStack(self, filters):
"""
Filter the ObjectGraph in-place by removing all edges to nodes that
do not match every filter in the given filter list
Returns a tuple containing the number of:
(nodes_visited, nodes_removed, nodes_orphaned)
"""
visited, removes, orphans = filter_stack(self.graph, self, filters)
for last_good, tail in orphans:
self.graph.add_edge(last_good, tail, edge_data="orphan")
for node in removes:
self.graph.hide_node(node)
return len(visited) - 1, len(removes), len(orphans)
def removeNode(self, node):
"""
Remove the given node from the graph if it exists
"""
ident = self.getIdent(node)
if ident is not None:
self.graph.hide_node(ident)
def removeReference(self, fromnode, tonode):
"""
Remove all edges from fromnode to tonode
"""
if fromnode is None:
fromnode = self
fromident = self.getIdent(fromnode)
toident = self.getIdent(tonode)
if fromident is not None and toident is not None:
while True:
edge = self.graph.edge_by_node(fromident, toident)
if edge is None:
break
self.graph.hide_edge(edge)
def getIdent(self, node):
"""
Get the graph identifier for a node
"""
ident = self.getRawIdent(node)
if ident is not None:
return ident
node = self.findNode(node)
if node is None:
return None
return node.graphident
def getRawIdent(self, node):
"""
Get the identifier for a node object
"""
if node is self:
return node
ident = getattr(node, "graphident", None)
return ident
def __contains__(self, node):
return self.findNode(node) is not None
def findNode(self, node):
"""
Find the node on the graph
"""
ident = self.getRawIdent(node)
if ident is None:
ident = node
try:
return self.graph.node_data(ident)
except KeyError:
return None
def addNode(self, node):
"""
Add a node to the graph referenced by the root
"""
self.msg(4, "addNode", node)
try:
self.graph.restore_node(node.graphident)
except GraphError:
self.graph.add_node(node.graphident, node)
def createReference(self, fromnode, tonode, edge_data=None):
"""
Create a reference from fromnode to tonode
"""
if fromnode is None:
fromnode = self
fromident, toident = self.getIdent(fromnode), self.getIdent(tonode)
if fromident is None or toident is None:
return
self.msg(4, "createReference", fromnode, tonode, edge_data)
self.graph.add_edge(fromident, toident, edge_data=edge_data)
def createNode(self, cls, name, *args, **kw):
"""
Add a node of type cls to the graph if it does not already exist
by the given name
"""
m = self.findNode(name)
if m is None:
m = cls(name, *args, **kw)
self.addNode(m)
return m
def msg(self, level, s, *args):
"""
Print a debug message with the given level
"""
if s and level <= self.debug:
print("%s%s %s" % (" " * self.indent, s, " ".join(map(repr, args))))
def msgin(self, level, s, *args):
"""
Print a debug message and indent
"""
if level <= self.debug:
self.msg(level, s, *args)
self.indent = self.indent + 1
def msgout(self, level, s, *args):
"""
Dedent and print a debug message
"""
if level <= self.debug:
self.indent = self.indent - 1
self.msg(level, s, *args)

View file

@ -0,0 +1,148 @@
"""
altgraph - a python graph library
=================================
altgraph is a fork of `graphlib <http://pygraphlib.sourceforge.net>`_ tailored
to use newer Python 2.3+ features, including additional support used by the
py2app suite (modulegraph and macholib, specifically).
altgraph is a python based graph (network) representation and manipulation
package. It has started out as an extension to the
`graph_lib module
<http://www.ece.arizona.edu/~denny/python_nest/graph_lib_1.0.1.html>`_
written by Nathan Denny it has been significantly optimized and expanded.
The :class:`altgraph.Graph.Graph` class is loosely modeled after the
`LEDA <http://www.algorithmic-solutions.com/enleda.htm>`_
(Library of Efficient Datatypes) representation. The library
includes methods for constructing graphs, BFS and DFS traversals,
topological sort, finding connected components, shortest paths as well as a
number graph statistics functions. The library can also visualize graphs
via `graphviz <http://www.research.att.com/sw/tools/graphviz/>`_.
The package contains the following modules:
- the :py:mod:`altgraph.Graph` module contains the
:class:`~altgraph.Graph.Graph` class that stores the graph data
- the :py:mod:`altgraph.GraphAlgo` module implements graph algorithms
operating on graphs (:py:class:`~altgraph.Graph.Graph`} instances)
- the :py:mod:`altgraph.GraphStat` module contains functions for
computing statistical measures on graphs
- the :py:mod:`altgraph.GraphUtil` module contains functions for
generating, reading and saving graphs
- the :py:mod:`altgraph.Dot` module contains functions for displaying
graphs via `graphviz <http://www.research.att.com/sw/tools/graphviz/>`_
- the :py:mod:`altgraph.ObjectGraph` module implements a graph of
objects with a unique identifier
Installation
------------
Download and unpack the archive then type::
python setup.py install
This will install the library in the default location. For instructions on
how to customize the install procedure read the output of::
python setup.py --help install
To verify that the code works run the test suite::
python setup.py test
Example usage
-------------
Lets assume that we want to analyze the graph below (links to the full picture)
GRAPH_IMG. Our script then might look the following way::
from altgraph import Graph, GraphAlgo, Dot
# these are the edges
edges = [ (1,2), (2,4), (1,3), (2,4), (3,4), (4,5), (6,5),
(6,14), (14,15), (6, 15), (5,7), (7, 8), (7,13), (12,8),
(8,13), (11,12), (11,9), (13,11), (9,13), (13,10) ]
# creates the graph
graph = Graph.Graph()
for head, tail in edges:
graph.add_edge(head, tail)
# do a forward bfs from 1 at most to 20
print(graph.forw_bfs(1))
This will print the nodes in some breadth first order::
[1, 2, 3, 4, 5, 7, 8, 13, 11, 10, 12, 9]
If we wanted to get the hop-distance from node 1 to node 8
we coud write::
print(graph.get_hops(1, 8))
This will print the following::
[(1, 0), (2, 1), (3, 1), (4, 2), (5, 3), (7, 4), (8, 5)]
Node 1 is at 0 hops since it is the starting node, nodes 2,3 are 1 hop away ...
node 8 is 5 hops away. To find the shortest distance between two nodes you
can use::
print(GraphAlgo.shortest_path(graph, 1, 12))
It will print the nodes on one (if there are more) the shortest paths::
[1, 2, 4, 5, 7, 13, 11, 12]
To display the graph we can use the GraphViz backend::
dot = Dot.Dot(graph)
# display the graph on the monitor
dot.display()
# save it in an image file
dot.save_img(file_name='graph', file_type='gif')
..
@author: U{Istvan Albert<http://www.personal.psu.edu/staff/i/u/iua1/>}
@license: MIT License
Copyright (c) 2004 Istvan Albert unless otherwise noted.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
@requires: Python 2.3 or higher
@newfield contributor: Contributors:
@contributor: U{Reka Albert <http://www.phys.psu.edu/~ralbert/>}
"""
import pkg_resources
__version__ = pkg_resources.require("altgraph")[0].version
class GraphError(ValueError):
pass

View file

@ -0,0 +1,79 @@
# SPDX-License-Identifier: MIT
import sys
from functools import partial
from . import converters, exceptions, filters, setters, validators
from ._cmp import cmp_using
from ._config import get_run_validators, set_run_validators
from ._funcs import asdict, assoc, astuple, evolve, has, resolve_types
from ._make import (
NOTHING,
Attribute,
Factory,
attrib,
attrs,
fields,
fields_dict,
make_class,
validate,
)
from ._version_info import VersionInfo
__version__ = "22.1.0"
__version_info__ = VersionInfo._from_version_string(__version__)
__title__ = "attrs"
__description__ = "Classes Without Boilerplate"
__url__ = "https://www.attrs.org/"
__uri__ = __url__
__doc__ = __description__ + " <" + __uri__ + ">"
__author__ = "Hynek Schlawack"
__email__ = "hs@ox.cx"
__license__ = "MIT"
__copyright__ = "Copyright (c) 2015 Hynek Schlawack"
s = attributes = attrs
ib = attr = attrib
dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
__all__ = [
"Attribute",
"Factory",
"NOTHING",
"asdict",
"assoc",
"astuple",
"attr",
"attrib",
"attributes",
"attrs",
"cmp_using",
"converters",
"evolve",
"exceptions",
"fields",
"fields_dict",
"filters",
"get_run_validators",
"has",
"ib",
"make_class",
"resolve_types",
"s",
"set_run_validators",
"setters",
"validate",
"validators",
]
if sys.version_info[:2] >= (3, 6):
from ._next_gen import define, field, frozen, mutable # noqa: F401
__all__.extend(("define", "field", "frozen", "mutable"))

View file

@ -0,0 +1,486 @@
import sys
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generic,
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
# `import X as X` is required to make these public
from . import converters as converters
from . import exceptions as exceptions
from . import filters as filters
from . import setters as setters
from . import validators as validators
from ._cmp import cmp_using as cmp_using
from ._version_info import VersionInfo
__version__: str
__version_info__: VersionInfo
__title__: str
__description__: str
__url__: str
__uri__: str
__author__: str
__email__: str
__license__: str
__copyright__: str
_T = TypeVar("_T")
_C = TypeVar("_C", bound=type)
_EqOrderType = Union[bool, Callable[[Any], Any]]
_ValidatorType = Callable[[Any, Attribute[_T], _T], Any]
_ConverterType = Callable[[Any], Any]
_FilterType = Callable[[Attribute[_T], _T], bool]
_ReprType = Callable[[Any], str]
_ReprArgType = Union[bool, _ReprType]
_OnSetAttrType = Callable[[Any, Attribute[Any], Any], Any]
_OnSetAttrArgType = Union[
_OnSetAttrType, List[_OnSetAttrType], setters._NoOpType
]
_FieldTransformer = Callable[
[type, List[Attribute[Any]]], List[Attribute[Any]]
]
# FIXME: in reality, if multiple validators are passed they must be in a list
# or tuple, but those are invariant and so would prevent subtypes of
# _ValidatorType from working when passed in a list or tuple.
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]
# A protocol to be able to statically accept an attrs class.
class AttrsInstance(Protocol):
__attrs_attrs__: ClassVar[Any]
# _make --
NOTHING: object
# NOTE: Factory lies about its return type to make this possible:
# `x: List[int] # = Factory(list)`
# Work around mypy issue #4554 in the common case by using an overload.
if sys.version_info >= (3, 8):
from typing import Literal
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
@overload
def Factory(
factory: Callable[[Any], _T],
takes_self: Literal[True],
) -> _T: ...
@overload
def Factory(
factory: Callable[[], _T],
takes_self: Literal[False],
) -> _T: ...
else:
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
@overload
def Factory(
factory: Union[Callable[[Any], _T], Callable[[], _T]],
takes_self: bool = ...,
) -> _T: ...
# Static type inference support via __dataclass_transform__ implemented as per:
# https://github.com/microsoft/pyright/blob/1.1.135/specs/dataclass_transforms.md
# This annotation must be applied to all overloads of "define" and "attrs"
#
# NOTE: This is a typing construct and does not exist at runtime. Extensions
# wrapping attrs decorators should declare a separate __dataclass_transform__
# signature in the extension module using the specification linked above to
# provide pyright support.
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...
class Attribute(Generic[_T]):
name: str
default: Optional[_T]
validator: Optional[_ValidatorType[_T]]
repr: _ReprArgType
cmp: _EqOrderType
eq: _EqOrderType
order: _EqOrderType
hash: Optional[bool]
init: bool
converter: Optional[_ConverterType]
metadata: Dict[Any, Any]
type: Optional[Type[_T]]
kw_only: bool
on_setattr: _OnSetAttrType
def evolve(self, **changes: Any) -> "Attribute[Any]": ...
# NOTE: We had several choices for the annotation to use for type arg:
# 1) Type[_T]
# - Pros: Handles simple cases correctly
# - Cons: Might produce less informative errors in the case of conflicting
# TypeVars e.g. `attr.ib(default='bad', type=int)`
# 2) Callable[..., _T]
# - Pros: Better error messages than #1 for conflicting TypeVars
# - Cons: Terrible error messages for validator checks.
# e.g. attr.ib(type=int, validator=validate_str)
# -> error: Cannot infer function type argument
# 3) type (and do all of the work in the mypy plugin)
# - Pros: Simple here, and we could customize the plugin with our own errors.
# - Cons: Would need to write mypy plugin code to handle all the cases.
# We chose option #1.
# `attr` lies about its return type to make the following possible:
# attr() -> Any
# attr(8) -> int
# attr(validator=<some callable>) -> Whatever the callable expects.
# This makes this type of assignments possible:
# x: int = attr(8)
#
# This form catches explicit None or no default but with no other arguments
# returns Any.
@overload
def attrib(
default: None = ...,
validator: None = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the
# other arguments.
@overload
def attrib(
default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form catches an explicit default argument.
@overload
def attrib(
default: _T,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def attrib(
default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: object = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
@overload
def field(
*,
default: None = ...,
validator: None = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: None = ...,
factory: None = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the
# other arguments.
@overload
def field(
*,
default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form catches an explicit default argument.
@overload
def field(
*,
default: _T,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def field(
*,
default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
def attrs(
maybe_cls: _C,
these: Optional[Dict[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
auto_detect: bool = ...,
collect_by_mro: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
def attrs(
maybe_cls: None = ...,
these: Optional[Dict[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
auto_detect: bool = ...,
collect_by_mro: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
match_args: bool = ...,
) -> Callable[[_C], _C]: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
def define(
maybe_cls: _C,
*,
these: Optional[Dict[str, Any]] = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
auto_detect: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
def define(
maybe_cls: None = ...,
*,
these: Optional[Dict[str, Any]] = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
auto_detect: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
match_args: bool = ...,
) -> Callable[[_C], _C]: ...
mutable = define
frozen = define # they differ only in their defaults
def fields(cls: Type[AttrsInstance]) -> Any: ...
def fields_dict(cls: Type[AttrsInstance]) -> Dict[str, Attribute[Any]]: ...
def validate(inst: AttrsInstance) -> None: ...
def resolve_types(
cls: _C,
globalns: Optional[Dict[str, Any]] = ...,
localns: Optional[Dict[str, Any]] = ...,
attribs: Optional[List[Attribute[Any]]] = ...,
) -> _C: ...
# TODO: add support for returning a proper attrs class from the mypy plugin
# we use Any instead of _CountingAttr so that e.g. `make_class('Foo',
# [attr.ib()])` is valid
def make_class(
name: str,
attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]],
bases: Tuple[type, ...] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
collect_by_mro: bool = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
) -> type: ...
# _funcs --
# TODO: add support for returning TypedDict from the mypy plugin
# FIXME: asdict/astuple do not honor their factory args. Waiting on one of
# these:
# https://github.com/python/mypy/issues/4236
# https://github.com/python/typing/issues/253
# XXX: remember to fix attrs.asdict/astuple too!
def asdict(
inst: AttrsInstance,
recurse: bool = ...,
filter: Optional[_FilterType[Any]] = ...,
dict_factory: Type[Mapping[Any, Any]] = ...,
retain_collection_types: bool = ...,
value_serializer: Optional[
Callable[[type, Attribute[Any], Any], Any]
] = ...,
tuple_keys: Optional[bool] = ...,
) -> Dict[str, Any]: ...
# TODO: add support for returning NamedTuple from the mypy plugin
def astuple(
inst: AttrsInstance,
recurse: bool = ...,
filter: Optional[_FilterType[Any]] = ...,
tuple_factory: Type[Sequence[Any]] = ...,
retain_collection_types: bool = ...,
) -> Tuple[Any, ...]: ...
def has(cls: type) -> bool: ...
def assoc(inst: _T, **changes: Any) -> _T: ...
def evolve(inst: _T, **changes: Any) -> _T: ...
# _config --
def set_run_validators(run: bool) -> None: ...
def get_run_validators() -> bool: ...
# aliases --
s = attributes = attrs
ib = attr = attrib
dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;)

View file

@ -0,0 +1,155 @@
# SPDX-License-Identifier: MIT
import functools
import types
from ._make import _make_ne
_operation_names = {"eq": "==", "lt": "<", "le": "<=", "gt": ">", "ge": ">="}
def cmp_using(
eq=None,
lt=None,
le=None,
gt=None,
ge=None,
require_same_type=True,
class_name="Comparable",
):
"""
Create a class that can be passed into `attr.ib`'s ``eq``, ``order``, and
``cmp`` arguments to customize field comparison.
The resulting class will have a full set of ordering methods if
at least one of ``{lt, le, gt, ge}`` and ``eq`` are provided.
:param Optional[callable] eq: `callable` used to evaluate equality
of two objects.
:param Optional[callable] lt: `callable` used to evaluate whether
one object is less than another object.
:param Optional[callable] le: `callable` used to evaluate whether
one object is less than or equal to another object.
:param Optional[callable] gt: `callable` used to evaluate whether
one object is greater than another object.
:param Optional[callable] ge: `callable` used to evaluate whether
one object is greater than or equal to another object.
:param bool require_same_type: When `True`, equality and ordering methods
will return `NotImplemented` if objects are not of the same type.
:param Optional[str] class_name: Name of class. Defaults to 'Comparable'.
See `comparison` for more details.
.. versionadded:: 21.1.0
"""
body = {
"__slots__": ["value"],
"__init__": _make_init(),
"_requirements": [],
"_is_comparable_to": _is_comparable_to,
}
# Add operations.
num_order_functions = 0
has_eq_function = False
if eq is not None:
has_eq_function = True
body["__eq__"] = _make_operator("eq", eq)
body["__ne__"] = _make_ne()
if lt is not None:
num_order_functions += 1
body["__lt__"] = _make_operator("lt", lt)
if le is not None:
num_order_functions += 1
body["__le__"] = _make_operator("le", le)
if gt is not None:
num_order_functions += 1
body["__gt__"] = _make_operator("gt", gt)
if ge is not None:
num_order_functions += 1
body["__ge__"] = _make_operator("ge", ge)
type_ = types.new_class(
class_name, (object,), {}, lambda ns: ns.update(body)
)
# Add same type requirement.
if require_same_type:
type_._requirements.append(_check_same_type)
# Add total ordering if at least one operation was defined.
if 0 < num_order_functions < 4:
if not has_eq_function:
# functools.total_ordering requires __eq__ to be defined,
# so raise early error here to keep a nice stack.
raise ValueError(
"eq must be define is order to complete ordering from "
"lt, le, gt, ge."
)
type_ = functools.total_ordering(type_)
return type_
def _make_init():
"""
Create __init__ method.
"""
def __init__(self, value):
"""
Initialize object with *value*.
"""
self.value = value
return __init__
def _make_operator(name, func):
"""
Create operator method.
"""
def method(self, other):
if not self._is_comparable_to(other):
return NotImplemented
result = func(self.value, other.value)
if result is NotImplemented:
return NotImplemented
return result
method.__name__ = "__%s__" % (name,)
method.__doc__ = "Return a %s b. Computed by attrs." % (
_operation_names[name],
)
return method
def _is_comparable_to(self, other):
"""
Check whether `other` is comparable to `self`.
"""
for func in self._requirements:
if not func(self, other):
return False
return True
def _check_same_type(self, other):
"""
Return True if *self* and *other* are of the same type, False otherwise.
"""
return other.value.__class__ is self.value.__class__

View file

@ -0,0 +1,13 @@
from typing import Any, Callable, Optional, Type
_CompareWithType = Callable[[Any, Any], bool]
def cmp_using(
eq: Optional[_CompareWithType],
lt: Optional[_CompareWithType],
le: Optional[_CompareWithType],
gt: Optional[_CompareWithType],
ge: Optional[_CompareWithType],
require_same_type: bool,
class_name: str,
) -> Type: ...

View file

@ -0,0 +1,185 @@
# SPDX-License-Identifier: MIT
import inspect
import platform
import sys
import threading
import types
import warnings
from collections.abc import Mapping, Sequence # noqa
PYPY = platform.python_implementation() == "PyPy"
PY36 = sys.version_info[:2] >= (3, 6)
HAS_F_STRINGS = PY36
PY310 = sys.version_info[:2] >= (3, 10)
if PYPY or PY36:
ordered_dict = dict
else:
from collections import OrderedDict
ordered_dict = OrderedDict
def just_warn(*args, **kw):
warnings.warn(
"Running interpreter doesn't sufficiently support code object "
"introspection. Some features like bare super() or accessing "
"__class__ will not work with slotted classes.",
RuntimeWarning,
stacklevel=2,
)
class _AnnotationExtractor:
"""
Extract type annotations from a callable, returning None whenever there
is none.
"""
__slots__ = ["sig"]
def __init__(self, callable):
try:
self.sig = inspect.signature(callable)
except (ValueError, TypeError): # inspect failed
self.sig = None
def get_first_param_type(self):
"""
Return the type annotation of the first argument if it's not empty.
"""
if not self.sig:
return None
params = list(self.sig.parameters.values())
if params and params[0].annotation is not inspect.Parameter.empty:
return params[0].annotation
return None
def get_return_type(self):
"""
Return the return type if it's not empty.
"""
if (
self.sig
and self.sig.return_annotation is not inspect.Signature.empty
):
return self.sig.return_annotation
return None
def make_set_closure_cell():
"""Return a function of two arguments (cell, value) which sets
the value stored in the closure cell `cell` to `value`.
"""
# pypy makes this easy. (It also supports the logic below, but
# why not do the easy/fast thing?)
if PYPY:
def set_closure_cell(cell, value):
cell.__setstate__((value,))
return set_closure_cell
# Otherwise gotta do it the hard way.
# Create a function that will set its first cellvar to `value`.
def set_first_cellvar_to(value):
x = value
return
# This function will be eliminated as dead code, but
# not before its reference to `x` forces `x` to be
# represented as a closure cell rather than a local.
def force_x_to_be_a_cell(): # pragma: no cover
return x
try:
# Extract the code object and make sure our assumptions about
# the closure behavior are correct.
co = set_first_cellvar_to.__code__
if co.co_cellvars != ("x",) or co.co_freevars != ():
raise AssertionError # pragma: no cover
# Convert this code object to a code object that sets the
# function's first _freevar_ (not cellvar) to the argument.
if sys.version_info >= (3, 8):
def set_closure_cell(cell, value):
cell.cell_contents = value
else:
args = [co.co_argcount]
args.append(co.co_kwonlyargcount)
args.extend(
[
co.co_nlocals,
co.co_stacksize,
co.co_flags,
co.co_code,
co.co_consts,
co.co_names,
co.co_varnames,
co.co_filename,
co.co_name,
co.co_firstlineno,
co.co_lnotab,
# These two arguments are reversed:
co.co_cellvars,
co.co_freevars,
]
)
set_first_freevar_code = types.CodeType(*args)
def set_closure_cell(cell, value):
# Create a function using the set_first_freevar_code,
# whose first closure cell is `cell`. Calling it will
# change the value of that cell.
setter = types.FunctionType(
set_first_freevar_code, {}, "setter", (), (cell,)
)
# And call it to set the cell.
setter(value)
# Make sure it works on this interpreter:
def make_func_with_cell():
x = None
def func():
return x # pragma: no cover
return func
cell = make_func_with_cell().__closure__[0]
set_closure_cell(cell, 100)
if cell.cell_contents != 100:
raise AssertionError # pragma: no cover
except Exception:
return just_warn
else:
return set_closure_cell
set_closure_cell = make_set_closure_cell()
# Thread-local global to track attrs instances which are already being repr'd.
# This is needed because there is no other (thread-safe) way to pass info
# about the instances that are already being repr'd through the call stack
# in order to ensure we don't perform infinite recursion.
#
# For instance, if an instance contains a dict which contains that instance,
# we need to know that we're already repr'ing the outside instance from within
# the dict's repr() call.
#
# This lives here rather than in _make.py so that the functions in _make.py
# don't have a direct reference to the thread-local in their globals dict.
# If they have such a reference, it breaks cloudpickle.
repr_context = threading.local()

View file

@ -0,0 +1,31 @@
# SPDX-License-Identifier: MIT
__all__ = ["set_run_validators", "get_run_validators"]
_run_validators = True
def set_run_validators(run):
"""
Set whether or not validators are run. By default, they are run.
.. deprecated:: 21.3.0 It will not be removed, but it also will not be
moved to new ``attrs`` namespace. Use `attrs.validators.set_disabled()`
instead.
"""
if not isinstance(run, bool):
raise TypeError("'run' must be bool.")
global _run_validators
_run_validators = run
def get_run_validators():
"""
Return whether or not validators are run.
.. deprecated:: 21.3.0 It will not be removed, but it also will not be
moved to new ``attrs`` namespace. Use `attrs.validators.get_disabled()`
instead.
"""
return _run_validators

View file

@ -0,0 +1,420 @@
# SPDX-License-Identifier: MIT
import copy
from ._make import NOTHING, _obj_setattr, fields
from .exceptions import AttrsAttributeNotFoundError
def asdict(
inst,
recurse=True,
filter=None,
dict_factory=dict,
retain_collection_types=False,
value_serializer=None,
):
"""
Return the ``attrs`` attribute values of *inst* as a dict.
Optionally recurse into other ``attrs``-decorated classes.
:param inst: Instance of an ``attrs``-decorated class.
:param bool recurse: Recurse into classes that are also
``attrs``-decorated.
:param callable filter: A callable whose return code determines whether an
attribute or element is included (``True``) or dropped (``False``). Is
called with the `attrs.Attribute` as the first argument and the
value as the second argument.
:param callable dict_factory: A callable to produce dictionaries from. For
example, to produce ordered dictionaries instead of normal Python
dictionaries, pass in ``collections.OrderedDict``.
:param bool retain_collection_types: Do not convert to ``list`` when
encountering an attribute whose type is ``tuple`` or ``set``. Only
meaningful if ``recurse`` is ``True``.
:param Optional[callable] value_serializer: A hook that is called for every
attribute or dict key/value. It receives the current instance, field
and value and must return the (updated) value. The hook is run *after*
the optional *filter* has been applied.
:rtype: return type of *dict_factory*
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 16.0.0 *dict_factory*
.. versionadded:: 16.1.0 *retain_collection_types*
.. versionadded:: 20.3.0 *value_serializer*
.. versionadded:: 21.3.0 If a dict has a collection for a key, it is
serialized as a tuple.
"""
attrs = fields(inst.__class__)
rv = dict_factory()
for a in attrs:
v = getattr(inst, a.name)
if filter is not None and not filter(a, v):
continue
if value_serializer is not None:
v = value_serializer(inst, a, v)
if recurse is True:
if has(v.__class__):
rv[a.name] = asdict(
v,
recurse=True,
filter=filter,
dict_factory=dict_factory,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
)
elif isinstance(v, (tuple, list, set, frozenset)):
cf = v.__class__ if retain_collection_types is True else list
rv[a.name] = cf(
[
_asdict_anything(
i,
is_key=False,
filter=filter,
dict_factory=dict_factory,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
)
for i in v
]
)
elif isinstance(v, dict):
df = dict_factory
rv[a.name] = df(
(
_asdict_anything(
kk,
is_key=True,
filter=filter,
dict_factory=df,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
),
_asdict_anything(
vv,
is_key=False,
filter=filter,
dict_factory=df,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
),
)
for kk, vv in v.items()
)
else:
rv[a.name] = v
else:
rv[a.name] = v
return rv
def _asdict_anything(
val,
is_key,
filter,
dict_factory,
retain_collection_types,
value_serializer,
):
"""
``asdict`` only works on attrs instances, this works on anything.
"""
if getattr(val.__class__, "__attrs_attrs__", None) is not None:
# Attrs class.
rv = asdict(
val,
recurse=True,
filter=filter,
dict_factory=dict_factory,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
)
elif isinstance(val, (tuple, list, set, frozenset)):
if retain_collection_types is True:
cf = val.__class__
elif is_key:
cf = tuple
else:
cf = list
rv = cf(
[
_asdict_anything(
i,
is_key=False,
filter=filter,
dict_factory=dict_factory,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
)
for i in val
]
)
elif isinstance(val, dict):
df = dict_factory
rv = df(
(
_asdict_anything(
kk,
is_key=True,
filter=filter,
dict_factory=df,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
),
_asdict_anything(
vv,
is_key=False,
filter=filter,
dict_factory=df,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
),
)
for kk, vv in val.items()
)
else:
rv = val
if value_serializer is not None:
rv = value_serializer(None, None, rv)
return rv
def astuple(
inst,
recurse=True,
filter=None,
tuple_factory=tuple,
retain_collection_types=False,
):
"""
Return the ``attrs`` attribute values of *inst* as a tuple.
Optionally recurse into other ``attrs``-decorated classes.
:param inst: Instance of an ``attrs``-decorated class.
:param bool recurse: Recurse into classes that are also
``attrs``-decorated.
:param callable filter: A callable whose return code determines whether an
attribute or element is included (``True``) or dropped (``False``). Is
called with the `attrs.Attribute` as the first argument and the
value as the second argument.
:param callable tuple_factory: A callable to produce tuples from. For
example, to produce lists instead of tuples.
:param bool retain_collection_types: Do not convert to ``list``
or ``dict`` when encountering an attribute which type is
``tuple``, ``dict`` or ``set``. Only meaningful if ``recurse`` is
``True``.
:rtype: return type of *tuple_factory*
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 16.2.0
"""
attrs = fields(inst.__class__)
rv = []
retain = retain_collection_types # Very long. :/
for a in attrs:
v = getattr(inst, a.name)
if filter is not None and not filter(a, v):
continue
if recurse is True:
if has(v.__class__):
rv.append(
astuple(
v,
recurse=True,
filter=filter,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
)
elif isinstance(v, (tuple, list, set, frozenset)):
cf = v.__class__ if retain is True else list
rv.append(
cf(
[
astuple(
j,
recurse=True,
filter=filter,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(j.__class__)
else j
for j in v
]
)
)
elif isinstance(v, dict):
df = v.__class__ if retain is True else dict
rv.append(
df(
(
astuple(
kk,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(kk.__class__)
else kk,
astuple(
vv,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(vv.__class__)
else vv,
)
for kk, vv in v.items()
)
)
else:
rv.append(v)
else:
rv.append(v)
return rv if tuple_factory is list else tuple_factory(rv)
def has(cls):
"""
Check whether *cls* is a class with ``attrs`` attributes.
:param type cls: Class to introspect.
:raise TypeError: If *cls* is not a class.
:rtype: bool
"""
return getattr(cls, "__attrs_attrs__", None) is not None
def assoc(inst, **changes):
"""
Copy *inst* and apply *changes*.
:param inst: Instance of a class with ``attrs`` attributes.
:param changes: Keyword changes in the new copy.
:return: A copy of inst with *changes* incorporated.
:raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't
be found on *cls*.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. deprecated:: 17.1.0
Use `attrs.evolve` instead if you can.
This function will not be removed du to the slightly different approach
compared to `attrs.evolve`.
"""
import warnings
warnings.warn(
"assoc is deprecated and will be removed after 2018/01.",
DeprecationWarning,
stacklevel=2,
)
new = copy.copy(inst)
attrs = fields(inst.__class__)
for k, v in changes.items():
a = getattr(attrs, k, NOTHING)
if a is NOTHING:
raise AttrsAttributeNotFoundError(
"{k} is not an attrs attribute on {cl}.".format(
k=k, cl=new.__class__
)
)
_obj_setattr(new, k, v)
return new
def evolve(inst, **changes):
"""
Create a new instance, based on *inst* with *changes* applied.
:param inst: Instance of a class with ``attrs`` attributes.
:param changes: Keyword changes in the new copy.
:return: A copy of inst with *changes* incorporated.
:raise TypeError: If *attr_name* couldn't be found in the class
``__init__``.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 17.1.0
"""
cls = inst.__class__
attrs = fields(cls)
for a in attrs:
if not a.init:
continue
attr_name = a.name # To deal with private attributes.
init_name = attr_name if attr_name[0] != "_" else attr_name[1:]
if init_name not in changes:
changes[init_name] = getattr(inst, attr_name)
return cls(**changes)
def resolve_types(cls, globalns=None, localns=None, attribs=None):
"""
Resolve any strings and forward annotations in type annotations.
This is only required if you need concrete types in `Attribute`'s *type*
field. In other words, you don't need to resolve your types if you only
use them for static type checking.
With no arguments, names will be looked up in the module in which the class
was created. If this is not what you want, e.g. if the name only exists
inside a method, you may pass *globalns* or *localns* to specify other
dictionaries in which to look up these names. See the docs of
`typing.get_type_hints` for more details.
:param type cls: Class to resolve.
:param Optional[dict] globalns: Dictionary containing global variables.
:param Optional[dict] localns: Dictionary containing local variables.
:param Optional[list] attribs: List of attribs for the given class.
This is necessary when calling from inside a ``field_transformer``
since *cls* is not an ``attrs`` class yet.
:raise TypeError: If *cls* is not a class.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class and you didn't pass any attribs.
:raise NameError: If types cannot be resolved because of missing variables.
:returns: *cls* so you can use this function also as a class decorator.
Please note that you have to apply it **after** `attrs.define`. That
means the decorator has to come in the line **before** `attrs.define`.
.. versionadded:: 20.1.0
.. versionadded:: 21.1.0 *attribs*
"""
# Since calling get_type_hints is expensive we cache whether we've
# done it already.
if getattr(cls, "__attrs_types_resolved__", None) != cls:
import typing
hints = typing.get_type_hints(cls, globalns=globalns, localns=localns)
for field in fields(cls) if attribs is None else attribs:
if field.name in hints:
# Since fields have been frozen we must work around it.
_obj_setattr(field, "type", hints[field.name])
# We store the class we resolved so that subclasses know they haven't
# been resolved.
cls.__attrs_types_resolved__ = cls
# Return the class so you can use it as a decorator too.
return cls

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,220 @@
# SPDX-License-Identifier: MIT
"""
These are Python 3.6+-only and keyword-only APIs that call `attr.s` and
`attr.ib` with different default values.
"""
from functools import partial
from . import setters
from ._funcs import asdict as _asdict
from ._funcs import astuple as _astuple
from ._make import (
NOTHING,
_frozen_setattrs,
_ng_default_on_setattr,
attrib,
attrs,
)
from .exceptions import UnannotatedAttributeError
def define(
maybe_cls=None,
*,
these=None,
repr=None,
hash=None,
init=None,
slots=True,
frozen=False,
weakref_slot=True,
str=False,
auto_attribs=None,
kw_only=False,
cache_hash=False,
auto_exc=True,
eq=None,
order=False,
auto_detect=True,
getstate_setstate=None,
on_setattr=None,
field_transformer=None,
match_args=True,
):
r"""
Define an ``attrs`` class.
Differences to the classic `attr.s` that it uses underneath:
- Automatically detect whether or not *auto_attribs* should be `True` (c.f.
*auto_attribs* parameter).
- If *frozen* is `False`, run converters and validators when setting an
attribute by default.
- *slots=True*
.. caution::
Usually this has only upsides and few visible effects in everyday
programming. But it *can* lead to some suprising behaviors, so please
make sure to read :term:`slotted classes`.
- *auto_exc=True*
- *auto_detect=True*
- *order=False*
- Some options that were only relevant on Python 2 or were kept around for
backwards-compatibility have been removed.
Please note that these are all defaults and you can change them as you
wish.
:param Optional[bool] auto_attribs: If set to `True` or `False`, it behaves
exactly like `attr.s`. If left `None`, `attr.s` will try to guess:
1. If any attributes are annotated and no unannotated `attrs.fields`\ s
are found, it assumes *auto_attribs=True*.
2. Otherwise it assumes *auto_attribs=False* and tries to collect
`attrs.fields`\ s.
For now, please refer to `attr.s` for the rest of the parameters.
.. versionadded:: 20.1.0
.. versionchanged:: 21.3.0 Converters are also run ``on_setattr``.
"""
def do_it(cls, auto_attribs):
return attrs(
maybe_cls=cls,
these=these,
repr=repr,
hash=hash,
init=init,
slots=slots,
frozen=frozen,
weakref_slot=weakref_slot,
str=str,
auto_attribs=auto_attribs,
kw_only=kw_only,
cache_hash=cache_hash,
auto_exc=auto_exc,
eq=eq,
order=order,
auto_detect=auto_detect,
collect_by_mro=True,
getstate_setstate=getstate_setstate,
on_setattr=on_setattr,
field_transformer=field_transformer,
match_args=match_args,
)
def wrap(cls):
"""
Making this a wrapper ensures this code runs during class creation.
We also ensure that frozen-ness of classes is inherited.
"""
nonlocal frozen, on_setattr
had_on_setattr = on_setattr not in (None, setters.NO_OP)
# By default, mutable classes convert & validate on setattr.
if frozen is False and on_setattr is None:
on_setattr = _ng_default_on_setattr
# However, if we subclass a frozen class, we inherit the immutability
# and disable on_setattr.
for base_cls in cls.__bases__:
if base_cls.__setattr__ is _frozen_setattrs:
if had_on_setattr:
raise ValueError(
"Frozen classes can't use on_setattr "
"(frozen-ness was inherited)."
)
on_setattr = setters.NO_OP
break
if auto_attribs is not None:
return do_it(cls, auto_attribs)
try:
return do_it(cls, True)
except UnannotatedAttributeError:
return do_it(cls, False)
# maybe_cls's type depends on the usage of the decorator. It's a class
# if it's used as `@attrs` but ``None`` if used as `@attrs()`.
if maybe_cls is None:
return wrap
else:
return wrap(maybe_cls)
mutable = define
frozen = partial(define, frozen=True, on_setattr=None)
def field(
*,
default=NOTHING,
validator=None,
repr=True,
hash=None,
init=True,
metadata=None,
converter=None,
factory=None,
kw_only=False,
eq=None,
order=None,
on_setattr=None,
):
"""
Identical to `attr.ib`, except keyword-only and with some arguments
removed.
.. versionadded:: 20.1.0
"""
return attrib(
default=default,
validator=validator,
repr=repr,
hash=hash,
init=init,
metadata=metadata,
converter=converter,
factory=factory,
kw_only=kw_only,
eq=eq,
order=order,
on_setattr=on_setattr,
)
def asdict(inst, *, recurse=True, filter=None, value_serializer=None):
"""
Same as `attr.asdict`, except that collections types are always retained
and dict is always used as *dict_factory*.
.. versionadded:: 21.3.0
"""
return _asdict(
inst=inst,
recurse=recurse,
filter=filter,
value_serializer=value_serializer,
retain_collection_types=True,
)
def astuple(inst, *, recurse=True, filter=None):
"""
Same as `attr.astuple`, except that collections types are always retained
and `tuple` is always used as the *tuple_factory*.
.. versionadded:: 21.3.0
"""
return _astuple(
inst=inst, recurse=recurse, filter=filter, retain_collection_types=True
)

View file

@ -0,0 +1,86 @@
# SPDX-License-Identifier: MIT
from functools import total_ordering
from ._funcs import astuple
from ._make import attrib, attrs
@total_ordering
@attrs(eq=False, order=False, slots=True, frozen=True)
class VersionInfo:
"""
A version object that can be compared to tuple of length 1--4:
>>> attr.VersionInfo(19, 1, 0, "final") <= (19, 2)
True
>>> attr.VersionInfo(19, 1, 0, "final") < (19, 1, 1)
True
>>> vi = attr.VersionInfo(19, 2, 0, "final")
>>> vi < (19, 1, 1)
False
>>> vi < (19,)
False
>>> vi == (19, 2,)
True
>>> vi == (19, 2, 1)
False
.. versionadded:: 19.2
"""
year = attrib(type=int)
minor = attrib(type=int)
micro = attrib(type=int)
releaselevel = attrib(type=str)
@classmethod
def _from_version_string(cls, s):
"""
Parse *s* and return a _VersionInfo.
"""
v = s.split(".")
if len(v) == 3:
v.append("final")
return cls(
year=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3]
)
def _ensure_tuple(self, other):
"""
Ensure *other* is a tuple of a valid length.
Returns a possibly transformed *other* and ourselves as a tuple of
the same length as *other*.
"""
if self.__class__ is other.__class__:
other = astuple(other)
if not isinstance(other, tuple):
raise NotImplementedError
if not (1 <= len(other) <= 4):
raise NotImplementedError
return astuple(self)[: len(other)], other
def __eq__(self, other):
try:
us, them = self._ensure_tuple(other)
except NotImplementedError:
return NotImplemented
return us == them
def __lt__(self, other):
try:
us, them = self._ensure_tuple(other)
except NotImplementedError:
return NotImplemented
# Since alphabetically "dev0" < "final" < "post1" < "post2", we don't
# have to do anything special with releaselevel for now.
return us < them

View file

@ -0,0 +1,9 @@
class VersionInfo:
@property
def year(self) -> int: ...
@property
def minor(self) -> int: ...
@property
def micro(self) -> int: ...
@property
def releaselevel(self) -> str: ...

View file

@ -0,0 +1,144 @@
# SPDX-License-Identifier: MIT
"""
Commonly useful converters.
"""
import typing
from ._compat import _AnnotationExtractor
from ._make import NOTHING, Factory, pipe
__all__ = [
"default_if_none",
"optional",
"pipe",
"to_bool",
]
def optional(converter):
"""
A converter that allows an attribute to be optional. An optional attribute
is one which can be set to ``None``.
Type annotations will be inferred from the wrapped converter's, if it
has any.
:param callable converter: the converter that is used for non-``None``
values.
.. versionadded:: 17.1.0
"""
def optional_converter(val):
if val is None:
return None
return converter(val)
xtr = _AnnotationExtractor(converter)
t = xtr.get_first_param_type()
if t:
optional_converter.__annotations__["val"] = typing.Optional[t]
rt = xtr.get_return_type()
if rt:
optional_converter.__annotations__["return"] = typing.Optional[rt]
return optional_converter
def default_if_none(default=NOTHING, factory=None):
"""
A converter that allows to replace ``None`` values by *default* or the
result of *factory*.
:param default: Value to be used if ``None`` is passed. Passing an instance
of `attrs.Factory` is supported, however the ``takes_self`` option
is *not*.
:param callable factory: A callable that takes no parameters whose result
is used if ``None`` is passed.
:raises TypeError: If **neither** *default* or *factory* is passed.
:raises TypeError: If **both** *default* and *factory* are passed.
:raises ValueError: If an instance of `attrs.Factory` is passed with
``takes_self=True``.
.. versionadded:: 18.2.0
"""
if default is NOTHING and factory is None:
raise TypeError("Must pass either `default` or `factory`.")
if default is not NOTHING and factory is not None:
raise TypeError(
"Must pass either `default` or `factory` but not both."
)
if factory is not None:
default = Factory(factory)
if isinstance(default, Factory):
if default.takes_self:
raise ValueError(
"`takes_self` is not supported by default_if_none."
)
def default_if_none_converter(val):
if val is not None:
return val
return default.factory()
else:
def default_if_none_converter(val):
if val is not None:
return val
return default
return default_if_none_converter
def to_bool(val):
"""
Convert "boolean" strings (e.g., from env. vars.) to real booleans.
Values mapping to :code:`True`:
- :code:`True`
- :code:`"true"` / :code:`"t"`
- :code:`"yes"` / :code:`"y"`
- :code:`"on"`
- :code:`"1"`
- :code:`1`
Values mapping to :code:`False`:
- :code:`False`
- :code:`"false"` / :code:`"f"`
- :code:`"no"` / :code:`"n"`
- :code:`"off"`
- :code:`"0"`
- :code:`0`
:raises ValueError: for any other value.
.. versionadded:: 21.3.0
"""
if isinstance(val, str):
val = val.lower()
truthy = {True, "true", "t", "yes", "y", "on", "1", 1}
falsy = {False, "false", "f", "no", "n", "off", "0", 0}
try:
if val in truthy:
return True
if val in falsy:
return False
except TypeError:
# Raised when "val" is not hashable (e.g., lists)
pass
raise ValueError("Cannot convert value to bool: {}".format(val))

View file

@ -0,0 +1,13 @@
from typing import Callable, Optional, TypeVar, overload
from . import _ConverterType
_T = TypeVar("_T")
def pipe(*validators: _ConverterType) -> _ConverterType: ...
def optional(converter: _ConverterType) -> _ConverterType: ...
@overload
def default_if_none(default: _T) -> _ConverterType: ...
@overload
def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType: ...
def to_bool(val: str) -> bool: ...

View file

@ -0,0 +1,92 @@
# SPDX-License-Identifier: MIT
class FrozenError(AttributeError):
"""
A frozen/immutable instance or attribute have been attempted to be
modified.
It mirrors the behavior of ``namedtuples`` by using the same error message
and subclassing `AttributeError`.
.. versionadded:: 20.1.0
"""
msg = "can't set attribute"
args = [msg]
class FrozenInstanceError(FrozenError):
"""
A frozen instance has been attempted to be modified.
.. versionadded:: 16.1.0
"""
class FrozenAttributeError(FrozenError):
"""
A frozen attribute has been attempted to be modified.
.. versionadded:: 20.1.0
"""
class AttrsAttributeNotFoundError(ValueError):
"""
An ``attrs`` function couldn't find an attribute that the user asked for.
.. versionadded:: 16.2.0
"""
class NotAnAttrsClassError(ValueError):
"""
A non-``attrs`` class has been passed into an ``attrs`` function.
.. versionadded:: 16.2.0
"""
class DefaultAlreadySetError(RuntimeError):
"""
A default has been set using ``attr.ib()`` and is attempted to be reset
using the decorator.
.. versionadded:: 17.1.0
"""
class UnannotatedAttributeError(RuntimeError):
"""
A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type
annotation.
.. versionadded:: 17.3.0
"""
class PythonTooOldError(RuntimeError):
"""
It was attempted to use an ``attrs`` feature that requires a newer Python
version.
.. versionadded:: 18.2.0
"""
class NotCallableError(TypeError):
"""
A ``attr.ib()`` requiring a callable has been set with a value
that is not callable.
.. versionadded:: 19.2.0
"""
def __init__(self, msg, value):
super(TypeError, self).__init__(msg, value)
self.msg = msg
self.value = value
def __str__(self):
return str(self.msg)

View file

@ -0,0 +1,17 @@
from typing import Any
class FrozenError(AttributeError):
msg: str = ...
class FrozenInstanceError(FrozenError): ...
class FrozenAttributeError(FrozenError): ...
class AttrsAttributeNotFoundError(ValueError): ...
class NotAnAttrsClassError(ValueError): ...
class DefaultAlreadySetError(RuntimeError): ...
class UnannotatedAttributeError(RuntimeError): ...
class PythonTooOldError(RuntimeError): ...
class NotCallableError(TypeError):
msg: str = ...
value: Any = ...
def __init__(self, msg: str, value: Any) -> None: ...

View file

@ -0,0 +1,51 @@
# SPDX-License-Identifier: MIT
"""
Commonly useful filters for `attr.asdict`.
"""
from ._make import Attribute
def _split_what(what):
"""
Returns a tuple of `frozenset`s of classes and attributes.
"""
return (
frozenset(cls for cls in what if isinstance(cls, type)),
frozenset(cls for cls in what if isinstance(cls, Attribute)),
)
def include(*what):
"""
Include *what*.
:param what: What to include.
:type what: `list` of `type` or `attrs.Attribute`\\ s
:rtype: `callable`
"""
cls, attrs = _split_what(what)
def include_(attribute, value):
return value.__class__ in cls or attribute in attrs
return include_
def exclude(*what):
"""
Exclude *what*.
:param what: What to exclude.
:type what: `list` of classes or `attrs.Attribute`\\ s.
:rtype: `callable`
"""
cls, attrs = _split_what(what)
def exclude_(attribute, value):
return value.__class__ not in cls and attribute not in attrs
return exclude_

View file

@ -0,0 +1,6 @@
from typing import Any, Union
from . import Attribute, _FilterType
def include(*what: Union[type, Attribute[Any]]) -> _FilterType[Any]: ...
def exclude(*what: Union[type, Attribute[Any]]) -> _FilterType[Any]: ...

View file

View file

@ -0,0 +1,73 @@
# SPDX-License-Identifier: MIT
"""
Commonly used hooks for on_setattr.
"""
from . import _config
from .exceptions import FrozenAttributeError
def pipe(*setters):
"""
Run all *setters* and return the return value of the last one.
.. versionadded:: 20.1.0
"""
def wrapped_pipe(instance, attrib, new_value):
rv = new_value
for setter in setters:
rv = setter(instance, attrib, rv)
return rv
return wrapped_pipe
def frozen(_, __, ___):
"""
Prevent an attribute to be modified.
.. versionadded:: 20.1.0
"""
raise FrozenAttributeError()
def validate(instance, attrib, new_value):
"""
Run *attrib*'s validator on *new_value* if it has one.
.. versionadded:: 20.1.0
"""
if _config._run_validators is False:
return new_value
v = attrib.validator
if not v:
return new_value
v(instance, attrib, new_value)
return new_value
def convert(instance, attrib, new_value):
"""
Run *attrib*'s converter -- if it has one -- on *new_value* and return the
result.
.. versionadded:: 20.1.0
"""
c = attrib.converter
if c:
return c(new_value)
return new_value
# Sentinel for disabling class-wide *on_setattr* hooks for certain attributes.
# autodata stopped working, so the docstring is inlined in the API docs.
NO_OP = object()

View file

@ -0,0 +1,19 @@
from typing import Any, NewType, NoReturn, TypeVar, cast
from . import Attribute, _OnSetAttrType
_T = TypeVar("_T")
def frozen(
instance: Any, attribute: Attribute[Any], new_value: Any
) -> NoReturn: ...
def pipe(*setters: _OnSetAttrType) -> _OnSetAttrType: ...
def validate(instance: Any, attribute: Attribute[_T], new_value: _T) -> _T: ...
# convert is allowed to return Any, because they can be chained using pipe.
def convert(
instance: Any, attribute: Attribute[Any], new_value: Any
) -> Any: ...
_NoOpType = NewType("_NoOpType", object)
NO_OP: _NoOpType

View file

@ -0,0 +1,594 @@
# SPDX-License-Identifier: MIT
"""
Commonly useful validators.
"""
import operator
import re
from contextlib import contextmanager
from ._config import get_run_validators, set_run_validators
from ._make import _AndValidator, and_, attrib, attrs
from .exceptions import NotCallableError
try:
Pattern = re.Pattern
except AttributeError: # Python <3.7 lacks a Pattern type.
Pattern = type(re.compile(""))
__all__ = [
"and_",
"deep_iterable",
"deep_mapping",
"disabled",
"ge",
"get_disabled",
"gt",
"in_",
"instance_of",
"is_callable",
"le",
"lt",
"matches_re",
"max_len",
"min_len",
"optional",
"provides",
"set_disabled",
]
def set_disabled(disabled):
"""
Globally disable or enable running validators.
By default, they are run.
:param disabled: If ``True``, disable running all validators.
:type disabled: bool
.. warning::
This function is not thread-safe!
.. versionadded:: 21.3.0
"""
set_run_validators(not disabled)
def get_disabled():
"""
Return a bool indicating whether validators are currently disabled or not.
:return: ``True`` if validators are currently disabled.
:rtype: bool
.. versionadded:: 21.3.0
"""
return not get_run_validators()
@contextmanager
def disabled():
"""
Context manager that disables running validators within its context.
.. warning::
This context manager is not thread-safe!
.. versionadded:: 21.3.0
"""
set_run_validators(False)
try:
yield
finally:
set_run_validators(True)
@attrs(repr=False, slots=True, hash=True)
class _InstanceOfValidator:
type = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not isinstance(value, self.type):
raise TypeError(
"'{name}' must be {type!r} (got {value!r} that is a "
"{actual!r}).".format(
name=attr.name,
type=self.type,
actual=value.__class__,
value=value,
),
attr,
self.type,
value,
)
def __repr__(self):
return "<instance_of validator for type {type!r}>".format(
type=self.type
)
def instance_of(type):
"""
A validator that raises a `TypeError` if the initializer is called
with a wrong type for this particular attribute (checks are performed using
`isinstance` therefore it's also valid to pass a tuple of types).
:param type: The type to check for.
:type type: type or tuple of types
:raises TypeError: With a human readable error message, the attribute
(of type `attrs.Attribute`), the expected type, and the value it
got.
"""
return _InstanceOfValidator(type)
@attrs(repr=False, frozen=True, slots=True)
class _MatchesReValidator:
pattern = attrib()
match_func = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.match_func(value):
raise ValueError(
"'{name}' must match regex {pattern!r}"
" ({value!r} doesn't)".format(
name=attr.name, pattern=self.pattern.pattern, value=value
),
attr,
self.pattern,
value,
)
def __repr__(self):
return "<matches_re validator for pattern {pattern!r}>".format(
pattern=self.pattern
)
def matches_re(regex, flags=0, func=None):
r"""
A validator that raises `ValueError` if the initializer is called
with a string that doesn't match *regex*.
:param regex: a regex string or precompiled pattern to match against
:param int flags: flags that will be passed to the underlying re function
(default 0)
:param callable func: which underlying `re` function to call. Valid options
are `re.fullmatch`, `re.search`, and `re.match`; the default ``None``
means `re.fullmatch`. For performance reasons, the pattern is always
precompiled using `re.compile`.
.. versionadded:: 19.2.0
.. versionchanged:: 21.3.0 *regex* can be a pre-compiled pattern.
"""
valid_funcs = (re.fullmatch, None, re.search, re.match)
if func not in valid_funcs:
raise ValueError(
"'func' must be one of {}.".format(
", ".join(
sorted(
e and e.__name__ or "None" for e in set(valid_funcs)
)
)
)
)
if isinstance(regex, Pattern):
if flags:
raise TypeError(
"'flags' can only be used with a string pattern; "
"pass flags to re.compile() instead"
)
pattern = regex
else:
pattern = re.compile(regex, flags)
if func is re.match:
match_func = pattern.match
elif func is re.search:
match_func = pattern.search
else:
match_func = pattern.fullmatch
return _MatchesReValidator(pattern, match_func)
@attrs(repr=False, slots=True, hash=True)
class _ProvidesValidator:
interface = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.interface.providedBy(value):
raise TypeError(
"'{name}' must provide {interface!r} which {value!r} "
"doesn't.".format(
name=attr.name, interface=self.interface, value=value
),
attr,
self.interface,
value,
)
def __repr__(self):
return "<provides validator for interface {interface!r}>".format(
interface=self.interface
)
def provides(interface):
"""
A validator that raises a `TypeError` if the initializer is called
with an object that does not provide the requested *interface* (checks are
performed using ``interface.providedBy(value)`` (see `zope.interface
<https://zopeinterface.readthedocs.io/en/latest/>`_).
:param interface: The interface to check for.
:type interface: ``zope.interface.Interface``
:raises TypeError: With a human readable error message, the attribute
(of type `attrs.Attribute`), the expected interface, and the
value it got.
"""
return _ProvidesValidator(interface)
@attrs(repr=False, slots=True, hash=True)
class _OptionalValidator:
validator = attrib()
def __call__(self, inst, attr, value):
if value is None:
return
self.validator(inst, attr, value)
def __repr__(self):
return "<optional validator for {what} or None>".format(
what=repr(self.validator)
)
def optional(validator):
"""
A validator that makes an attribute optional. An optional attribute is one
which can be set to ``None`` in addition to satisfying the requirements of
the sub-validator.
:param validator: A validator (or a list of validators) that is used for
non-``None`` values.
:type validator: callable or `list` of callables.
.. versionadded:: 15.1.0
.. versionchanged:: 17.1.0 *validator* can be a list of validators.
"""
if isinstance(validator, list):
return _OptionalValidator(_AndValidator(validator))
return _OptionalValidator(validator)
@attrs(repr=False, slots=True, hash=True)
class _InValidator:
options = attrib()
def __call__(self, inst, attr, value):
try:
in_options = value in self.options
except TypeError: # e.g. `1 in "abc"`
in_options = False
if not in_options:
raise ValueError(
"'{name}' must be in {options!r} (got {value!r})".format(
name=attr.name, options=self.options, value=value
),
attr,
self.options,
value,
)
def __repr__(self):
return "<in_ validator with options {options!r}>".format(
options=self.options
)
def in_(options):
"""
A validator that raises a `ValueError` if the initializer is called
with a value that does not belong in the options provided. The check is
performed using ``value in options``.
:param options: Allowed options.
:type options: list, tuple, `enum.Enum`, ...
:raises ValueError: With a human readable error message, the attribute (of
type `attrs.Attribute`), the expected options, and the value it
got.
.. versionadded:: 17.1.0
.. versionchanged:: 22.1.0
The ValueError was incomplete until now and only contained the human
readable error message. Now it contains all the information that has
been promised since 17.1.0.
"""
return _InValidator(options)
@attrs(repr=False, slots=False, hash=True)
class _IsCallableValidator:
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not callable(value):
message = (
"'{name}' must be callable "
"(got {value!r} that is a {actual!r})."
)
raise NotCallableError(
msg=message.format(
name=attr.name, value=value, actual=value.__class__
),
value=value,
)
def __repr__(self):
return "<is_callable validator>"
def is_callable():
"""
A validator that raises a `attr.exceptions.NotCallableError` if the
initializer is called with a value for this particular attribute
that is not callable.
.. versionadded:: 19.1.0
:raises `attr.exceptions.NotCallableError`: With a human readable error
message containing the attribute (`attrs.Attribute`) name,
and the value it got.
"""
return _IsCallableValidator()
@attrs(repr=False, slots=True, hash=True)
class _DeepIterable:
member_validator = attrib(validator=is_callable())
iterable_validator = attrib(
default=None, validator=optional(is_callable())
)
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if self.iterable_validator is not None:
self.iterable_validator(inst, attr, value)
for member in value:
self.member_validator(inst, attr, member)
def __repr__(self):
iterable_identifier = (
""
if self.iterable_validator is None
else " {iterable!r}".format(iterable=self.iterable_validator)
)
return (
"<deep_iterable validator for{iterable_identifier}"
" iterables of {member!r}>"
).format(
iterable_identifier=iterable_identifier,
member=self.member_validator,
)
def deep_iterable(member_validator, iterable_validator=None):
"""
A validator that performs deep validation of an iterable.
:param member_validator: Validator(s) to apply to iterable members
:param iterable_validator: Validator to apply to iterable itself
(optional)
.. versionadded:: 19.1.0
:raises TypeError: if any sub-validators fail
"""
if isinstance(member_validator, (list, tuple)):
member_validator = and_(*member_validator)
return _DeepIterable(member_validator, iterable_validator)
@attrs(repr=False, slots=True, hash=True)
class _DeepMapping:
key_validator = attrib(validator=is_callable())
value_validator = attrib(validator=is_callable())
mapping_validator = attrib(default=None, validator=optional(is_callable()))
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if self.mapping_validator is not None:
self.mapping_validator(inst, attr, value)
for key in value:
self.key_validator(inst, attr, key)
self.value_validator(inst, attr, value[key])
def __repr__(self):
return (
"<deep_mapping validator for objects mapping {key!r} to {value!r}>"
).format(key=self.key_validator, value=self.value_validator)
def deep_mapping(key_validator, value_validator, mapping_validator=None):
"""
A validator that performs deep validation of a dictionary.
:param key_validator: Validator to apply to dictionary keys
:param value_validator: Validator to apply to dictionary values
:param mapping_validator: Validator to apply to top-level mapping
attribute (optional)
.. versionadded:: 19.1.0
:raises TypeError: if any sub-validators fail
"""
return _DeepMapping(key_validator, value_validator, mapping_validator)
@attrs(repr=False, frozen=True, slots=True)
class _NumberValidator:
bound = attrib()
compare_op = attrib()
compare_func = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.compare_func(value, self.bound):
raise ValueError(
"'{name}' must be {op} {bound}: {value}".format(
name=attr.name,
op=self.compare_op,
bound=self.bound,
value=value,
)
)
def __repr__(self):
return "<Validator for x {op} {bound}>".format(
op=self.compare_op, bound=self.bound
)
def lt(val):
"""
A validator that raises `ValueError` if the initializer is called
with a number larger or equal to *val*.
:param val: Exclusive upper bound for values
.. versionadded:: 21.3.0
"""
return _NumberValidator(val, "<", operator.lt)
def le(val):
"""
A validator that raises `ValueError` if the initializer is called
with a number greater than *val*.
:param val: Inclusive upper bound for values
.. versionadded:: 21.3.0
"""
return _NumberValidator(val, "<=", operator.le)
def ge(val):
"""
A validator that raises `ValueError` if the initializer is called
with a number smaller than *val*.
:param val: Inclusive lower bound for values
.. versionadded:: 21.3.0
"""
return _NumberValidator(val, ">=", operator.ge)
def gt(val):
"""
A validator that raises `ValueError` if the initializer is called
with a number smaller or equal to *val*.
:param val: Exclusive lower bound for values
.. versionadded:: 21.3.0
"""
return _NumberValidator(val, ">", operator.gt)
@attrs(repr=False, frozen=True, slots=True)
class _MaxLengthValidator:
max_length = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if len(value) > self.max_length:
raise ValueError(
"Length of '{name}' must be <= {max}: {len}".format(
name=attr.name, max=self.max_length, len=len(value)
)
)
def __repr__(self):
return "<max_len validator for {max}>".format(max=self.max_length)
def max_len(length):
"""
A validator that raises `ValueError` if the initializer is called
with a string or iterable that is longer than *length*.
:param int length: Maximum length of the string or iterable
.. versionadded:: 21.3.0
"""
return _MaxLengthValidator(length)
@attrs(repr=False, frozen=True, slots=True)
class _MinLengthValidator:
min_length = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if len(value) < self.min_length:
raise ValueError(
"Length of '{name}' must be => {min}: {len}".format(
name=attr.name, min=self.min_length, len=len(value)
)
)
def __repr__(self):
return "<min_len validator for {min}>".format(min=self.min_length)
def min_len(length):
"""
A validator that raises `ValueError` if the initializer is called
with a string or iterable that is shorter than *length*.
:param int length: Minimum length of the string or iterable
.. versionadded:: 22.1.0
"""
return _MinLengthValidator(length)

View file

@ -0,0 +1,80 @@
from typing import (
Any,
AnyStr,
Callable,
Container,
ContextManager,
Iterable,
List,
Mapping,
Match,
Optional,
Pattern,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from . import _ValidatorType
from . import _ValidatorArgType
_T = TypeVar("_T")
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
_I = TypeVar("_I", bound=Iterable)
_K = TypeVar("_K")
_V = TypeVar("_V")
_M = TypeVar("_M", bound=Mapping)
def set_disabled(run: bool) -> None: ...
def get_disabled() -> bool: ...
def disabled() -> ContextManager[None]: ...
# To be more precise on instance_of use some overloads.
# If there are more than 3 items in the tuple then we fall back to Any
@overload
def instance_of(type: Type[_T]) -> _ValidatorType[_T]: ...
@overload
def instance_of(type: Tuple[Type[_T]]) -> _ValidatorType[_T]: ...
@overload
def instance_of(
type: Tuple[Type[_T1], Type[_T2]]
) -> _ValidatorType[Union[_T1, _T2]]: ...
@overload
def instance_of(
type: Tuple[Type[_T1], Type[_T2], Type[_T3]]
) -> _ValidatorType[Union[_T1, _T2, _T3]]: ...
@overload
def instance_of(type: Tuple[type, ...]) -> _ValidatorType[Any]: ...
def provides(interface: Any) -> _ValidatorType[Any]: ...
def optional(
validator: Union[_ValidatorType[_T], List[_ValidatorType[_T]]]
) -> _ValidatorType[Optional[_T]]: ...
def in_(options: Container[_T]) -> _ValidatorType[_T]: ...
def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
def matches_re(
regex: Union[Pattern[AnyStr], AnyStr],
flags: int = ...,
func: Optional[
Callable[[AnyStr, AnyStr, int], Optional[Match[AnyStr]]]
] = ...,
) -> _ValidatorType[AnyStr]: ...
def deep_iterable(
member_validator: _ValidatorArgType[_T],
iterable_validator: Optional[_ValidatorType[_I]] = ...,
) -> _ValidatorType[_I]: ...
def deep_mapping(
key_validator: _ValidatorType[_K],
value_validator: _ValidatorType[_V],
mapping_validator: Optional[_ValidatorType[_M]] = ...,
) -> _ValidatorType[_M]: ...
def is_callable() -> _ValidatorType[_T]: ...
def lt(val: _T) -> _ValidatorType[_T]: ...
def le(val: _T) -> _ValidatorType[_T]: ...
def ge(val: _T) -> _ValidatorType[_T]: ...
def gt(val: _T) -> _ValidatorType[_T]: ...
def max_len(length: int) -> _ValidatorType[_T]: ...
def min_len(length: int) -> _ValidatorType[_T]: ...

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2015 Hynek Schlawack and the attrs contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,70 @@
# SPDX-License-Identifier: MIT
from attr import (
NOTHING,
Attribute,
Factory,
__author__,
__copyright__,
__description__,
__doc__,
__email__,
__license__,
__title__,
__url__,
__version__,
__version_info__,
assoc,
cmp_using,
define,
evolve,
field,
fields,
fields_dict,
frozen,
has,
make_class,
mutable,
resolve_types,
validate,
)
from attr._next_gen import asdict, astuple
from . import converters, exceptions, filters, setters, validators
__all__ = [
"__author__",
"__copyright__",
"__description__",
"__doc__",
"__email__",
"__license__",
"__title__",
"__url__",
"__version__",
"__version_info__",
"asdict",
"assoc",
"astuple",
"Attribute",
"cmp_using",
"converters",
"define",
"evolve",
"exceptions",
"Factory",
"field",
"fields_dict",
"fields",
"filters",
"frozen",
"has",
"make_class",
"mutable",
"NOTHING",
"resolve_types",
"setters",
"validate",
"validators",
]

View file

@ -0,0 +1,66 @@
from typing import (
Any,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Tuple,
Type,
)
# Because we need to type our own stuff, we have to make everything from
# attr explicitly public too.
from attr import __author__ as __author__
from attr import __copyright__ as __copyright__
from attr import __description__ as __description__
from attr import __email__ as __email__
from attr import __license__ as __license__
from attr import __title__ as __title__
from attr import __url__ as __url__
from attr import __version__ as __version__
from attr import __version_info__ as __version_info__
from attr import _FilterType
from attr import assoc as assoc
from attr import Attribute as Attribute
from attr import cmp_using as cmp_using
from attr import converters as converters
from attr import define as define
from attr import evolve as evolve
from attr import exceptions as exceptions
from attr import Factory as Factory
from attr import field as field
from attr import fields as fields
from attr import fields_dict as fields_dict
from attr import filters as filters
from attr import frozen as frozen
from attr import has as has
from attr import make_class as make_class
from attr import mutable as mutable
from attr import NOTHING as NOTHING
from attr import resolve_types as resolve_types
from attr import setters as setters
from attr import validate as validate
from attr import validators as validators
# TODO: see definition of attr.asdict/astuple
def asdict(
inst: Any,
recurse: bool = ...,
filter: Optional[_FilterType[Any]] = ...,
dict_factory: Type[Mapping[Any, Any]] = ...,
retain_collection_types: bool = ...,
value_serializer: Optional[
Callable[[type, Attribute[Any], Any], Any]
] = ...,
tuple_keys: bool = ...,
) -> Dict[str, Any]: ...
# TODO: add support for returning NamedTuple from the mypy plugin
def astuple(
inst: Any,
recurse: bool = ...,
filter: Optional[_FilterType[Any]] = ...,
tuple_factory: Type[Sequence[Any]] = ...,
retain_collection_types: bool = ...,
) -> Tuple[Any, ...]: ...

View file

@ -0,0 +1,3 @@
# SPDX-License-Identifier: MIT
from attr.converters import * # noqa

View file

@ -0,0 +1,3 @@
# SPDX-License-Identifier: MIT
from attr.exceptions import * # noqa

View file

@ -0,0 +1,3 @@
# SPDX-License-Identifier: MIT
from attr.filters import * # noqa

View file

View file

@ -0,0 +1,3 @@
# SPDX-License-Identifier: MIT
from attr.setters import * # noqa

View file

@ -0,0 +1,3 @@
# SPDX-License-Identifier: MIT
from attr.validators import * # noqa

View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,54 @@
from .distro import (
NORMALIZED_DISTRO_ID,
NORMALIZED_LSB_ID,
NORMALIZED_OS_ID,
LinuxDistribution,
__version__,
build_number,
codename,
distro_release_attr,
distro_release_info,
id,
info,
like,
linux_distribution,
lsb_release_attr,
lsb_release_info,
major_version,
minor_version,
name,
os_release_attr,
os_release_info,
uname_attr,
uname_info,
version,
version_parts,
)
__all__ = [
"NORMALIZED_DISTRO_ID",
"NORMALIZED_LSB_ID",
"NORMALIZED_OS_ID",
"LinuxDistribution",
"build_number",
"codename",
"distro_release_attr",
"distro_release_info",
"id",
"info",
"like",
"linux_distribution",
"lsb_release_attr",
"lsb_release_info",
"major_version",
"minor_version",
"name",
"os_release_attr",
"os_release_info",
"uname_attr",
"uname_info",
"version",
"version_parts",
]
__version__ = __version__

View file

@ -0,0 +1,4 @@
from .distro import main
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load diff

View file

View file

@ -0,0 +1,28 @@
Copyright 2007 Pallets
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,45 @@
"""Jinja is a template engine written in pure Python. It provides a
non-XML syntax that supports inline expressions and an optional
sandboxed environment.
"""
from .bccache import BytecodeCache as BytecodeCache
from .bccache import FileSystemBytecodeCache as FileSystemBytecodeCache
from .bccache import MemcachedBytecodeCache as MemcachedBytecodeCache
from .environment import Environment as Environment
from .environment import Template as Template
from .exceptions import TemplateAssertionError as TemplateAssertionError
from .exceptions import TemplateError as TemplateError
from .exceptions import TemplateNotFound as TemplateNotFound
from .exceptions import TemplateRuntimeError as TemplateRuntimeError
from .exceptions import TemplatesNotFound as TemplatesNotFound
from .exceptions import TemplateSyntaxError as TemplateSyntaxError
from .exceptions import UndefinedError as UndefinedError
from .filters import contextfilter
from .filters import environmentfilter
from .filters import evalcontextfilter
from .loaders import BaseLoader as BaseLoader
from .loaders import ChoiceLoader as ChoiceLoader
from .loaders import DictLoader as DictLoader
from .loaders import FileSystemLoader as FileSystemLoader
from .loaders import FunctionLoader as FunctionLoader
from .loaders import ModuleLoader as ModuleLoader
from .loaders import PackageLoader as PackageLoader
from .loaders import PrefixLoader as PrefixLoader
from .runtime import ChainableUndefined as ChainableUndefined
from .runtime import DebugUndefined as DebugUndefined
from .runtime import make_logging_undefined as make_logging_undefined
from .runtime import StrictUndefined as StrictUndefined
from .runtime import Undefined as Undefined
from .utils import clear_caches as clear_caches
from .utils import contextfunction
from .utils import environmentfunction
from .utils import escape
from .utils import evalcontextfunction
from .utils import is_undefined as is_undefined
from .utils import Markup
from .utils import pass_context as pass_context
from .utils import pass_environment as pass_environment
from .utils import pass_eval_context as pass_eval_context
from .utils import select_autoescape as select_autoescape
__version__ = "3.0.3"

View file

@ -0,0 +1,6 @@
import re
# generated by scripts/generate_identifier_pattern.py
pattern = re.compile(
r"[\w·̀-ͯ·҃-֑҇-ׇֽֿׁׂׅׄؐ-ًؚ-ٰٟۖ-ۜ۟-۪ۤۧۨ-ܑۭܰ-݊ަ-ް߫-߳ࠖ-࠙ࠛ-ࠣࠥ-ࠧࠩ-࡙࠭-࡛ࣔ-ࣣ࣡-ःऺ-़ा-ॏ॑-ॗॢॣঁ-ঃ়া-ৄেৈো-্ৗৢৣਁ-ਃ਼ਾ-ੂੇੈੋ-੍ੑੰੱੵઁ-ઃ઼ા-ૅે-ૉો-્ૢૣଁ-ଃ଼ା-ୄେୈୋ-୍ୖୗୢୣஂா-ூெ-ைொ-்ௗఀ-ఃా-ౄె-ైొ-్ౕౖౢౣಁ-ಃ಼ಾ-ೄೆ-ೈೊ-್ೕೖೢೣഁ-ഃാ-ൄെ-ൈൊ-്ൗൢൣංඃ්ා-ුූෘ-ෟෲෳัิ-ฺ็-๎ັິ-ູົຼ່-ໍ༹༘༙༵༷༾༿ཱ-྄྆྇ྍ-ྗྙ-ྼ࿆ါ-ှၖ-ၙၞ-ၠၢ-ၤၧ-ၭၱ-ၴႂ-ႍႏႚ-ႝ፝-፟ᜒ-᜔ᜲ-᜴ᝒᝓᝲᝳ឴-៓៝᠋-᠍ᢅᢆᢩᤠ-ᤫᤰ-᤻ᨗ-ᨛᩕ-ᩞ᩠-᩿᩼᪰-᪽ᬀ-ᬄ᬴-᭄᭫-᭳ᮀ-ᮂᮡ-ᮭ᯦-᯳ᰤ-᰷᳐-᳔᳒-᳨᳭ᳲ-᳴᳸᳹᷀-᷵᷻-᷿‿⁀⁔⃐-⃥⃜⃡-⃰℘℮⳯-⵿⳱ⷠ-〪ⷿ-゙゚〯꙯ꙴ-꙽ꚞꚟ꛰꛱ꠂ꠆ꠋꠣ-ꠧꢀꢁꢴ-ꣅ꣠-꣱ꤦ-꤭ꥇ-꥓ꦀ-ꦃ꦳-꧀ꧥꨩ-ꨶꩃꩌꩍꩻ-ꩽꪰꪲ-ꪴꪷꪸꪾ꪿꫁ꫫ-ꫯꫵ꫶ꯣ-ꯪ꯬꯭ﬞ︀-️︠-︯︳︴﹍-﹏_𐇽𐋠𐍶-𐍺𐨁-𐨃𐨅𐨆𐨌-𐨏𐨸-𐨿𐨺𐫦𐫥𑀀-𑀂𑀸-𑁆𑁿-𑂂𑂰-𑂺𑄀-𑄂𑄧-𑅳𑄴𑆀-𑆂𑆳-𑇊𑇀-𑇌𑈬-𑈷𑈾𑋟-𑋪𑌀-𑌃𑌼𑌾-𑍄𑍇𑍈𑍋-𑍍𑍗𑍢𑍣𑍦-𑍬𑍰-𑍴𑐵-𑑆𑒰-𑓃𑖯-𑖵𑖸-𑗀𑗜𑗝𑘰-𑙀𑚫-𑚷𑜝-𑜫𑰯-𑰶𑰸-𑰿𑲒-𑲧𑲩-𑲶𖫰-𖫴𖬰-𖬶𖽑-𖽾𖾏-𖾒𛲝𛲞𝅥-𝅩𝅭-𝅲𝅻-𝆂𝆅-𝆋𝆪-𝆭𝉂-𝉄𝨀-𝨶𝨻-𝩬𝩵𝪄𝪛-𝪟𝪡-𝪯𞀀-𞀆𞀈-𞀘𞀛-𞀡𞀣𞀤𞀦-𞣐𞀪-𞣖𞥄-𞥊󠄀-󠇯]+" # noqa: B950
)

View file

@ -0,0 +1,75 @@
import inspect
import typing as t
from functools import wraps
from .utils import _PassArg
from .utils import pass_eval_context
V = t.TypeVar("V")
def async_variant(normal_func): # type: ignore
def decorator(async_func): # type: ignore
pass_arg = _PassArg.from_obj(normal_func)
need_eval_context = pass_arg is None
if pass_arg is _PassArg.environment:
def is_async(args: t.Any) -> bool:
return t.cast(bool, args[0].is_async)
else:
def is_async(args: t.Any) -> bool:
return t.cast(bool, args[0].environment.is_async)
@wraps(normal_func)
def wrapper(*args, **kwargs): # type: ignore
b = is_async(args)
if need_eval_context:
args = args[1:]
if b:
return async_func(*args, **kwargs)
return normal_func(*args, **kwargs)
if need_eval_context:
wrapper = pass_eval_context(wrapper)
wrapper.jinja_async_variant = True
return wrapper
return decorator
_common_primitives = {int, float, bool, str, list, dict, tuple, type(None)}
async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V":
# Avoid a costly call to isawaitable
if type(value) in _common_primitives:
return t.cast("V", value)
if inspect.isawaitable(value):
return await t.cast("t.Awaitable[V]", value)
return t.cast("V", value)
async def auto_aiter(
iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
) -> "t.AsyncIterator[V]":
if hasattr(iterable, "__aiter__"):
async for item in t.cast("t.AsyncIterable[V]", iterable):
yield item
else:
for item in t.cast("t.Iterable[V]", iterable):
yield item
async def auto_to_list(
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
) -> t.List["V"]:
return [x async for x in auto_aiter(value)]

View file

@ -0,0 +1,364 @@
"""The optional bytecode cache system. This is useful if you have very
complex template situations and the compilation of all those templates
slows down your application too much.
Situations where this is useful are often forking web applications that
are initialized on the first request.
"""
import errno
import fnmatch
import marshal
import os
import pickle
import stat
import sys
import tempfile
import typing as t
from hashlib import sha1
from io import BytesIO
from types import CodeType
if t.TYPE_CHECKING:
import typing_extensions as te
from .environment import Environment
class _MemcachedClient(te.Protocol):
def get(self, key: str) -> bytes:
...
def set(self, key: str, value: bytes, timeout: t.Optional[int] = None) -> None:
...
bc_version = 5
# Magic bytes to identify Jinja bytecode cache files. Contains the
# Python major and minor version to avoid loading incompatible bytecode
# if a project upgrades its Python version.
bc_magic = (
b"j2"
+ pickle.dumps(bc_version, 2)
+ pickle.dumps((sys.version_info[0] << 24) | sys.version_info[1], 2)
)
class Bucket:
"""Buckets are used to store the bytecode for one template. It's created
and initialized by the bytecode cache and passed to the loading functions.
The buckets get an internal checksum from the cache assigned and use this
to automatically reject outdated cache material. Individual bytecode
cache subclasses don't have to care about cache invalidation.
"""
def __init__(self, environment: "Environment", key: str, checksum: str) -> None:
self.environment = environment
self.key = key
self.checksum = checksum
self.reset()
def reset(self) -> None:
"""Resets the bucket (unloads the bytecode)."""
self.code: t.Optional[CodeType] = None
def load_bytecode(self, f: t.BinaryIO) -> None:
"""Loads bytecode from a file or file like object."""
# make sure the magic header is correct
magic = f.read(len(bc_magic))
if magic != bc_magic:
self.reset()
return
# the source code of the file changed, we need to reload
checksum = pickle.load(f)
if self.checksum != checksum:
self.reset()
return
# if marshal_load fails then we need to reload
try:
self.code = marshal.load(f)
except (EOFError, ValueError, TypeError):
self.reset()
return
def write_bytecode(self, f: t.BinaryIO) -> None:
"""Dump the bytecode into the file or file like object passed."""
if self.code is None:
raise TypeError("can't write empty bucket")
f.write(bc_magic)
pickle.dump(self.checksum, f, 2)
marshal.dump(self.code, f)
def bytecode_from_string(self, string: bytes) -> None:
"""Load bytecode from bytes."""
self.load_bytecode(BytesIO(string))
def bytecode_to_string(self) -> bytes:
"""Return the bytecode as bytes."""
out = BytesIO()
self.write_bytecode(out)
return out.getvalue()
class BytecodeCache:
"""To implement your own bytecode cache you have to subclass this class
and override :meth:`load_bytecode` and :meth:`dump_bytecode`. Both of
these methods are passed a :class:`~jinja2.bccache.Bucket`.
A very basic bytecode cache that saves the bytecode on the file system::
from os import path
class MyCache(BytecodeCache):
def __init__(self, directory):
self.directory = directory
def load_bytecode(self, bucket):
filename = path.join(self.directory, bucket.key)
if path.exists(filename):
with open(filename, 'rb') as f:
bucket.load_bytecode(f)
def dump_bytecode(self, bucket):
filename = path.join(self.directory, bucket.key)
with open(filename, 'wb') as f:
bucket.write_bytecode(f)
A more advanced version of a filesystem based bytecode cache is part of
Jinja.
"""
def load_bytecode(self, bucket: Bucket) -> None:
"""Subclasses have to override this method to load bytecode into a
bucket. If they are not able to find code in the cache for the
bucket, it must not do anything.
"""
raise NotImplementedError()
def dump_bytecode(self, bucket: Bucket) -> None:
"""Subclasses have to override this method to write the bytecode
from a bucket back to the cache. If it unable to do so it must not
fail silently but raise an exception.
"""
raise NotImplementedError()
def clear(self) -> None:
"""Clears the cache. This method is not used by Jinja but should be
implemented to allow applications to clear the bytecode cache used
by a particular environment.
"""
def get_cache_key(
self, name: str, filename: t.Optional[t.Union[str]] = None
) -> str:
"""Returns the unique hash key for this template name."""
hash = sha1(name.encode("utf-8"))
if filename is not None:
hash.update(f"|{filename}".encode())
return hash.hexdigest()
def get_source_checksum(self, source: str) -> str:
"""Returns a checksum for the source."""
return sha1(source.encode("utf-8")).hexdigest()
def get_bucket(
self,
environment: "Environment",
name: str,
filename: t.Optional[str],
source: str,
) -> Bucket:
"""Return a cache bucket for the given template. All arguments are
mandatory but filename may be `None`.
"""
key = self.get_cache_key(name, filename)
checksum = self.get_source_checksum(source)
bucket = Bucket(environment, key, checksum)
self.load_bytecode(bucket)
return bucket
def set_bucket(self, bucket: Bucket) -> None:
"""Put the bucket into the cache."""
self.dump_bytecode(bucket)
class FileSystemBytecodeCache(BytecodeCache):
"""A bytecode cache that stores bytecode on the filesystem. It accepts
two arguments: The directory where the cache items are stored and a
pattern string that is used to build the filename.
If no directory is specified a default cache directory is selected. On
Windows the user's temp directory is used, on UNIX systems a directory
is created for the user in the system temp directory.
The pattern can be used to have multiple separate caches operate on the
same directory. The default pattern is ``'__jinja2_%s.cache'``. ``%s``
is replaced with the cache key.
>>> bcc = FileSystemBytecodeCache('/tmp/jinja_cache', '%s.cache')
This bytecode cache supports clearing of the cache using the clear method.
"""
def __init__(
self, directory: t.Optional[str] = None, pattern: str = "__jinja2_%s.cache"
) -> None:
if directory is None:
directory = self._get_default_cache_dir()
self.directory = directory
self.pattern = pattern
def _get_default_cache_dir(self) -> str:
def _unsafe_dir() -> "te.NoReturn":
raise RuntimeError(
"Cannot determine safe temp directory. You "
"need to explicitly provide one."
)
tmpdir = tempfile.gettempdir()
# On windows the temporary directory is used specific unless
# explicitly forced otherwise. We can just use that.
if os.name == "nt":
return tmpdir
if not hasattr(os, "getuid"):
_unsafe_dir()
dirname = f"_jinja2-cache-{os.getuid()}"
actual_dir = os.path.join(tmpdir, dirname)
try:
os.mkdir(actual_dir, stat.S_IRWXU)
except OSError as e:
if e.errno != errno.EEXIST:
raise
try:
os.chmod(actual_dir, stat.S_IRWXU)
actual_dir_stat = os.lstat(actual_dir)
if (
actual_dir_stat.st_uid != os.getuid()
or not stat.S_ISDIR(actual_dir_stat.st_mode)
or stat.S_IMODE(actual_dir_stat.st_mode) != stat.S_IRWXU
):
_unsafe_dir()
except OSError as e:
if e.errno != errno.EEXIST:
raise
actual_dir_stat = os.lstat(actual_dir)
if (
actual_dir_stat.st_uid != os.getuid()
or not stat.S_ISDIR(actual_dir_stat.st_mode)
or stat.S_IMODE(actual_dir_stat.st_mode) != stat.S_IRWXU
):
_unsafe_dir()
return actual_dir
def _get_cache_filename(self, bucket: Bucket) -> str:
return os.path.join(self.directory, self.pattern % (bucket.key,))
def load_bytecode(self, bucket: Bucket) -> None:
filename = self._get_cache_filename(bucket)
if os.path.exists(filename):
with open(filename, "rb") as f:
bucket.load_bytecode(f)
def dump_bytecode(self, bucket: Bucket) -> None:
with open(self._get_cache_filename(bucket), "wb") as f:
bucket.write_bytecode(f)
def clear(self) -> None:
# imported lazily here because google app-engine doesn't support
# write access on the file system and the function does not exist
# normally.
from os import remove
files = fnmatch.filter(os.listdir(self.directory), self.pattern % ("*",))
for filename in files:
try:
remove(os.path.join(self.directory, filename))
except OSError:
pass
class MemcachedBytecodeCache(BytecodeCache):
"""This class implements a bytecode cache that uses a memcache cache for
storing the information. It does not enforce a specific memcache library
(tummy's memcache or cmemcache) but will accept any class that provides
the minimal interface required.
Libraries compatible with this class:
- `cachelib <https://github.com/pallets/cachelib>`_
- `python-memcached <https://pypi.org/project/python-memcached/>`_
(Unfortunately the django cache interface is not compatible because it
does not support storing binary data, only text. You can however pass
the underlying cache client to the bytecode cache which is available
as `django.core.cache.cache._client`.)
The minimal interface for the client passed to the constructor is this:
.. class:: MinimalClientInterface
.. method:: set(key, value[, timeout])
Stores the bytecode in the cache. `value` is a string and
`timeout` the timeout of the key. If timeout is not provided
a default timeout or no timeout should be assumed, if it's
provided it's an integer with the number of seconds the cache
item should exist.
.. method:: get(key)
Returns the value for the cache key. If the item does not
exist in the cache the return value must be `None`.
The other arguments to the constructor are the prefix for all keys that
is added before the actual cache key and the timeout for the bytecode in
the cache system. We recommend a high (or no) timeout.
This bytecode cache does not support clearing of used items in the cache.
The clear method is a no-operation function.
.. versionadded:: 2.7
Added support for ignoring memcache errors through the
`ignore_memcache_errors` parameter.
"""
def __init__(
self,
client: "_MemcachedClient",
prefix: str = "jinja2/bytecode/",
timeout: t.Optional[int] = None,
ignore_memcache_errors: bool = True,
):
self.client = client
self.prefix = prefix
self.timeout = timeout
self.ignore_memcache_errors = ignore_memcache_errors
def load_bytecode(self, bucket: Bucket) -> None:
try:
code = self.client.get(self.prefix + bucket.key)
except Exception:
if not self.ignore_memcache_errors:
raise
else:
bucket.bytecode_from_string(code)
def dump_bytecode(self, bucket: Bucket) -> None:
key = self.prefix + bucket.key
value = bucket.bytecode_to_string()
try:
if self.timeout is not None:
self.client.set(key, value, self.timeout)
else:
self.client.set(key, value)
except Exception:
if not self.ignore_memcache_errors:
raise

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,20 @@
#: list of lorem ipsum words used by the lipsum() helper function
LOREM_IPSUM_WORDS = """\
a ac accumsan ad adipiscing aenean aliquam aliquet amet ante aptent arcu at
auctor augue bibendum blandit class commodo condimentum congue consectetuer
consequat conubia convallis cras cubilia cum curabitur curae cursus dapibus
diam dictum dictumst dignissim dis dolor donec dui duis egestas eget eleifend
elementum elit enim erat eros est et etiam eu euismod facilisi facilisis fames
faucibus felis fermentum feugiat fringilla fusce gravida habitant habitasse hac
hendrerit hymenaeos iaculis id imperdiet in inceptos integer interdum ipsum
justo lacinia lacus laoreet lectus leo libero ligula litora lobortis lorem
luctus maecenas magna magnis malesuada massa mattis mauris metus mi molestie
mollis montes morbi mus nam nascetur natoque nec neque netus nibh nisi nisl non
nonummy nostra nulla nullam nunc odio orci ornare parturient pede pellentesque
penatibus per pharetra phasellus placerat platea porta porttitor posuere
potenti praesent pretium primis proin pulvinar purus quam quis quisque rhoncus
ridiculus risus rutrum sagittis sapien scelerisque sed sem semper senectus sit
sociis sociosqu sodales sollicitudin suscipit suspendisse taciti tellus tempor
tempus tincidunt torquent tortor tristique turpis ullamcorper ultrices
ultricies urna ut varius vehicula vel velit venenatis vestibulum vitae vivamus
viverra volutpat vulputate"""

View file

@ -0,0 +1,259 @@
import platform
import sys
import typing as t
from types import CodeType
from types import TracebackType
from .exceptions import TemplateSyntaxError
from .utils import internal_code
from .utils import missing
if t.TYPE_CHECKING:
from .runtime import Context
def rewrite_traceback_stack(source: t.Optional[str] = None) -> BaseException:
"""Rewrite the current exception to replace any tracebacks from
within compiled template code with tracebacks that look like they
came from the template source.
This must be called within an ``except`` block.
:param source: For ``TemplateSyntaxError``, the original source if
known.
:return: The original exception with the rewritten traceback.
"""
_, exc_value, tb = sys.exc_info()
exc_value = t.cast(BaseException, exc_value)
tb = t.cast(TracebackType, tb)
if isinstance(exc_value, TemplateSyntaxError) and not exc_value.translated:
exc_value.translated = True
exc_value.source = source
# Remove the old traceback, otherwise the frames from the
# compiler still show up.
exc_value.with_traceback(None)
# Outside of runtime, so the frame isn't executing template
# code, but it still needs to point at the template.
tb = fake_traceback(
exc_value, None, exc_value.filename or "<unknown>", exc_value.lineno
)
else:
# Skip the frame for the render function.
tb = tb.tb_next
stack = []
# Build the stack of traceback object, replacing any in template
# code with the source file and line information.
while tb is not None:
# Skip frames decorated with @internalcode. These are internal
# calls that aren't useful in template debugging output.
if tb.tb_frame.f_code in internal_code:
tb = tb.tb_next
continue
template = tb.tb_frame.f_globals.get("__jinja_template__")
if template is not None:
lineno = template.get_corresponding_lineno(tb.tb_lineno)
fake_tb = fake_traceback(exc_value, tb, template.filename, lineno)
stack.append(fake_tb)
else:
stack.append(tb)
tb = tb.tb_next
tb_next = None
# Assign tb_next in reverse to avoid circular references.
for tb in reversed(stack):
tb_next = tb_set_next(tb, tb_next)
return exc_value.with_traceback(tb_next)
def fake_traceback( # type: ignore
exc_value: BaseException, tb: t.Optional[TracebackType], filename: str, lineno: int
) -> TracebackType:
"""Produce a new traceback object that looks like it came from the
template source instead of the compiled code. The filename, line
number, and location name will point to the template, and the local
variables will be the current template context.
:param exc_value: The original exception to be re-raised to create
the new traceback.
:param tb: The original traceback to get the local variables and
code info from.
:param filename: The template filename.
:param lineno: The line number in the template source.
"""
if tb is not None:
# Replace the real locals with the context that would be
# available at that point in the template.
locals = get_template_locals(tb.tb_frame.f_locals)
locals.pop("__jinja_exception__", None)
else:
locals = {}
globals = {
"__name__": filename,
"__file__": filename,
"__jinja_exception__": exc_value,
}
# Raise an exception at the correct line number.
code: CodeType = compile(
"\n" * (lineno - 1) + "raise __jinja_exception__", filename, "exec"
)
# Build a new code object that points to the template file and
# replaces the location with a block name.
location = "template"
if tb is not None:
function = tb.tb_frame.f_code.co_name
if function == "root":
location = "top-level template code"
elif function.startswith("block_"):
location = f"block {function[6:]!r}"
if sys.version_info >= (3, 8):
code = code.replace(co_name=location)
else:
code = CodeType(
code.co_argcount,
code.co_kwonlyargcount,
code.co_nlocals,
code.co_stacksize,
code.co_flags,
code.co_code,
code.co_consts,
code.co_names,
code.co_varnames,
code.co_filename,
location,
code.co_firstlineno,
code.co_lnotab,
code.co_freevars,
code.co_cellvars,
)
# Execute the new code, which is guaranteed to raise, and return
# the new traceback without this frame.
try:
exec(code, globals, locals)
except BaseException:
return sys.exc_info()[2].tb_next # type: ignore
def get_template_locals(real_locals: t.Mapping[str, t.Any]) -> t.Dict[str, t.Any]:
"""Based on the runtime locals, get the context that would be
available at that point in the template.
"""
# Start with the current template context.
ctx: "t.Optional[Context]" = real_locals.get("context")
if ctx is not None:
data: t.Dict[str, t.Any] = ctx.get_all().copy()
else:
data = {}
# Might be in a derived context that only sets local variables
# rather than pushing a context. Local variables follow the scheme
# l_depth_name. Find the highest-depth local that has a value for
# each name.
local_overrides: t.Dict[str, t.Tuple[int, t.Any]] = {}
for name, value in real_locals.items():
if not name.startswith("l_") or value is missing:
# Not a template variable, or no longer relevant.
continue
try:
_, depth_str, name = name.split("_", 2)
depth = int(depth_str)
except ValueError:
continue
cur_depth = local_overrides.get(name, (-1,))[0]
if cur_depth < depth:
local_overrides[name] = (depth, value)
# Modify the context with any derived context.
for name, (_, value) in local_overrides.items():
if value is missing:
data.pop(name, None)
else:
data[name] = value
return data
if sys.version_info >= (3, 7):
# tb_next is directly assignable as of Python 3.7
def tb_set_next(
tb: TracebackType, tb_next: t.Optional[TracebackType]
) -> TracebackType:
tb.tb_next = tb_next
return tb
elif platform.python_implementation() == "PyPy":
# PyPy might have special support, and won't work with ctypes.
try:
import tputil # type: ignore
except ImportError:
# Without tproxy support, use the original traceback.
def tb_set_next(
tb: TracebackType, tb_next: t.Optional[TracebackType]
) -> TracebackType:
return tb
else:
# With tproxy support, create a proxy around the traceback that
# returns the new tb_next.
def tb_set_next(
tb: TracebackType, tb_next: t.Optional[TracebackType]
) -> TracebackType:
def controller(op): # type: ignore
if op.opname == "__getattribute__" and op.args[0] == "tb_next":
return tb_next
return op.delegate()
return tputil.make_proxy(controller, obj=tb) # type: ignore
else:
# Use ctypes to assign tb_next at the C level since it's read-only
# from Python.
import ctypes
class _CTraceback(ctypes.Structure):
_fields_ = [
# Extra PyObject slots when compiled with Py_TRACE_REFS.
("PyObject_HEAD", ctypes.c_byte * object().__sizeof__()),
# Only care about tb_next as an object, not a traceback.
("tb_next", ctypes.py_object),
]
def tb_set_next(
tb: TracebackType, tb_next: t.Optional[TracebackType]
) -> TracebackType:
c_tb = _CTraceback.from_address(id(tb))
# Clear out the old tb_next.
if tb.tb_next is not None:
c_tb_next = ctypes.py_object(tb.tb_next)
c_tb.tb_next = ctypes.py_object()
ctypes.pythonapi.Py_DecRef(c_tb_next)
# Assign the new tb_next.
if tb_next is not None:
c_tb_next = ctypes.py_object(tb_next)
ctypes.pythonapi.Py_IncRef(c_tb_next)
c_tb.tb_next = c_tb_next
return tb

View file

@ -0,0 +1,48 @@
import typing as t
from .filters import FILTERS as DEFAULT_FILTERS # noqa: F401
from .tests import TESTS as DEFAULT_TESTS # noqa: F401
from .utils import Cycler
from .utils import generate_lorem_ipsum
from .utils import Joiner
from .utils import Namespace
if t.TYPE_CHECKING:
import typing_extensions as te
# defaults for the parser / lexer
BLOCK_START_STRING = "{%"
BLOCK_END_STRING = "%}"
VARIABLE_START_STRING = "{{"
VARIABLE_END_STRING = "}}"
COMMENT_START_STRING = "{#"
COMMENT_END_STRING = "#}"
LINE_STATEMENT_PREFIX: t.Optional[str] = None
LINE_COMMENT_PREFIX: t.Optional[str] = None
TRIM_BLOCKS = False
LSTRIP_BLOCKS = False
NEWLINE_SEQUENCE: "te.Literal['\\n', '\\r\\n', '\\r']" = "\n"
KEEP_TRAILING_NEWLINE = False
# default filters, tests and namespace
DEFAULT_NAMESPACE = {
"range": range,
"dict": dict,
"lipsum": generate_lorem_ipsum,
"cycler": Cycler,
"joiner": Joiner,
"namespace": Namespace,
}
# default policies
DEFAULT_POLICIES: t.Dict[str, t.Any] = {
"compiler.ascii_str": True,
"urlize.rel": "noopener",
"urlize.target": None,
"urlize.extra_schemes": None,
"truncate.leeway": 5,
"json.dumps_function": None,
"json.dumps_kwargs": {"sort_keys": True},
"ext.i18n.trimmed": False,
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,166 @@
import typing as t
if t.TYPE_CHECKING:
from .runtime import Undefined
class TemplateError(Exception):
"""Baseclass for all template errors."""
def __init__(self, message: t.Optional[str] = None) -> None:
super().__init__(message)
@property
def message(self) -> t.Optional[str]:
return self.args[0] if self.args else None
class TemplateNotFound(IOError, LookupError, TemplateError):
"""Raised if a template does not exist.
.. versionchanged:: 2.11
If the given name is :class:`Undefined` and no message was
provided, an :exc:`UndefinedError` is raised.
"""
# Silence the Python warning about message being deprecated since
# it's not valid here.
message: t.Optional[str] = None
def __init__(
self,
name: t.Optional[t.Union[str, "Undefined"]],
message: t.Optional[str] = None,
) -> None:
IOError.__init__(self, name)
if message is None:
from .runtime import Undefined
if isinstance(name, Undefined):
name._fail_with_undefined_error()
message = name
self.message = message
self.name = name
self.templates = [name]
def __str__(self) -> str:
return str(self.message)
class TemplatesNotFound(TemplateNotFound):
"""Like :class:`TemplateNotFound` but raised if multiple templates
are selected. This is a subclass of :class:`TemplateNotFound`
exception, so just catching the base exception will catch both.
.. versionchanged:: 2.11
If a name in the list of names is :class:`Undefined`, a message
about it being undefined is shown rather than the empty string.
.. versionadded:: 2.2
"""
def __init__(
self,
names: t.Sequence[t.Union[str, "Undefined"]] = (),
message: t.Optional[str] = None,
) -> None:
if message is None:
from .runtime import Undefined
parts = []
for name in names:
if isinstance(name, Undefined):
parts.append(name._undefined_message)
else:
parts.append(name)
parts_str = ", ".join(map(str, parts))
message = f"none of the templates given were found: {parts_str}"
super().__init__(names[-1] if names else None, message)
self.templates = list(names)
class TemplateSyntaxError(TemplateError):
"""Raised to tell the user that there is a problem with the template."""
def __init__(
self,
message: str,
lineno: int,
name: t.Optional[str] = None,
filename: t.Optional[str] = None,
) -> None:
super().__init__(message)
self.lineno = lineno
self.name = name
self.filename = filename
self.source: t.Optional[str] = None
# this is set to True if the debug.translate_syntax_error
# function translated the syntax error into a new traceback
self.translated = False
def __str__(self) -> str:
# for translated errors we only return the message
if self.translated:
return t.cast(str, self.message)
# otherwise attach some stuff
location = f"line {self.lineno}"
name = self.filename or self.name
if name:
location = f'File "{name}", {location}'
lines = [t.cast(str, self.message), " " + location]
# if the source is set, add the line to the output
if self.source is not None:
try:
line = self.source.splitlines()[self.lineno - 1]
except IndexError:
pass
else:
lines.append(" " + line.strip())
return "\n".join(lines)
def __reduce__(self): # type: ignore
# https://bugs.python.org/issue1692335 Exceptions that take
# multiple required arguments have problems with pickling.
# Without this, raises TypeError: __init__() missing 1 required
# positional argument: 'lineno'
return self.__class__, (self.message, self.lineno, self.name, self.filename)
class TemplateAssertionError(TemplateSyntaxError):
"""Like a template syntax error, but covers cases where something in the
template caused an error at compile time that wasn't necessarily caused
by a syntax error. However it's a direct subclass of
:exc:`TemplateSyntaxError` and has the same attributes.
"""
class TemplateRuntimeError(TemplateError):
"""A generic runtime error in the template engine. Under some situations
Jinja may raise this exception.
"""
class UndefinedError(TemplateRuntimeError):
"""Raised if a template tries to operate on :class:`Undefined`."""
class SecurityError(TemplateRuntimeError):
"""Raised if a template tries to do something insecure if the
sandbox is enabled.
"""
class FilterArgumentError(TemplateRuntimeError):
"""This error is raised if a filter was called with inappropriate
arguments
"""

View file

@ -0,0 +1,879 @@
"""Extension API for adding custom tags and behavior."""
import pprint
import re
import typing as t
import warnings
from markupsafe import Markup
from . import defaults
from . import nodes
from .environment import Environment
from .exceptions import TemplateAssertionError
from .exceptions import TemplateSyntaxError
from .runtime import concat # type: ignore
from .runtime import Context
from .runtime import Undefined
from .utils import import_string
from .utils import pass_context
if t.TYPE_CHECKING:
import typing_extensions as te
from .lexer import Token
from .lexer import TokenStream
from .parser import Parser
class _TranslationsBasic(te.Protocol):
def gettext(self, message: str) -> str:
...
def ngettext(self, singular: str, plural: str, n: int) -> str:
pass
class _TranslationsContext(_TranslationsBasic):
def pgettext(self, context: str, message: str) -> str:
...
def npgettext(self, context: str, singular: str, plural: str, n: int) -> str:
...
_SupportedTranslations = t.Union[_TranslationsBasic, _TranslationsContext]
# I18N functions available in Jinja templates. If the I18N library
# provides ugettext, it will be assigned to gettext.
GETTEXT_FUNCTIONS: t.Tuple[str, ...] = (
"_",
"gettext",
"ngettext",
"pgettext",
"npgettext",
)
_ws_re = re.compile(r"\s*\n\s*")
class Extension:
"""Extensions can be used to add extra functionality to the Jinja template
system at the parser level. Custom extensions are bound to an environment
but may not store environment specific data on `self`. The reason for
this is that an extension can be bound to another environment (for
overlays) by creating a copy and reassigning the `environment` attribute.
As extensions are created by the environment they cannot accept any
arguments for configuration. One may want to work around that by using
a factory function, but that is not possible as extensions are identified
by their import name. The correct way to configure the extension is
storing the configuration values on the environment. Because this way the
environment ends up acting as central configuration storage the
attributes may clash which is why extensions have to ensure that the names
they choose for configuration are not too generic. ``prefix`` for example
is a terrible name, ``fragment_cache_prefix`` on the other hand is a good
name as includes the name of the extension (fragment cache).
"""
identifier: t.ClassVar[str]
def __init_subclass__(cls) -> None:
cls.identifier = f"{cls.__module__}.{cls.__name__}"
#: if this extension parses this is the list of tags it's listening to.
tags: t.Set[str] = set()
#: the priority of that extension. This is especially useful for
#: extensions that preprocess values. A lower value means higher
#: priority.
#:
#: .. versionadded:: 2.4
priority = 100
def __init__(self, environment: Environment) -> None:
self.environment = environment
def bind(self, environment: Environment) -> "Extension":
"""Create a copy of this extension bound to another environment."""
rv = t.cast(Extension, object.__new__(self.__class__))
rv.__dict__.update(self.__dict__)
rv.environment = environment
return rv
def preprocess(
self, source: str, name: t.Optional[str], filename: t.Optional[str] = None
) -> str:
"""This method is called before the actual lexing and can be used to
preprocess the source. The `filename` is optional. The return value
must be the preprocessed source.
"""
return source
def filter_stream(
self, stream: "TokenStream"
) -> t.Union["TokenStream", t.Iterable["Token"]]:
"""It's passed a :class:`~jinja2.lexer.TokenStream` that can be used
to filter tokens returned. This method has to return an iterable of
:class:`~jinja2.lexer.Token`\\s, but it doesn't have to return a
:class:`~jinja2.lexer.TokenStream`.
"""
return stream
def parse(self, parser: "Parser") -> t.Union[nodes.Node, t.List[nodes.Node]]:
"""If any of the :attr:`tags` matched this method is called with the
parser as first argument. The token the parser stream is pointing at
is the name token that matched. This method has to return one or a
list of multiple nodes.
"""
raise NotImplementedError()
def attr(
self, name: str, lineno: t.Optional[int] = None
) -> nodes.ExtensionAttribute:
"""Return an attribute node for the current extension. This is useful
to pass constants on extensions to generated template code.
::
self.attr('_my_attribute', lineno=lineno)
"""
return nodes.ExtensionAttribute(self.identifier, name, lineno=lineno)
def call_method(
self,
name: str,
args: t.Optional[t.List[nodes.Expr]] = None,
kwargs: t.Optional[t.List[nodes.Keyword]] = None,
dyn_args: t.Optional[nodes.Expr] = None,
dyn_kwargs: t.Optional[nodes.Expr] = None,
lineno: t.Optional[int] = None,
) -> nodes.Call:
"""Call a method of the extension. This is a shortcut for
:meth:`attr` + :class:`jinja2.nodes.Call`.
"""
if args is None:
args = []
if kwargs is None:
kwargs = []
return nodes.Call(
self.attr(name, lineno=lineno),
args,
kwargs,
dyn_args,
dyn_kwargs,
lineno=lineno,
)
@pass_context
def _gettext_alias(
__context: Context, *args: t.Any, **kwargs: t.Any
) -> t.Union[t.Any, Undefined]:
return __context.call(__context.resolve("gettext"), *args, **kwargs)
def _make_new_gettext(func: t.Callable[[str], str]) -> t.Callable[..., str]:
@pass_context
def gettext(__context: Context, __string: str, **variables: t.Any) -> str:
rv = __context.call(func, __string)
if __context.eval_ctx.autoescape:
rv = Markup(rv)
# Always treat as a format string, even if there are no
# variables. This makes translation strings more consistent
# and predictable. This requires escaping
return rv % variables # type: ignore
return gettext
def _make_new_ngettext(func: t.Callable[[str, str, int], str]) -> t.Callable[..., str]:
@pass_context
def ngettext(
__context: Context,
__singular: str,
__plural: str,
__num: int,
**variables: t.Any,
) -> str:
variables.setdefault("num", __num)
rv = __context.call(func, __singular, __plural, __num)
if __context.eval_ctx.autoescape:
rv = Markup(rv)
# Always treat as a format string, see gettext comment above.
return rv % variables # type: ignore
return ngettext
def _make_new_pgettext(func: t.Callable[[str, str], str]) -> t.Callable[..., str]:
@pass_context
def pgettext(
__context: Context, __string_ctx: str, __string: str, **variables: t.Any
) -> str:
variables.setdefault("context", __string_ctx)
rv = __context.call(func, __string_ctx, __string)
if __context.eval_ctx.autoescape:
rv = Markup(rv)
# Always treat as a format string, see gettext comment above.
return rv % variables # type: ignore
return pgettext
def _make_new_npgettext(
func: t.Callable[[str, str, str, int], str]
) -> t.Callable[..., str]:
@pass_context
def npgettext(
__context: Context,
__string_ctx: str,
__singular: str,
__plural: str,
__num: int,
**variables: t.Any,
) -> str:
variables.setdefault("context", __string_ctx)
variables.setdefault("num", __num)
rv = __context.call(func, __string_ctx, __singular, __plural, __num)
if __context.eval_ctx.autoescape:
rv = Markup(rv)
# Always treat as a format string, see gettext comment above.
return rv % variables # type: ignore
return npgettext
class InternationalizationExtension(Extension):
"""This extension adds gettext support to Jinja."""
tags = {"trans"}
# TODO: the i18n extension is currently reevaluating values in a few
# situations. Take this example:
# {% trans count=something() %}{{ count }} foo{% pluralize
# %}{{ count }} fooss{% endtrans %}
# something is called twice here. One time for the gettext value and
# the other time for the n-parameter of the ngettext function.
def __init__(self, environment: Environment) -> None:
super().__init__(environment)
environment.globals["_"] = _gettext_alias
environment.extend(
install_gettext_translations=self._install,
install_null_translations=self._install_null,
install_gettext_callables=self._install_callables,
uninstall_gettext_translations=self._uninstall,
extract_translations=self._extract,
newstyle_gettext=False,
)
def _install(
self, translations: "_SupportedTranslations", newstyle: t.Optional[bool] = None
) -> None:
# ugettext and ungettext are preferred in case the I18N library
# is providing compatibility with older Python versions.
gettext = getattr(translations, "ugettext", None)
if gettext is None:
gettext = translations.gettext
ngettext = getattr(translations, "ungettext", None)
if ngettext is None:
ngettext = translations.ngettext
pgettext = getattr(translations, "pgettext", None)
npgettext = getattr(translations, "npgettext", None)
self._install_callables(
gettext, ngettext, newstyle=newstyle, pgettext=pgettext, npgettext=npgettext
)
def _install_null(self, newstyle: t.Optional[bool] = None) -> None:
import gettext
translations = gettext.NullTranslations()
if hasattr(translations, "pgettext"):
# Python < 3.8
pgettext = translations.pgettext # type: ignore
else:
def pgettext(c: str, s: str) -> str:
return s
if hasattr(translations, "npgettext"):
npgettext = translations.npgettext # type: ignore
else:
def npgettext(c: str, s: str, p: str, n: int) -> str:
return s if n == 1 else p
self._install_callables(
gettext=translations.gettext,
ngettext=translations.ngettext,
newstyle=newstyle,
pgettext=pgettext,
npgettext=npgettext,
)
def _install_callables(
self,
gettext: t.Callable[[str], str],
ngettext: t.Callable[[str, str, int], str],
newstyle: t.Optional[bool] = None,
pgettext: t.Optional[t.Callable[[str, str], str]] = None,
npgettext: t.Optional[t.Callable[[str, str, str, int], str]] = None,
) -> None:
if newstyle is not None:
self.environment.newstyle_gettext = newstyle # type: ignore
if self.environment.newstyle_gettext: # type: ignore
gettext = _make_new_gettext(gettext)
ngettext = _make_new_ngettext(ngettext)
if pgettext is not None:
pgettext = _make_new_pgettext(pgettext)
if npgettext is not None:
npgettext = _make_new_npgettext(npgettext)
self.environment.globals.update(
gettext=gettext, ngettext=ngettext, pgettext=pgettext, npgettext=npgettext
)
def _uninstall(self, translations: "_SupportedTranslations") -> None:
for key in ("gettext", "ngettext", "pgettext", "npgettext"):
self.environment.globals.pop(key, None)
def _extract(
self,
source: t.Union[str, nodes.Template],
gettext_functions: t.Sequence[str] = GETTEXT_FUNCTIONS,
) -> t.Iterator[
t.Tuple[int, str, t.Union[t.Optional[str], t.Tuple[t.Optional[str], ...]]]
]:
if isinstance(source, str):
source = self.environment.parse(source)
return extract_from_ast(source, gettext_functions)
def parse(self, parser: "Parser") -> t.Union[nodes.Node, t.List[nodes.Node]]:
"""Parse a translatable tag."""
lineno = next(parser.stream).lineno
num_called_num = False
# find all the variables referenced. Additionally a variable can be
# defined in the body of the trans block too, but this is checked at
# a later state.
plural_expr: t.Optional[nodes.Expr] = None
plural_expr_assignment: t.Optional[nodes.Assign] = None
variables: t.Dict[str, nodes.Expr] = {}
trimmed = None
while parser.stream.current.type != "block_end":
if variables:
parser.stream.expect("comma")
# skip colon for python compatibility
if parser.stream.skip_if("colon"):
break
token = parser.stream.expect("name")
if token.value in variables:
parser.fail(
f"translatable variable {token.value!r} defined twice.",
token.lineno,
exc=TemplateAssertionError,
)
# expressions
if parser.stream.current.type == "assign":
next(parser.stream)
variables[token.value] = var = parser.parse_expression()
elif trimmed is None and token.value in ("trimmed", "notrimmed"):
trimmed = token.value == "trimmed"
continue
else:
variables[token.value] = var = nodes.Name(token.value, "load")
if plural_expr is None:
if isinstance(var, nodes.Call):
plural_expr = nodes.Name("_trans", "load")
variables[token.value] = plural_expr
plural_expr_assignment = nodes.Assign(
nodes.Name("_trans", "store"), var
)
else:
plural_expr = var
num_called_num = token.value == "num"
parser.stream.expect("block_end")
plural = None
have_plural = False
referenced = set()
# now parse until endtrans or pluralize
singular_names, singular = self._parse_block(parser, True)
if singular_names:
referenced.update(singular_names)
if plural_expr is None:
plural_expr = nodes.Name(singular_names[0], "load")
num_called_num = singular_names[0] == "num"
# if we have a pluralize block, we parse that too
if parser.stream.current.test("name:pluralize"):
have_plural = True
next(parser.stream)
if parser.stream.current.type != "block_end":
token = parser.stream.expect("name")
if token.value not in variables:
parser.fail(
f"unknown variable {token.value!r} for pluralization",
token.lineno,
exc=TemplateAssertionError,
)
plural_expr = variables[token.value]
num_called_num = token.value == "num"
parser.stream.expect("block_end")
plural_names, plural = self._parse_block(parser, False)
next(parser.stream)
referenced.update(plural_names)
else:
next(parser.stream)
# register free names as simple name expressions
for name in referenced:
if name not in variables:
variables[name] = nodes.Name(name, "load")
if not have_plural:
plural_expr = None
elif plural_expr is None:
parser.fail("pluralize without variables", lineno)
if trimmed is None:
trimmed = self.environment.policies["ext.i18n.trimmed"]
if trimmed:
singular = self._trim_whitespace(singular)
if plural:
plural = self._trim_whitespace(plural)
node = self._make_node(
singular,
plural,
variables,
plural_expr,
bool(referenced),
num_called_num and have_plural,
)
node.set_lineno(lineno)
if plural_expr_assignment is not None:
return [plural_expr_assignment, node]
else:
return node
def _trim_whitespace(self, string: str, _ws_re: t.Pattern[str] = _ws_re) -> str:
return _ws_re.sub(" ", string.strip())
def _parse_block(
self, parser: "Parser", allow_pluralize: bool
) -> t.Tuple[t.List[str], str]:
"""Parse until the next block tag with a given name."""
referenced = []
buf = []
while True:
if parser.stream.current.type == "data":
buf.append(parser.stream.current.value.replace("%", "%%"))
next(parser.stream)
elif parser.stream.current.type == "variable_begin":
next(parser.stream)
name = parser.stream.expect("name").value
referenced.append(name)
buf.append(f"%({name})s")
parser.stream.expect("variable_end")
elif parser.stream.current.type == "block_begin":
next(parser.stream)
if parser.stream.current.test("name:endtrans"):
break
elif parser.stream.current.test("name:pluralize"):
if allow_pluralize:
break
parser.fail(
"a translatable section can have only one pluralize section"
)
parser.fail(
"control structures in translatable sections are not allowed"
)
elif parser.stream.eos:
parser.fail("unclosed translation block")
else:
raise RuntimeError("internal parser error")
return referenced, concat(buf)
def _make_node(
self,
singular: str,
plural: t.Optional[str],
variables: t.Dict[str, nodes.Expr],
plural_expr: t.Optional[nodes.Expr],
vars_referenced: bool,
num_called_num: bool,
) -> nodes.Output:
"""Generates a useful node from the data provided."""
newstyle = self.environment.newstyle_gettext # type: ignore
node: nodes.Expr
# no variables referenced? no need to escape for old style
# gettext invocations only if there are vars.
if not vars_referenced and not newstyle:
singular = singular.replace("%%", "%")
if plural:
plural = plural.replace("%%", "%")
# singular only:
if plural_expr is None:
gettext = nodes.Name("gettext", "load")
node = nodes.Call(gettext, [nodes.Const(singular)], [], None, None)
# singular and plural
else:
ngettext = nodes.Name("ngettext", "load")
node = nodes.Call(
ngettext,
[nodes.Const(singular), nodes.Const(plural), plural_expr],
[],
None,
None,
)
# in case newstyle gettext is used, the method is powerful
# enough to handle the variable expansion and autoescape
# handling itself
if newstyle:
for key, value in variables.items():
# the function adds that later anyways in case num was
# called num, so just skip it.
if num_called_num and key == "num":
continue
node.kwargs.append(nodes.Keyword(key, value))
# otherwise do that here
else:
# mark the return value as safe if we are in an
# environment with autoescaping turned on
node = nodes.MarkSafeIfAutoescape(node)
if variables:
node = nodes.Mod(
node,
nodes.Dict(
[
nodes.Pair(nodes.Const(key), value)
for key, value in variables.items()
]
),
)
return nodes.Output([node])
class ExprStmtExtension(Extension):
"""Adds a `do` tag to Jinja that works like the print statement just
that it doesn't print the return value.
"""
tags = {"do"}
def parse(self, parser: "Parser") -> nodes.ExprStmt:
node = nodes.ExprStmt(lineno=next(parser.stream).lineno)
node.node = parser.parse_tuple()
return node
class LoopControlExtension(Extension):
"""Adds break and continue to the template engine."""
tags = {"break", "continue"}
def parse(self, parser: "Parser") -> t.Union[nodes.Break, nodes.Continue]:
token = next(parser.stream)
if token.value == "break":
return nodes.Break(lineno=token.lineno)
return nodes.Continue(lineno=token.lineno)
class WithExtension(Extension):
def __init__(self, environment: Environment) -> None:
super().__init__(environment)
warnings.warn(
"The 'with' extension is deprecated and will be removed in"
" Jinja 3.1. This is built in now.",
DeprecationWarning,
stacklevel=3,
)
class AutoEscapeExtension(Extension):
def __init__(self, environment: Environment) -> None:
super().__init__(environment)
warnings.warn(
"The 'autoescape' extension is deprecated and will be"
" removed in Jinja 3.1. This is built in now.",
DeprecationWarning,
stacklevel=3,
)
class DebugExtension(Extension):
"""A ``{% debug %}`` tag that dumps the available variables,
filters, and tests.
.. code-block:: html+jinja
<pre>{% debug %}</pre>
.. code-block:: text
{'context': {'cycler': <class 'jinja2.utils.Cycler'>,
...,
'namespace': <class 'jinja2.utils.Namespace'>},
'filters': ['abs', 'attr', 'batch', 'capitalize', 'center', 'count', 'd',
..., 'urlencode', 'urlize', 'wordcount', 'wordwrap', 'xmlattr'],
'tests': ['!=', '<', '<=', '==', '>', '>=', 'callable', 'defined',
..., 'odd', 'sameas', 'sequence', 'string', 'undefined', 'upper']}
.. versionadded:: 2.11.0
"""
tags = {"debug"}
def parse(self, parser: "Parser") -> nodes.Output:
lineno = parser.stream.expect("name:debug").lineno
context = nodes.ContextReference()
result = self.call_method("_render", [context], lineno=lineno)
return nodes.Output([result], lineno=lineno)
def _render(self, context: Context) -> str:
result = {
"context": context.get_all(),
"filters": sorted(self.environment.filters.keys()),
"tests": sorted(self.environment.tests.keys()),
}
# Set the depth since the intent is to show the top few names.
return pprint.pformat(result, depth=3, compact=True)
def extract_from_ast(
ast: nodes.Template,
gettext_functions: t.Sequence[str] = GETTEXT_FUNCTIONS,
babel_style: bool = True,
) -> t.Iterator[
t.Tuple[int, str, t.Union[t.Optional[str], t.Tuple[t.Optional[str], ...]]]
]:
"""Extract localizable strings from the given template node. Per
default this function returns matches in babel style that means non string
parameters as well as keyword arguments are returned as `None`. This
allows Babel to figure out what you really meant if you are using
gettext functions that allow keyword arguments for placeholder expansion.
If you don't want that behavior set the `babel_style` parameter to `False`
which causes only strings to be returned and parameters are always stored
in tuples. As a consequence invalid gettext calls (calls without a single
string parameter or string parameters after non-string parameters) are
skipped.
This example explains the behavior:
>>> from jinja2 import Environment
>>> env = Environment()
>>> node = env.parse('{{ (_("foo"), _(), ngettext("foo", "bar", 42)) }}')
>>> list(extract_from_ast(node))
[(1, '_', 'foo'), (1, '_', ()), (1, 'ngettext', ('foo', 'bar', None))]
>>> list(extract_from_ast(node, babel_style=False))
[(1, '_', ('foo',)), (1, 'ngettext', ('foo', 'bar'))]
For every string found this function yields a ``(lineno, function,
message)`` tuple, where:
* ``lineno`` is the number of the line on which the string was found,
* ``function`` is the name of the ``gettext`` function used (if the
string was extracted from embedded Python code), and
* ``message`` is the string, or a tuple of strings for functions
with multiple string arguments.
This extraction function operates on the AST and is because of that unable
to extract any comments. For comment support you have to use the babel
extraction interface or extract comments yourself.
"""
out: t.Union[t.Optional[str], t.Tuple[t.Optional[str], ...]]
for node in ast.find_all(nodes.Call):
if (
not isinstance(node.node, nodes.Name)
or node.node.name not in gettext_functions
):
continue
strings: t.List[t.Optional[str]] = []
for arg in node.args:
if isinstance(arg, nodes.Const) and isinstance(arg.value, str):
strings.append(arg.value)
else:
strings.append(None)
for _ in node.kwargs:
strings.append(None)
if node.dyn_args is not None:
strings.append(None)
if node.dyn_kwargs is not None:
strings.append(None)
if not babel_style:
out = tuple(x for x in strings if x is not None)
if not out:
continue
else:
if len(strings) == 1:
out = strings[0]
else:
out = tuple(strings)
yield node.lineno, node.node.name, out
class _CommentFinder:
"""Helper class to find comments in a token stream. Can only
find comments for gettext calls forwards. Once the comment
from line 4 is found, a comment for line 1 will not return a
usable value.
"""
def __init__(
self, tokens: t.Sequence[t.Tuple[int, str, str]], comment_tags: t.Sequence[str]
) -> None:
self.tokens = tokens
self.comment_tags = comment_tags
self.offset = 0
self.last_lineno = 0
def find_backwards(self, offset: int) -> t.List[str]:
try:
for _, token_type, token_value in reversed(
self.tokens[self.offset : offset]
):
if token_type in ("comment", "linecomment"):
try:
prefix, comment = token_value.split(None, 1)
except ValueError:
continue
if prefix in self.comment_tags:
return [comment.rstrip()]
return []
finally:
self.offset = offset
def find_comments(self, lineno: int) -> t.List[str]:
if not self.comment_tags or self.last_lineno > lineno:
return []
for idx, (token_lineno, _, _) in enumerate(self.tokens[self.offset :]):
if token_lineno > lineno:
return self.find_backwards(self.offset + idx)
return self.find_backwards(len(self.tokens))
def babel_extract(
fileobj: t.BinaryIO,
keywords: t.Sequence[str],
comment_tags: t.Sequence[str],
options: t.Dict[str, t.Any],
) -> t.Iterator[
t.Tuple[
int, str, t.Union[t.Optional[str], t.Tuple[t.Optional[str], ...]], t.List[str]
]
]:
"""Babel extraction method for Jinja templates.
.. versionchanged:: 2.3
Basic support for translation comments was added. If `comment_tags`
is now set to a list of keywords for extraction, the extractor will
try to find the best preceding comment that begins with one of the
keywords. For best results, make sure to not have more than one
gettext call in one line of code and the matching comment in the
same line or the line before.
.. versionchanged:: 2.5.1
The `newstyle_gettext` flag can be set to `True` to enable newstyle
gettext calls.
.. versionchanged:: 2.7
A `silent` option can now be provided. If set to `False` template
syntax errors are propagated instead of being ignored.
:param fileobj: the file-like object the messages should be extracted from
:param keywords: a list of keywords (i.e. function names) that should be
recognized as translation functions
:param comment_tags: a list of translator tags to search for and include
in the results.
:param options: a dictionary of additional options (optional)
:return: an iterator over ``(lineno, funcname, message, comments)`` tuples.
(comments will be empty currently)
"""
extensions: t.Dict[t.Type[Extension], None] = {}
for extension_name in options.get("extensions", "").split(","):
extension_name = extension_name.strip()
if not extension_name:
continue
extensions[import_string(extension_name)] = None
if InternationalizationExtension not in extensions:
extensions[InternationalizationExtension] = None
def getbool(options: t.Mapping[str, str], key: str, default: bool = False) -> bool:
return options.get(key, str(default)).lower() in {"1", "on", "yes", "true"}
silent = getbool(options, "silent", True)
environment = Environment(
options.get("block_start_string", defaults.BLOCK_START_STRING),
options.get("block_end_string", defaults.BLOCK_END_STRING),
options.get("variable_start_string", defaults.VARIABLE_START_STRING),
options.get("variable_end_string", defaults.VARIABLE_END_STRING),
options.get("comment_start_string", defaults.COMMENT_START_STRING),
options.get("comment_end_string", defaults.COMMENT_END_STRING),
options.get("line_statement_prefix") or defaults.LINE_STATEMENT_PREFIX,
options.get("line_comment_prefix") or defaults.LINE_COMMENT_PREFIX,
getbool(options, "trim_blocks", defaults.TRIM_BLOCKS),
getbool(options, "lstrip_blocks", defaults.LSTRIP_BLOCKS),
defaults.NEWLINE_SEQUENCE,
getbool(options, "keep_trailing_newline", defaults.KEEP_TRAILING_NEWLINE),
tuple(extensions),
cache_size=0,
auto_reload=False,
)
if getbool(options, "trimmed"):
environment.policies["ext.i18n.trimmed"] = True
if getbool(options, "newstyle_gettext"):
environment.newstyle_gettext = True # type: ignore
source = fileobj.read().decode(options.get("encoding", "utf-8"))
try:
node = environment.parse(source)
tokens = list(environment.lex(environment.preprocess(source)))
except TemplateSyntaxError:
if not silent:
raise
# skip templates with syntax errors
return
finder = _CommentFinder(tokens, comment_tags)
for lineno, func, message in extract_from_ast(node, keywords):
yield lineno, func, message, finder.find_comments(lineno)
#: nicer import names
i18n = InternationalizationExtension
do = ExprStmtExtension
loopcontrols = LoopControlExtension
with_ = WithExtension
autoescape = AutoEscapeExtension
debug = DebugExtension

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,318 @@
import typing as t
from . import nodes
from .visitor import NodeVisitor
VAR_LOAD_PARAMETER = "param"
VAR_LOAD_RESOLVE = "resolve"
VAR_LOAD_ALIAS = "alias"
VAR_LOAD_UNDEFINED = "undefined"
def find_symbols(
nodes: t.Iterable[nodes.Node], parent_symbols: t.Optional["Symbols"] = None
) -> "Symbols":
sym = Symbols(parent=parent_symbols)
visitor = FrameSymbolVisitor(sym)
for node in nodes:
visitor.visit(node)
return sym
def symbols_for_node(
node: nodes.Node, parent_symbols: t.Optional["Symbols"] = None
) -> "Symbols":
sym = Symbols(parent=parent_symbols)
sym.analyze_node(node)
return sym
class Symbols:
def __init__(
self, parent: t.Optional["Symbols"] = None, level: t.Optional[int] = None
) -> None:
if level is None:
if parent is None:
level = 0
else:
level = parent.level + 1
self.level: int = level
self.parent = parent
self.refs: t.Dict[str, str] = {}
self.loads: t.Dict[str, t.Any] = {}
self.stores: t.Set[str] = set()
def analyze_node(self, node: nodes.Node, **kwargs: t.Any) -> None:
visitor = RootVisitor(self)
visitor.visit(node, **kwargs)
def _define_ref(
self, name: str, load: t.Optional[t.Tuple[str, t.Optional[str]]] = None
) -> str:
ident = f"l_{self.level}_{name}"
self.refs[name] = ident
if load is not None:
self.loads[ident] = load
return ident
def find_load(self, target: str) -> t.Optional[t.Any]:
if target in self.loads:
return self.loads[target]
if self.parent is not None:
return self.parent.find_load(target)
return None
def find_ref(self, name: str) -> t.Optional[str]:
if name in self.refs:
return self.refs[name]
if self.parent is not None:
return self.parent.find_ref(name)
return None
def ref(self, name: str) -> str:
rv = self.find_ref(name)
if rv is None:
raise AssertionError(
"Tried to resolve a name to a reference that was"
f" unknown to the frame ({name!r})"
)
return rv
def copy(self) -> "Symbols":
rv = t.cast(Symbols, object.__new__(self.__class__))
rv.__dict__.update(self.__dict__)
rv.refs = self.refs.copy()
rv.loads = self.loads.copy()
rv.stores = self.stores.copy()
return rv
def store(self, name: str) -> None:
self.stores.add(name)
# If we have not see the name referenced yet, we need to figure
# out what to set it to.
if name not in self.refs:
# If there is a parent scope we check if the name has a
# reference there. If it does it means we might have to alias
# to a variable there.
if self.parent is not None:
outer_ref = self.parent.find_ref(name)
if outer_ref is not None:
self._define_ref(name, load=(VAR_LOAD_ALIAS, outer_ref))
return
# Otherwise we can just set it to undefined.
self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None))
def declare_parameter(self, name: str) -> str:
self.stores.add(name)
return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None))
def load(self, name: str) -> None:
if self.find_ref(name) is None:
self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None:
stores: t.Dict[str, int] = {}
for branch in branch_symbols:
for target in branch.stores:
if target in self.stores:
continue
stores[target] = stores.get(target, 0) + 1
for sym in branch_symbols:
self.refs.update(sym.refs)
self.loads.update(sym.loads)
self.stores.update(sym.stores)
for name, branch_count in stores.items():
if branch_count == len(branch_symbols):
continue
target = self.find_ref(name) # type: ignore
assert target is not None, "should not happen"
if self.parent is not None:
outer_target = self.parent.find_ref(name)
if outer_target is not None:
self.loads[target] = (VAR_LOAD_ALIAS, outer_target)
continue
self.loads[target] = (VAR_LOAD_RESOLVE, name)
def dump_stores(self) -> t.Dict[str, str]:
rv: t.Dict[str, str] = {}
node: t.Optional["Symbols"] = self
while node is not None:
for name in sorted(node.stores):
if name not in rv:
rv[name] = self.find_ref(name) # type: ignore
node = node.parent
return rv
def dump_param_targets(self) -> t.Set[str]:
rv = set()
node: t.Optional["Symbols"] = self
while node is not None:
for target, (instr, _) in self.loads.items():
if instr == VAR_LOAD_PARAMETER:
rv.add(target)
node = node.parent
return rv
class RootVisitor(NodeVisitor):
def __init__(self, symbols: "Symbols") -> None:
self.sym_visitor = FrameSymbolVisitor(symbols)
def _simple_visit(self, node: nodes.Node, **kwargs: t.Any) -> None:
for child in node.iter_child_nodes():
self.sym_visitor.visit(child)
visit_Template = _simple_visit
visit_Block = _simple_visit
visit_Macro = _simple_visit
visit_FilterBlock = _simple_visit
visit_Scope = _simple_visit
visit_If = _simple_visit
visit_ScopedEvalContextModifier = _simple_visit
def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
for child in node.body:
self.sym_visitor.visit(child)
def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None:
for child in node.iter_child_nodes(exclude=("call",)):
self.sym_visitor.visit(child)
def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
for child in node.body:
self.sym_visitor.visit(child)
def visit_For(
self, node: nodes.For, for_branch: str = "body", **kwargs: t.Any
) -> None:
if for_branch == "body":
self.sym_visitor.visit(node.target, store_as_param=True)
branch = node.body
elif for_branch == "else":
branch = node.else_
elif for_branch == "test":
self.sym_visitor.visit(node.target, store_as_param=True)
if node.test is not None:
self.sym_visitor.visit(node.test)
return
else:
raise RuntimeError("Unknown for branch")
if branch:
for item in branch:
self.sym_visitor.visit(item)
def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None:
for target in node.targets:
self.sym_visitor.visit(target)
for child in node.body:
self.sym_visitor.visit(child)
def generic_visit(self, node: nodes.Node, *args: t.Any, **kwargs: t.Any) -> None:
raise NotImplementedError(f"Cannot find symbols for {type(node).__name__!r}")
class FrameSymbolVisitor(NodeVisitor):
"""A visitor for `Frame.inspect`."""
def __init__(self, symbols: "Symbols") -> None:
self.symbols = symbols
def visit_Name(
self, node: nodes.Name, store_as_param: bool = False, **kwargs: t.Any
) -> None:
"""All assignments to names go through this function."""
if store_as_param or node.ctx == "param":
self.symbols.declare_parameter(node.name)
elif node.ctx == "store":
self.symbols.store(node.name)
elif node.ctx == "load":
self.symbols.load(node.name)
def visit_NSRef(self, node: nodes.NSRef, **kwargs: t.Any) -> None:
self.symbols.load(node.name)
def visit_If(self, node: nodes.If, **kwargs: t.Any) -> None:
self.visit(node.test, **kwargs)
original_symbols = self.symbols
def inner_visit(nodes: t.Iterable[nodes.Node]) -> "Symbols":
self.symbols = rv = original_symbols.copy()
for subnode in nodes:
self.visit(subnode, **kwargs)
self.symbols = original_symbols
return rv
body_symbols = inner_visit(node.body)
elif_symbols = inner_visit(node.elif_)
else_symbols = inner_visit(node.else_ or ())
self.symbols.branch_update([body_symbols, elif_symbols, else_symbols])
def visit_Macro(self, node: nodes.Macro, **kwargs: t.Any) -> None:
self.symbols.store(node.name)
def visit_Import(self, node: nodes.Import, **kwargs: t.Any) -> None:
self.generic_visit(node, **kwargs)
self.symbols.store(node.target)
def visit_FromImport(self, node: nodes.FromImport, **kwargs: t.Any) -> None:
self.generic_visit(node, **kwargs)
for name in node.names:
if isinstance(name, tuple):
self.symbols.store(name[1])
else:
self.symbols.store(name)
def visit_Assign(self, node: nodes.Assign, **kwargs: t.Any) -> None:
"""Visit assignments in the correct order."""
self.visit(node.node, **kwargs)
self.visit(node.target, **kwargs)
def visit_For(self, node: nodes.For, **kwargs: t.Any) -> None:
"""Visiting stops at for blocks. However the block sequence
is visited as part of the outer scope.
"""
self.visit(node.iter, **kwargs)
def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None:
self.visit(node.call, **kwargs)
def visit_FilterBlock(self, node: nodes.FilterBlock, **kwargs: t.Any) -> None:
self.visit(node.filter, **kwargs)
def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None:
for target in node.values:
self.visit(target)
def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
"""Stop visiting at block assigns."""
self.visit(node.target, **kwargs)
def visit_Scope(self, node: nodes.Scope, **kwargs: t.Any) -> None:
"""Stop visiting at scopes."""
def visit_Block(self, node: nodes.Block, **kwargs: t.Any) -> None:
"""Stop visiting at blocks."""
def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
"""Do not visit into overlay scopes."""

View file

@ -0,0 +1,869 @@
"""Implements a Jinja / Python combination lexer. The ``Lexer`` class
is used to do some preprocessing. It filters out invalid operators like
the bitshift operators we don't allow in templates. It separates
template code and python code in expressions.
"""
import re
import typing as t
from ast import literal_eval
from collections import deque
from sys import intern
from ._identifier import pattern as name_re
from .exceptions import TemplateSyntaxError
from .utils import LRUCache
if t.TYPE_CHECKING:
import typing_extensions as te
from .environment import Environment
# cache for the lexers. Exists in order to be able to have multiple
# environments with the same lexer
_lexer_cache: t.MutableMapping[t.Tuple, "Lexer"] = LRUCache(50) # type: ignore
# static regular expressions
whitespace_re = re.compile(r"\s+")
newline_re = re.compile(r"(\r\n|\r|\n)")
string_re = re.compile(
r"('([^'\\]*(?:\\.[^'\\]*)*)'" r'|"([^"\\]*(?:\\.[^"\\]*)*)")', re.S
)
integer_re = re.compile(
r"""
(
0b(_?[0-1])+ # binary
|
0o(_?[0-7])+ # octal
|
0x(_?[\da-f])+ # hex
|
[1-9](_?\d)* # decimal
|
0(_?0)* # decimal zero
)
""",
re.IGNORECASE | re.VERBOSE,
)
float_re = re.compile(
r"""
(?<!\.) # doesn't start with a .
(\d+_)*\d+ # digits, possibly _ separated
(
(\.(\d+_)*\d+)? # optional fractional part
e[+\-]?(\d+_)*\d+ # exponent part
|
\.(\d+_)*\d+ # required fractional part
)
""",
re.IGNORECASE | re.VERBOSE,
)
# internal the tokens and keep references to them
TOKEN_ADD = intern("add")
TOKEN_ASSIGN = intern("assign")
TOKEN_COLON = intern("colon")
TOKEN_COMMA = intern("comma")
TOKEN_DIV = intern("div")
TOKEN_DOT = intern("dot")
TOKEN_EQ = intern("eq")
TOKEN_FLOORDIV = intern("floordiv")
TOKEN_GT = intern("gt")
TOKEN_GTEQ = intern("gteq")
TOKEN_LBRACE = intern("lbrace")
TOKEN_LBRACKET = intern("lbracket")
TOKEN_LPAREN = intern("lparen")
TOKEN_LT = intern("lt")
TOKEN_LTEQ = intern("lteq")
TOKEN_MOD = intern("mod")
TOKEN_MUL = intern("mul")
TOKEN_NE = intern("ne")
TOKEN_PIPE = intern("pipe")
TOKEN_POW = intern("pow")
TOKEN_RBRACE = intern("rbrace")
TOKEN_RBRACKET = intern("rbracket")
TOKEN_RPAREN = intern("rparen")
TOKEN_SEMICOLON = intern("semicolon")
TOKEN_SUB = intern("sub")
TOKEN_TILDE = intern("tilde")
TOKEN_WHITESPACE = intern("whitespace")
TOKEN_FLOAT = intern("float")
TOKEN_INTEGER = intern("integer")
TOKEN_NAME = intern("name")
TOKEN_STRING = intern("string")
TOKEN_OPERATOR = intern("operator")
TOKEN_BLOCK_BEGIN = intern("block_begin")
TOKEN_BLOCK_END = intern("block_end")
TOKEN_VARIABLE_BEGIN = intern("variable_begin")
TOKEN_VARIABLE_END = intern("variable_end")
TOKEN_RAW_BEGIN = intern("raw_begin")
TOKEN_RAW_END = intern("raw_end")
TOKEN_COMMENT_BEGIN = intern("comment_begin")
TOKEN_COMMENT_END = intern("comment_end")
TOKEN_COMMENT = intern("comment")
TOKEN_LINESTATEMENT_BEGIN = intern("linestatement_begin")
TOKEN_LINESTATEMENT_END = intern("linestatement_end")
TOKEN_LINECOMMENT_BEGIN = intern("linecomment_begin")
TOKEN_LINECOMMENT_END = intern("linecomment_end")
TOKEN_LINECOMMENT = intern("linecomment")
TOKEN_DATA = intern("data")
TOKEN_INITIAL = intern("initial")
TOKEN_EOF = intern("eof")
# bind operators to token types
operators = {
"+": TOKEN_ADD,
"-": TOKEN_SUB,
"/": TOKEN_DIV,
"//": TOKEN_FLOORDIV,
"*": TOKEN_MUL,
"%": TOKEN_MOD,
"**": TOKEN_POW,
"~": TOKEN_TILDE,
"[": TOKEN_LBRACKET,
"]": TOKEN_RBRACKET,
"(": TOKEN_LPAREN,
")": TOKEN_RPAREN,
"{": TOKEN_LBRACE,
"}": TOKEN_RBRACE,
"==": TOKEN_EQ,
"!=": TOKEN_NE,
">": TOKEN_GT,
">=": TOKEN_GTEQ,
"<": TOKEN_LT,
"<=": TOKEN_LTEQ,
"=": TOKEN_ASSIGN,
".": TOKEN_DOT,
":": TOKEN_COLON,
"|": TOKEN_PIPE,
",": TOKEN_COMMA,
";": TOKEN_SEMICOLON,
}
reverse_operators = {v: k for k, v in operators.items()}
assert len(operators) == len(reverse_operators), "operators dropped"
operator_re = re.compile(
f"({'|'.join(re.escape(x) for x in sorted(operators, key=lambda x: -len(x)))})"
)
ignored_tokens = frozenset(
[
TOKEN_COMMENT_BEGIN,
TOKEN_COMMENT,
TOKEN_COMMENT_END,
TOKEN_WHITESPACE,
TOKEN_LINECOMMENT_BEGIN,
TOKEN_LINECOMMENT_END,
TOKEN_LINECOMMENT,
]
)
ignore_if_empty = frozenset(
[TOKEN_WHITESPACE, TOKEN_DATA, TOKEN_COMMENT, TOKEN_LINECOMMENT]
)
def _describe_token_type(token_type: str) -> str:
if token_type in reverse_operators:
return reverse_operators[token_type]
return {
TOKEN_COMMENT_BEGIN: "begin of comment",
TOKEN_COMMENT_END: "end of comment",
TOKEN_COMMENT: "comment",
TOKEN_LINECOMMENT: "comment",
TOKEN_BLOCK_BEGIN: "begin of statement block",
TOKEN_BLOCK_END: "end of statement block",
TOKEN_VARIABLE_BEGIN: "begin of print statement",
TOKEN_VARIABLE_END: "end of print statement",
TOKEN_LINESTATEMENT_BEGIN: "begin of line statement",
TOKEN_LINESTATEMENT_END: "end of line statement",
TOKEN_DATA: "template data / text",
TOKEN_EOF: "end of template",
}.get(token_type, token_type)
def describe_token(token: "Token") -> str:
"""Returns a description of the token."""
if token.type == TOKEN_NAME:
return token.value
return _describe_token_type(token.type)
def describe_token_expr(expr: str) -> str:
"""Like `describe_token` but for token expressions."""
if ":" in expr:
type, value = expr.split(":", 1)
if type == TOKEN_NAME:
return value
else:
type = expr
return _describe_token_type(type)
def count_newlines(value: str) -> int:
"""Count the number of newline characters in the string. This is
useful for extensions that filter a stream.
"""
return len(newline_re.findall(value))
def compile_rules(environment: "Environment") -> t.List[t.Tuple[str, str]]:
"""Compiles all the rules from the environment into a list of rules."""
e = re.escape
rules = [
(
len(environment.comment_start_string),
TOKEN_COMMENT_BEGIN,
e(environment.comment_start_string),
),
(
len(environment.block_start_string),
TOKEN_BLOCK_BEGIN,
e(environment.block_start_string),
),
(
len(environment.variable_start_string),
TOKEN_VARIABLE_BEGIN,
e(environment.variable_start_string),
),
]
if environment.line_statement_prefix is not None:
rules.append(
(
len(environment.line_statement_prefix),
TOKEN_LINESTATEMENT_BEGIN,
r"^[ \t\v]*" + e(environment.line_statement_prefix),
)
)
if environment.line_comment_prefix is not None:
rules.append(
(
len(environment.line_comment_prefix),
TOKEN_LINECOMMENT_BEGIN,
r"(?:^|(?<=\S))[^\S\r\n]*" + e(environment.line_comment_prefix),
)
)
return [x[1:] for x in sorted(rules, reverse=True)]
class Failure:
"""Class that raises a `TemplateSyntaxError` if called.
Used by the `Lexer` to specify known errors.
"""
def __init__(
self, message: str, cls: t.Type[TemplateSyntaxError] = TemplateSyntaxError
) -> None:
self.message = message
self.error_class = cls
def __call__(self, lineno: int, filename: str) -> "te.NoReturn":
raise self.error_class(self.message, lineno, filename)
class Token(t.NamedTuple):
lineno: int
type: str
value: str
def __str__(self) -> str:
return describe_token(self)
def test(self, expr: str) -> bool:
"""Test a token against a token expression. This can either be a
token type or ``'token_type:token_value'``. This can only test
against string values and types.
"""
# here we do a regular string equality check as test_any is usually
# passed an iterable of not interned strings.
if self.type == expr:
return True
if ":" in expr:
return expr.split(":", 1) == [self.type, self.value]
return False
def test_any(self, *iterable: str) -> bool:
"""Test against multiple token expressions."""
return any(self.test(expr) for expr in iterable)
class TokenStreamIterator:
"""The iterator for tokenstreams. Iterate over the stream
until the eof token is reached.
"""
def __init__(self, stream: "TokenStream") -> None:
self.stream = stream
def __iter__(self) -> "TokenStreamIterator":
return self
def __next__(self) -> Token:
token = self.stream.current
if token.type is TOKEN_EOF:
self.stream.close()
raise StopIteration
next(self.stream)
return token
class TokenStream:
"""A token stream is an iterable that yields :class:`Token`\\s. The
parser however does not iterate over it but calls :meth:`next` to go
one token ahead. The current active token is stored as :attr:`current`.
"""
def __init__(
self,
generator: t.Iterable[Token],
name: t.Optional[str],
filename: t.Optional[str],
):
self._iter = iter(generator)
self._pushed: "te.Deque[Token]" = deque()
self.name = name
self.filename = filename
self.closed = False
self.current = Token(1, TOKEN_INITIAL, "")
next(self)
def __iter__(self) -> TokenStreamIterator:
return TokenStreamIterator(self)
def __bool__(self) -> bool:
return bool(self._pushed) or self.current.type is not TOKEN_EOF
@property
def eos(self) -> bool:
"""Are we at the end of the stream?"""
return not self
def push(self, token: Token) -> None:
"""Push a token back to the stream."""
self._pushed.append(token)
def look(self) -> Token:
"""Look at the next token."""
old_token = next(self)
result = self.current
self.push(result)
self.current = old_token
return result
def skip(self, n: int = 1) -> None:
"""Got n tokens ahead."""
for _ in range(n):
next(self)
def next_if(self, expr: str) -> t.Optional[Token]:
"""Perform the token test and return the token if it matched.
Otherwise the return value is `None`.
"""
if self.current.test(expr):
return next(self)
return None
def skip_if(self, expr: str) -> bool:
"""Like :meth:`next_if` but only returns `True` or `False`."""
return self.next_if(expr) is not None
def __next__(self) -> Token:
"""Go one token ahead and return the old one.
Use the built-in :func:`next` instead of calling this directly.
"""
rv = self.current
if self._pushed:
self.current = self._pushed.popleft()
elif self.current.type is not TOKEN_EOF:
try:
self.current = next(self._iter)
except StopIteration:
self.close()
return rv
def close(self) -> None:
"""Close the stream."""
self.current = Token(self.current.lineno, TOKEN_EOF, "")
self._iter = iter(())
self.closed = True
def expect(self, expr: str) -> Token:
"""Expect a given token type and return it. This accepts the same
argument as :meth:`jinja2.lexer.Token.test`.
"""
if not self.current.test(expr):
expr = describe_token_expr(expr)
if self.current.type is TOKEN_EOF:
raise TemplateSyntaxError(
f"unexpected end of template, expected {expr!r}.",
self.current.lineno,
self.name,
self.filename,
)
raise TemplateSyntaxError(
f"expected token {expr!r}, got {describe_token(self.current)!r}",
self.current.lineno,
self.name,
self.filename,
)
return next(self)
def get_lexer(environment: "Environment") -> "Lexer":
"""Return a lexer which is probably cached."""
key = (
environment.block_start_string,
environment.block_end_string,
environment.variable_start_string,
environment.variable_end_string,
environment.comment_start_string,
environment.comment_end_string,
environment.line_statement_prefix,
environment.line_comment_prefix,
environment.trim_blocks,
environment.lstrip_blocks,
environment.newline_sequence,
environment.keep_trailing_newline,
)
lexer = _lexer_cache.get(key)
if lexer is None:
_lexer_cache[key] = lexer = Lexer(environment)
return lexer
class OptionalLStrip(tuple):
"""A special tuple for marking a point in the state that can have
lstrip applied.
"""
__slots__ = ()
# Even though it looks like a no-op, creating instances fails
# without this.
def __new__(cls, *members, **kwargs): # type: ignore
return super().__new__(cls, members)
class _Rule(t.NamedTuple):
pattern: t.Pattern[str]
tokens: t.Union[str, t.Tuple[str, ...], t.Tuple[Failure]]
command: t.Optional[str]
class Lexer:
"""Class that implements a lexer for a given environment. Automatically
created by the environment class, usually you don't have to do that.
Note that the lexer is not automatically bound to an environment.
Multiple environments can share the same lexer.
"""
def __init__(self, environment: "Environment") -> None:
# shortcuts
e = re.escape
def c(x: str) -> t.Pattern[str]:
return re.compile(x, re.M | re.S)
# lexing rules for tags
tag_rules: t.List[_Rule] = [
_Rule(whitespace_re, TOKEN_WHITESPACE, None),
_Rule(float_re, TOKEN_FLOAT, None),
_Rule(integer_re, TOKEN_INTEGER, None),
_Rule(name_re, TOKEN_NAME, None),
_Rule(string_re, TOKEN_STRING, None),
_Rule(operator_re, TOKEN_OPERATOR, None),
]
# assemble the root lexing rule. because "|" is ungreedy
# we have to sort by length so that the lexer continues working
# as expected when we have parsing rules like <% for block and
# <%= for variables. (if someone wants asp like syntax)
# variables are just part of the rules if variable processing
# is required.
root_tag_rules = compile_rules(environment)
block_start_re = e(environment.block_start_string)
block_end_re = e(environment.block_end_string)
comment_end_re = e(environment.comment_end_string)
variable_end_re = e(environment.variable_end_string)
# block suffix if trimming is enabled
block_suffix_re = "\\n?" if environment.trim_blocks else ""
# If lstrip is enabled, it should not be applied if there is any
# non-whitespace between the newline and block.
self.lstrip_unless_re = c(r"[^ \t]") if environment.lstrip_blocks else None
self.newline_sequence = environment.newline_sequence
self.keep_trailing_newline = environment.keep_trailing_newline
root_raw_re = (
fr"(?P<raw_begin>{block_start_re}(\-|\+|)\s*raw\s*"
fr"(?:\-{block_end_re}\s*|{block_end_re}))"
)
root_parts_re = "|".join(
[root_raw_re] + [fr"(?P<{n}>{r}(\-|\+|))" for n, r in root_tag_rules]
)
# global lexing rules
self.rules: t.Dict[str, t.List[_Rule]] = {
"root": [
# directives
_Rule(
c(fr"(.*?)(?:{root_parts_re})"),
OptionalLStrip(TOKEN_DATA, "#bygroup"), # type: ignore
"#bygroup",
),
# data
_Rule(c(".+"), TOKEN_DATA, None),
],
# comments
TOKEN_COMMENT_BEGIN: [
_Rule(
c(
fr"(.*?)((?:\+{comment_end_re}|\-{comment_end_re}\s*"
fr"|{comment_end_re}{block_suffix_re}))"
),
(TOKEN_COMMENT, TOKEN_COMMENT_END),
"#pop",
),
_Rule(c(r"(.)"), (Failure("Missing end of comment tag"),), None),
],
# blocks
TOKEN_BLOCK_BEGIN: [
_Rule(
c(
fr"(?:\+{block_end_re}|\-{block_end_re}\s*"
fr"|{block_end_re}{block_suffix_re})"
),
TOKEN_BLOCK_END,
"#pop",
),
]
+ tag_rules,
# variables
TOKEN_VARIABLE_BEGIN: [
_Rule(
c(fr"\-{variable_end_re}\s*|{variable_end_re}"),
TOKEN_VARIABLE_END,
"#pop",
)
]
+ tag_rules,
# raw block
TOKEN_RAW_BEGIN: [
_Rule(
c(
fr"(.*?)((?:{block_start_re}(\-|\+|))\s*endraw\s*"
fr"(?:\+{block_end_re}|\-{block_end_re}\s*"
fr"|{block_end_re}{block_suffix_re}))"
),
OptionalLStrip(TOKEN_DATA, TOKEN_RAW_END), # type: ignore
"#pop",
),
_Rule(c(r"(.)"), (Failure("Missing end of raw directive"),), None),
],
# line statements
TOKEN_LINESTATEMENT_BEGIN: [
_Rule(c(r"\s*(\n|$)"), TOKEN_LINESTATEMENT_END, "#pop")
]
+ tag_rules,
# line comments
TOKEN_LINECOMMENT_BEGIN: [
_Rule(
c(r"(.*?)()(?=\n|$)"),
(TOKEN_LINECOMMENT, TOKEN_LINECOMMENT_END),
"#pop",
)
],
}
def _normalize_newlines(self, value: str) -> str:
"""Replace all newlines with the configured sequence in strings
and template data.
"""
return newline_re.sub(self.newline_sequence, value)
def tokenize(
self,
source: str,
name: t.Optional[str] = None,
filename: t.Optional[str] = None,
state: t.Optional[str] = None,
) -> TokenStream:
"""Calls tokeniter + tokenize and wraps it in a token stream."""
stream = self.tokeniter(source, name, filename, state)
return TokenStream(self.wrap(stream, name, filename), name, filename)
def wrap(
self,
stream: t.Iterable[t.Tuple[int, str, str]],
name: t.Optional[str] = None,
filename: t.Optional[str] = None,
) -> t.Iterator[Token]:
"""This is called with the stream as returned by `tokenize` and wraps
every token in a :class:`Token` and converts the value.
"""
for lineno, token, value_str in stream:
if token in ignored_tokens:
continue
value: t.Any = value_str
if token == TOKEN_LINESTATEMENT_BEGIN:
token = TOKEN_BLOCK_BEGIN
elif token == TOKEN_LINESTATEMENT_END:
token = TOKEN_BLOCK_END
# we are not interested in those tokens in the parser
elif token in (TOKEN_RAW_BEGIN, TOKEN_RAW_END):
continue
elif token == TOKEN_DATA:
value = self._normalize_newlines(value_str)
elif token == "keyword":
token = value_str
elif token == TOKEN_NAME:
value = value_str
if not value.isidentifier():
raise TemplateSyntaxError(
"Invalid character in identifier", lineno, name, filename
)
elif token == TOKEN_STRING:
# try to unescape string
try:
value = (
self._normalize_newlines(value_str[1:-1])
.encode("ascii", "backslashreplace")
.decode("unicode-escape")
)
except Exception as e:
msg = str(e).split(":")[-1].strip()
raise TemplateSyntaxError(msg, lineno, name, filename) from e
elif token == TOKEN_INTEGER:
value = int(value_str.replace("_", ""), 0)
elif token == TOKEN_FLOAT:
# remove all "_" first to support more Python versions
value = literal_eval(value_str.replace("_", ""))
elif token == TOKEN_OPERATOR:
token = operators[value_str]
yield Token(lineno, token, value)
def tokeniter(
self,
source: str,
name: t.Optional[str],
filename: t.Optional[str] = None,
state: t.Optional[str] = None,
) -> t.Iterator[t.Tuple[int, str, str]]:
"""This method tokenizes the text and returns the tokens in a
generator. Use this method if you just want to tokenize a template.
.. versionchanged:: 3.0
Only ``\\n``, ``\\r\\n`` and ``\\r`` are treated as line
breaks.
"""
lines = newline_re.split(source)[::2]
if not self.keep_trailing_newline and lines[-1] == "":
del lines[-1]
source = "\n".join(lines)
pos = 0
lineno = 1
stack = ["root"]
if state is not None and state != "root":
assert state in ("variable", "block"), "invalid state"
stack.append(state + "_begin")
statetokens = self.rules[stack[-1]]
source_length = len(source)
balancing_stack: t.List[str] = []
lstrip_unless_re = self.lstrip_unless_re
newlines_stripped = 0
line_starting = True
while True:
# tokenizer loop
for regex, tokens, new_state in statetokens:
m = regex.match(source, pos)
# if no match we try again with the next rule
if m is None:
continue
# we only match blocks and variables if braces / parentheses
# are balanced. continue parsing with the lower rule which
# is the operator rule. do this only if the end tags look
# like operators
if balancing_stack and tokens in (
TOKEN_VARIABLE_END,
TOKEN_BLOCK_END,
TOKEN_LINESTATEMENT_END,
):
continue
# tuples support more options
if isinstance(tokens, tuple):
groups = m.groups()
if isinstance(tokens, OptionalLStrip):
# Rule supports lstrip. Match will look like
# text, block type, whitespace control, type, control, ...
text = groups[0]
# Skipping the text and first type, every other group is the
# whitespace control for each type. One of the groups will be
# -, +, or empty string instead of None.
strip_sign = next(g for g in groups[2::2] if g is not None)
if strip_sign == "-":
# Strip all whitespace between the text and the tag.
stripped = text.rstrip()
newlines_stripped = text[len(stripped) :].count("\n")
groups = [stripped, *groups[1:]]
elif (
# Not marked for preserving whitespace.
strip_sign != "+"
# lstrip is enabled.
and lstrip_unless_re is not None
# Not a variable expression.
and not m.groupdict().get(TOKEN_VARIABLE_BEGIN)
):
# The start of text between the last newline and the tag.
l_pos = text.rfind("\n") + 1
if l_pos > 0 or line_starting:
# If there's only whitespace between the newline and the
# tag, strip it.
if not lstrip_unless_re.search(text, l_pos):
groups = [text[:l_pos], *groups[1:]]
for idx, token in enumerate(tokens):
# failure group
if token.__class__ is Failure:
raise token(lineno, filename)
# bygroup is a bit more complex, in that case we
# yield for the current token the first named
# group that matched
elif token == "#bygroup":
for key, value in m.groupdict().items():
if value is not None:
yield lineno, key, value
lineno += value.count("\n")
break
else:
raise RuntimeError(
f"{regex!r} wanted to resolve the token dynamically"
" but no group matched"
)
# normal group
else:
data = groups[idx]
if data or token not in ignore_if_empty:
yield lineno, token, data
lineno += data.count("\n") + newlines_stripped
newlines_stripped = 0
# strings as token just are yielded as it.
else:
data = m.group()
# update brace/parentheses balance
if tokens == TOKEN_OPERATOR:
if data == "{":
balancing_stack.append("}")
elif data == "(":
balancing_stack.append(")")
elif data == "[":
balancing_stack.append("]")
elif data in ("}", ")", "]"):
if not balancing_stack:
raise TemplateSyntaxError(
f"unexpected '{data}'", lineno, name, filename
)
expected_op = balancing_stack.pop()
if expected_op != data:
raise TemplateSyntaxError(
f"unexpected '{data}', expected '{expected_op}'",
lineno,
name,
filename,
)
# yield items
if data or tokens not in ignore_if_empty:
yield lineno, tokens, data
lineno += data.count("\n")
line_starting = m.group()[-1:] == "\n"
# fetch new position into new variable so that we can check
# if there is a internal parsing error which would result
# in an infinite loop
pos2 = m.end()
# handle state changes
if new_state is not None:
# remove the uppermost state
if new_state == "#pop":
stack.pop()
# resolve the new state by group checking
elif new_state == "#bygroup":
for key, value in m.groupdict().items():
if value is not None:
stack.append(key)
break
else:
raise RuntimeError(
f"{regex!r} wanted to resolve the new state dynamically"
f" but no group matched"
)
# direct state name given
else:
stack.append(new_state)
statetokens = self.rules[stack[-1]]
# we are still at the same position and no stack change.
# this means a loop without break condition, avoid that and
# raise error
elif pos2 == pos:
raise RuntimeError(
f"{regex!r} yielded empty string without stack change"
)
# publish new function and start again
pos = pos2
break
# if loop terminated without break we haven't found a single match
# either we are at the end of the file or we have a problem
else:
# end of text
if pos >= source_length:
return
# something went wrong
raise TemplateSyntaxError(
f"unexpected char {source[pos]!r} at {pos}", lineno, name, filename
)

View file

@ -0,0 +1,652 @@
"""API and implementations for loading templates from different data
sources.
"""
import importlib.util
import os
import sys
import typing as t
import weakref
import zipimport
from collections import abc
from hashlib import sha1
from importlib import import_module
from types import ModuleType
from .exceptions import TemplateNotFound
from .utils import internalcode
from .utils import open_if_exists
if t.TYPE_CHECKING:
from .environment import Environment
from .environment import Template
def split_template_path(template: str) -> t.List[str]:
"""Split a path into segments and perform a sanity check. If it detects
'..' in the path it will raise a `TemplateNotFound` error.
"""
pieces = []
for piece in template.split("/"):
if (
os.path.sep in piece
or (os.path.altsep and os.path.altsep in piece)
or piece == os.path.pardir
):
raise TemplateNotFound(template)
elif piece and piece != ".":
pieces.append(piece)
return pieces
class BaseLoader:
"""Baseclass for all loaders. Subclass this and override `get_source` to
implement a custom loading mechanism. The environment provides a
`get_template` method that calls the loader's `load` method to get the
:class:`Template` object.
A very basic example for a loader that looks up templates on the file
system could look like this::
from jinja2 import BaseLoader, TemplateNotFound
from os.path import join, exists, getmtime
class MyLoader(BaseLoader):
def __init__(self, path):
self.path = path
def get_source(self, environment, template):
path = join(self.path, template)
if not exists(path):
raise TemplateNotFound(template)
mtime = getmtime(path)
with open(path) as f:
source = f.read()
return source, path, lambda: mtime == getmtime(path)
"""
#: if set to `False` it indicates that the loader cannot provide access
#: to the source of templates.
#:
#: .. versionadded:: 2.4
has_source_access = True
def get_source(
self, environment: "Environment", template: str
) -> t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]:
"""Get the template source, filename and reload helper for a template.
It's passed the environment and template name and has to return a
tuple in the form ``(source, filename, uptodate)`` or raise a
`TemplateNotFound` error if it can't locate the template.
The source part of the returned tuple must be the source of the
template as a string. The filename should be the name of the
file on the filesystem if it was loaded from there, otherwise
``None``. The filename is used by Python for the tracebacks
if no loader extension is used.
The last item in the tuple is the `uptodate` function. If auto
reloading is enabled it's always called to check if the template
changed. No arguments are passed so the function must store the
old state somewhere (for example in a closure). If it returns `False`
the template will be reloaded.
"""
if not self.has_source_access:
raise RuntimeError(
f"{type(self).__name__} cannot provide access to the source"
)
raise TemplateNotFound(template)
def list_templates(self) -> t.List[str]:
"""Iterates over all templates. If the loader does not support that
it should raise a :exc:`TypeError` which is the default behavior.
"""
raise TypeError("this loader cannot iterate over all templates")
@internalcode
def load(
self,
environment: "Environment",
name: str,
globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
) -> "Template":
"""Loads a template. This method looks up the template in the cache
or loads one by calling :meth:`get_source`. Subclasses should not
override this method as loaders working on collections of other
loaders (such as :class:`PrefixLoader` or :class:`ChoiceLoader`)
will not call this method but `get_source` directly.
"""
code = None
if globals is None:
globals = {}
# first we try to get the source for this template together
# with the filename and the uptodate function.
source, filename, uptodate = self.get_source(environment, name)
# try to load the code from the bytecode cache if there is a
# bytecode cache configured.
bcc = environment.bytecode_cache
if bcc is not None:
bucket = bcc.get_bucket(environment, name, filename, source)
code = bucket.code
# if we don't have code so far (not cached, no longer up to
# date) etc. we compile the template
if code is None:
code = environment.compile(source, name, filename)
# if the bytecode cache is available and the bucket doesn't
# have a code so far, we give the bucket the new code and put
# it back to the bytecode cache.
if bcc is not None and bucket.code is None:
bucket.code = code
bcc.set_bucket(bucket)
return environment.template_class.from_code(
environment, code, globals, uptodate
)
class FileSystemLoader(BaseLoader):
"""Load templates from a directory in the file system.
The path can be relative or absolute. Relative paths are relative to
the current working directory.
.. code-block:: python
loader = FileSystemLoader("templates")
A list of paths can be given. The directories will be searched in
order, stopping at the first matching template.
.. code-block:: python
loader = FileSystemLoader(["/override/templates", "/default/templates"])
:param searchpath: A path, or list of paths, to the directory that
contains the templates.
:param encoding: Use this encoding to read the text from template
files.
:param followlinks: Follow symbolic links in the path.
.. versionchanged:: 2.8
Added the ``followlinks`` parameter.
"""
def __init__(
self,
searchpath: t.Union[str, os.PathLike, t.Sequence[t.Union[str, os.PathLike]]],
encoding: str = "utf-8",
followlinks: bool = False,
) -> None:
if not isinstance(searchpath, abc.Iterable) or isinstance(searchpath, str):
searchpath = [searchpath]
self.searchpath = [os.fspath(p) for p in searchpath]
self.encoding = encoding
self.followlinks = followlinks
def get_source(
self, environment: "Environment", template: str
) -> t.Tuple[str, str, t.Callable[[], bool]]:
pieces = split_template_path(template)
for searchpath in self.searchpath:
filename = os.path.join(searchpath, *pieces)
f = open_if_exists(filename)
if f is None:
continue
try:
contents = f.read().decode(self.encoding)
finally:
f.close()
mtime = os.path.getmtime(filename)
def uptodate() -> bool:
try:
return os.path.getmtime(filename) == mtime
except OSError:
return False
return contents, filename, uptodate
raise TemplateNotFound(template)
def list_templates(self) -> t.List[str]:
found = set()
for searchpath in self.searchpath:
walk_dir = os.walk(searchpath, followlinks=self.followlinks)
for dirpath, _, filenames in walk_dir:
for filename in filenames:
template = (
os.path.join(dirpath, filename)[len(searchpath) :]
.strip(os.path.sep)
.replace(os.path.sep, "/")
)
if template[:2] == "./":
template = template[2:]
if template not in found:
found.add(template)
return sorted(found)
class PackageLoader(BaseLoader):
"""Load templates from a directory in a Python package.
:param package_name: Import name of the package that contains the
template directory.
:param package_path: Directory within the imported package that
contains the templates.
:param encoding: Encoding of template files.
The following example looks up templates in the ``pages`` directory
within the ``project.ui`` package.
.. code-block:: python
loader = PackageLoader("project.ui", "pages")
Only packages installed as directories (standard pip behavior) or
zip/egg files (less common) are supported. The Python API for
introspecting data in packages is too limited to support other
installation methods the way this loader requires.
There is limited support for :pep:`420` namespace packages. The
template directory is assumed to only be in one namespace
contributor. Zip files contributing to a namespace are not
supported.
.. versionchanged:: 3.0
No longer uses ``setuptools`` as a dependency.
.. versionchanged:: 3.0
Limited PEP 420 namespace package support.
"""
def __init__(
self,
package_name: str,
package_path: "str" = "templates",
encoding: str = "utf-8",
) -> None:
package_path = os.path.normpath(package_path).rstrip(os.path.sep)
# normpath preserves ".", which isn't valid in zip paths.
if package_path == os.path.curdir:
package_path = ""
elif package_path[:2] == os.path.curdir + os.path.sep:
package_path = package_path[2:]
self.package_path = package_path
self.package_name = package_name
self.encoding = encoding
# Make sure the package exists. This also makes namespace
# packages work, otherwise get_loader returns None.
import_module(package_name)
spec = importlib.util.find_spec(package_name)
assert spec is not None, "An import spec was not found for the package."
loader = spec.loader
assert loader is not None, "A loader was not found for the package."
self._loader = loader
self._archive = None
template_root = None
if isinstance(loader, zipimport.zipimporter):
self._archive = loader.archive
pkgdir = next(iter(spec.submodule_search_locations)) # type: ignore
template_root = os.path.join(pkgdir, package_path)
else:
roots: t.List[str] = []
# One element for regular packages, multiple for namespace
# packages, or None for single module file.
if spec.submodule_search_locations:
roots.extend(spec.submodule_search_locations)
# A single module file, use the parent directory instead.
elif spec.origin is not None:
roots.append(os.path.dirname(spec.origin))
for root in roots:
root = os.path.join(root, package_path)
if os.path.isdir(root):
template_root = root
break
if template_root is None:
raise ValueError(
f"The {package_name!r} package was not installed in a"
" way that PackageLoader understands."
)
self._template_root = template_root
def get_source(
self, environment: "Environment", template: str
) -> t.Tuple[str, str, t.Optional[t.Callable[[], bool]]]:
p = os.path.join(self._template_root, *split_template_path(template))
up_to_date: t.Optional[t.Callable[[], bool]]
if self._archive is None:
# Package is a directory.
if not os.path.isfile(p):
raise TemplateNotFound(template)
with open(p, "rb") as f:
source = f.read()
mtime = os.path.getmtime(p)
def up_to_date() -> bool:
return os.path.isfile(p) and os.path.getmtime(p) == mtime
else:
# Package is a zip file.
try:
source = self._loader.get_data(p) # type: ignore
except OSError as e:
raise TemplateNotFound(template) from e
# Could use the zip's mtime for all template mtimes, but
# would need to safely reload the module if it's out of
# date, so just report it as always current.
up_to_date = None
return source.decode(self.encoding), p, up_to_date
def list_templates(self) -> t.List[str]:
results: t.List[str] = []
if self._archive is None:
# Package is a directory.
offset = len(self._template_root)
for dirpath, _, filenames in os.walk(self._template_root):
dirpath = dirpath[offset:].lstrip(os.path.sep)
results.extend(
os.path.join(dirpath, name).replace(os.path.sep, "/")
for name in filenames
)
else:
if not hasattr(self._loader, "_files"):
raise TypeError(
"This zip import does not have the required"
" metadata to list templates."
)
# Package is a zip file.
prefix = (
self._template_root[len(self._archive) :].lstrip(os.path.sep)
+ os.path.sep
)
offset = len(prefix)
for name in self._loader._files.keys(): # type: ignore
# Find names under the templates directory that aren't directories.
if name.startswith(prefix) and name[-1] != os.path.sep:
results.append(name[offset:].replace(os.path.sep, "/"))
results.sort()
return results
class DictLoader(BaseLoader):
"""Loads a template from a Python dict mapping template names to
template source. This loader is useful for unittesting:
>>> loader = DictLoader({'index.html': 'source here'})
Because auto reloading is rarely useful this is disabled per default.
"""
def __init__(self, mapping: t.Mapping[str, str]) -> None:
self.mapping = mapping
def get_source(
self, environment: "Environment", template: str
) -> t.Tuple[str, None, t.Callable[[], bool]]:
if template in self.mapping:
source = self.mapping[template]
return source, None, lambda: source == self.mapping.get(template)
raise TemplateNotFound(template)
def list_templates(self) -> t.List[str]:
return sorted(self.mapping)
class FunctionLoader(BaseLoader):
"""A loader that is passed a function which does the loading. The
function receives the name of the template and has to return either
a string with the template source, a tuple in the form ``(source,
filename, uptodatefunc)`` or `None` if the template does not exist.
>>> def load_template(name):
... if name == 'index.html':
... return '...'
...
>>> loader = FunctionLoader(load_template)
The `uptodatefunc` is a function that is called if autoreload is enabled
and has to return `True` if the template is still up to date. For more
details have a look at :meth:`BaseLoader.get_source` which has the same
return value.
"""
def __init__(
self,
load_func: t.Callable[
[str],
t.Optional[
t.Union[
str, t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]
]
],
],
) -> None:
self.load_func = load_func
def get_source(
self, environment: "Environment", template: str
) -> t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]:
rv = self.load_func(template)
if rv is None:
raise TemplateNotFound(template)
if isinstance(rv, str):
return rv, None, None
return rv
class PrefixLoader(BaseLoader):
"""A loader that is passed a dict of loaders where each loader is bound
to a prefix. The prefix is delimited from the template by a slash per
default, which can be changed by setting the `delimiter` argument to
something else::
loader = PrefixLoader({
'app1': PackageLoader('mypackage.app1'),
'app2': PackageLoader('mypackage.app2')
})
By loading ``'app1/index.html'`` the file from the app1 package is loaded,
by loading ``'app2/index.html'`` the file from the second.
"""
def __init__(
self, mapping: t.Mapping[str, BaseLoader], delimiter: str = "/"
) -> None:
self.mapping = mapping
self.delimiter = delimiter
def get_loader(self, template: str) -> t.Tuple[BaseLoader, str]:
try:
prefix, name = template.split(self.delimiter, 1)
loader = self.mapping[prefix]
except (ValueError, KeyError) as e:
raise TemplateNotFound(template) from e
return loader, name
def get_source(
self, environment: "Environment", template: str
) -> t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]:
loader, name = self.get_loader(template)
try:
return loader.get_source(environment, name)
except TemplateNotFound as e:
# re-raise the exception with the correct filename here.
# (the one that includes the prefix)
raise TemplateNotFound(template) from e
@internalcode
def load(
self,
environment: "Environment",
name: str,
globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
) -> "Template":
loader, local_name = self.get_loader(name)
try:
return loader.load(environment, local_name, globals)
except TemplateNotFound as e:
# re-raise the exception with the correct filename here.
# (the one that includes the prefix)
raise TemplateNotFound(name) from e
def list_templates(self) -> t.List[str]:
result = []
for prefix, loader in self.mapping.items():
for template in loader.list_templates():
result.append(prefix + self.delimiter + template)
return result
class ChoiceLoader(BaseLoader):
"""This loader works like the `PrefixLoader` just that no prefix is
specified. If a template could not be found by one loader the next one
is tried.
>>> loader = ChoiceLoader([
... FileSystemLoader('/path/to/user/templates'),
... FileSystemLoader('/path/to/system/templates')
... ])
This is useful if you want to allow users to override builtin templates
from a different location.
"""
def __init__(self, loaders: t.Sequence[BaseLoader]) -> None:
self.loaders = loaders
def get_source(
self, environment: "Environment", template: str
) -> t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]:
for loader in self.loaders:
try:
return loader.get_source(environment, template)
except TemplateNotFound:
pass
raise TemplateNotFound(template)
@internalcode
def load(
self,
environment: "Environment",
name: str,
globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
) -> "Template":
for loader in self.loaders:
try:
return loader.load(environment, name, globals)
except TemplateNotFound:
pass
raise TemplateNotFound(name)
def list_templates(self) -> t.List[str]:
found = set()
for loader in self.loaders:
found.update(loader.list_templates())
return sorted(found)
class _TemplateModule(ModuleType):
"""Like a normal module but with support for weak references"""
class ModuleLoader(BaseLoader):
"""This loader loads templates from precompiled templates.
Example usage:
>>> loader = ChoiceLoader([
... ModuleLoader('/path/to/compiled/templates'),
... FileSystemLoader('/path/to/templates')
... ])
Templates can be precompiled with :meth:`Environment.compile_templates`.
"""
has_source_access = False
def __init__(
self, path: t.Union[str, os.PathLike, t.Sequence[t.Union[str, os.PathLike]]]
) -> None:
package_name = f"_jinja2_module_templates_{id(self):x}"
# create a fake module that looks for the templates in the
# path given.
mod = _TemplateModule(package_name)
if not isinstance(path, abc.Iterable) or isinstance(path, str):
path = [path]
mod.__path__ = [os.fspath(p) for p in path] # type: ignore
sys.modules[package_name] = weakref.proxy(
mod, lambda x: sys.modules.pop(package_name, None)
)
# the only strong reference, the sys.modules entry is weak
# so that the garbage collector can remove it once the
# loader that created it goes out of business.
self.module = mod
self.package_name = package_name
@staticmethod
def get_template_key(name: str) -> str:
return "tmpl_" + sha1(name.encode("utf-8")).hexdigest()
@staticmethod
def get_module_filename(name: str) -> str:
return ModuleLoader.get_template_key(name) + ".py"
@internalcode
def load(
self,
environment: "Environment",
name: str,
globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
) -> "Template":
key = self.get_template_key(name)
module = f"{self.package_name}.{key}"
mod = getattr(self.module, module, None)
if mod is None:
try:
mod = __import__(module, None, None, ["root"])
except ImportError as e:
raise TemplateNotFound(name) from e
# remove the entry from sys.modules, we only want the attribute
# on the module object we have stored on the loader.
sys.modules.pop(module, None)
if globals is None:
globals = {}
return environment.template_class.from_module_dict(
environment, mod.__dict__, globals
)

View file

@ -0,0 +1,111 @@
"""Functions that expose information about templates that might be
interesting for introspection.
"""
import typing as t
from . import nodes
from .compiler import CodeGenerator
from .compiler import Frame
if t.TYPE_CHECKING:
from .environment import Environment
class TrackingCodeGenerator(CodeGenerator):
"""We abuse the code generator for introspection."""
def __init__(self, environment: "Environment") -> None:
super().__init__(environment, "<introspection>", "<introspection>")
self.undeclared_identifiers: t.Set[str] = set()
def write(self, x: str) -> None:
"""Don't write."""
def enter_frame(self, frame: Frame) -> None:
"""Remember all undeclared identifiers."""
super().enter_frame(frame)
for _, (action, param) in frame.symbols.loads.items():
if action == "resolve" and param not in self.environment.globals:
self.undeclared_identifiers.add(param)
def find_undeclared_variables(ast: nodes.Template) -> t.Set[str]:
"""Returns a set of all variables in the AST that will be looked up from
the context at runtime. Because at compile time it's not known which
variables will be used depending on the path the execution takes at
runtime, all variables are returned.
>>> from jinja2 import Environment, meta
>>> env = Environment()
>>> ast = env.parse('{% set foo = 42 %}{{ bar + foo }}')
>>> meta.find_undeclared_variables(ast) == {'bar'}
True
.. admonition:: Implementation
Internally the code generator is used for finding undeclared variables.
This is good to know because the code generator might raise a
:exc:`TemplateAssertionError` during compilation and as a matter of
fact this function can currently raise that exception as well.
"""
codegen = TrackingCodeGenerator(ast.environment) # type: ignore
codegen.visit(ast)
return codegen.undeclared_identifiers
_ref_types = (nodes.Extends, nodes.FromImport, nodes.Import, nodes.Include)
_RefType = t.Union[nodes.Extends, nodes.FromImport, nodes.Import, nodes.Include]
def find_referenced_templates(ast: nodes.Template) -> t.Iterator[t.Optional[str]]:
"""Finds all the referenced templates from the AST. This will return an
iterator over all the hardcoded template extensions, inclusions and
imports. If dynamic inheritance or inclusion is used, `None` will be
yielded.
>>> from jinja2 import Environment, meta
>>> env = Environment()
>>> ast = env.parse('{% extends "layout.html" %}{% include helper %}')
>>> list(meta.find_referenced_templates(ast))
['layout.html', None]
This function is useful for dependency tracking. For example if you want
to rebuild parts of the website after a layout template has changed.
"""
template_name: t.Any
for node in ast.find_all(_ref_types):
template: nodes.Expr = node.template # type: ignore
if not isinstance(template, nodes.Const):
# a tuple with some non consts in there
if isinstance(template, (nodes.Tuple, nodes.List)):
for template_name in template.items:
# something const, only yield the strings and ignore
# non-string consts that really just make no sense
if isinstance(template_name, nodes.Const):
if isinstance(template_name.value, str):
yield template_name.value
# something dynamic in there
else:
yield None
# something dynamic we don't know about here
else:
yield None
continue
# constant is a basestring, direct template name
if isinstance(template.value, str):
yield template.value
# a tuple or list (latter *should* not happen) made of consts,
# yield the consts that are strings. We could warn here for
# non string values
elif isinstance(node, nodes.Include) and isinstance(
template.value, (tuple, list)
):
for template_name in template.value:
if isinstance(template_name, str):
yield template_name
# something else we don't care about, we could warn here
else:
yield None

View file

@ -0,0 +1,124 @@
import typing as t
from ast import literal_eval
from ast import parse
from itertools import chain
from itertools import islice
from . import nodes
from .compiler import CodeGenerator
from .compiler import Frame
from .compiler import has_safe_repr
from .environment import Environment
from .environment import Template
def native_concat(values: t.Iterable[t.Any]) -> t.Optional[t.Any]:
"""Return a native Python type from the list of compiled nodes. If
the result is a single node, its value is returned. Otherwise, the
nodes are concatenated as strings. If the result can be parsed with
:func:`ast.literal_eval`, the parsed value is returned. Otherwise,
the string is returned.
:param values: Iterable of outputs to concatenate.
"""
head = list(islice(values, 2))
if not head:
return None
if len(head) == 1:
raw = head[0]
if not isinstance(raw, str):
return raw
else:
raw = "".join([str(v) for v in chain(head, values)])
try:
return literal_eval(
# In Python 3.10+ ast.literal_eval removes leading spaces/tabs
# from the given string. For backwards compatibility we need to
# parse the string ourselves without removing leading spaces/tabs.
parse(raw, mode="eval")
)
except (ValueError, SyntaxError, MemoryError):
return raw
class NativeCodeGenerator(CodeGenerator):
"""A code generator which renders Python types by not adding
``str()`` around output nodes.
"""
@staticmethod
def _default_finalize(value: t.Any) -> t.Any:
return value
def _output_const_repr(self, group: t.Iterable[t.Any]) -> str:
return repr("".join([str(v) for v in group]))
def _output_child_to_const(
self, node: nodes.Expr, frame: Frame, finalize: CodeGenerator._FinalizeInfo
) -> t.Any:
const = node.as_const(frame.eval_ctx)
if not has_safe_repr(const):
raise nodes.Impossible()
if isinstance(node, nodes.TemplateData):
return const
return finalize.const(const) # type: ignore
def _output_child_pre(
self, node: nodes.Expr, frame: Frame, finalize: CodeGenerator._FinalizeInfo
) -> None:
if finalize.src is not None:
self.write(finalize.src)
def _output_child_post(
self, node: nodes.Expr, frame: Frame, finalize: CodeGenerator._FinalizeInfo
) -> None:
if finalize.src is not None:
self.write(")")
class NativeEnvironment(Environment):
"""An environment that renders templates to native Python types."""
code_generator_class = NativeCodeGenerator
class NativeTemplate(Template):
environment_class = NativeEnvironment
def render(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Render the template to produce a native Python type. If the
result is a single node, its value is returned. Otherwise, the
nodes are concatenated as strings. If the result can be parsed
with :func:`ast.literal_eval`, the parsed value is returned.
Otherwise, the string is returned.
"""
ctx = self.new_context(dict(*args, **kwargs))
try:
return native_concat(self.root_render_func(ctx)) # type: ignore
except Exception:
return self.environment.handle_exception()
async def render_async(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if not self.environment.is_async:
raise RuntimeError(
"The environment was not created with async mode enabled."
)
ctx = self.new_context(dict(*args, **kwargs))
try:
return native_concat(
[n async for n in self.root_render_func(ctx)] # type: ignore
)
except Exception:
return self.environment.handle_exception()
NativeEnvironment.template_class = NativeTemplate

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,47 @@
"""The optimizer tries to constant fold expressions and modify the AST
in place so that it should be faster to evaluate.
Because the AST does not contain all the scoping information and the
compiler has to find that out, we cannot do all the optimizations we
want. For example, loop unrolling doesn't work because unrolled loops
would have a different scope. The solution would be a second syntax tree
that stored the scoping rules.
"""
import typing as t
from . import nodes
from .visitor import NodeTransformer
if t.TYPE_CHECKING:
from .environment import Environment
def optimize(node: nodes.Node, environment: "Environment") -> nodes.Node:
"""The context hint can be used to perform an static optimization
based on the context given."""
optimizer = Optimizer(environment)
return t.cast(nodes.Node, optimizer.visit(node))
class Optimizer(NodeTransformer):
def __init__(self, environment: "t.Optional[Environment]") -> None:
self.environment = environment
def generic_visit(
self, node: nodes.Node, *args: t.Any, **kwargs: t.Any
) -> nodes.Node:
node = super().generic_visit(node, *args, **kwargs)
# Do constant folding. Some other nodes besides Expr have
# as_const, but folding them causes errors later on.
if isinstance(node, nodes.Expr):
try:
return nodes.Const.from_untrusted(
node.as_const(args[0] if args else None),
lineno=node.lineno,
environment=self.environment,
)
except nodes.Impossible:
pass
return node

File diff suppressed because it is too large Load diff

View file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,428 @@
"""A sandbox layer that ensures unsafe operations cannot be performed.
Useful when the template itself comes from an untrusted source.
"""
import operator
import types
import typing as t
from _string import formatter_field_name_split # type: ignore
from collections import abc
from collections import deque
from string import Formatter
from markupsafe import EscapeFormatter
from markupsafe import Markup
from .environment import Environment
from .exceptions import SecurityError
from .runtime import Context
from .runtime import Undefined
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
#: maximum number of items a range may produce
MAX_RANGE = 100000
#: Unsafe function attributes.
UNSAFE_FUNCTION_ATTRIBUTES: t.Set[str] = set()
#: Unsafe method attributes. Function attributes are unsafe for methods too.
UNSAFE_METHOD_ATTRIBUTES: t.Set[str] = set()
#: unsafe generator attributes.
UNSAFE_GENERATOR_ATTRIBUTES = {"gi_frame", "gi_code"}
#: unsafe attributes on coroutines
UNSAFE_COROUTINE_ATTRIBUTES = {"cr_frame", "cr_code"}
#: unsafe attributes on async generators
UNSAFE_ASYNC_GENERATOR_ATTRIBUTES = {"ag_code", "ag_frame"}
_mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = (
(
abc.MutableSet,
frozenset(
[
"add",
"clear",
"difference_update",
"discard",
"pop",
"remove",
"symmetric_difference_update",
"update",
]
),
),
(
abc.MutableMapping,
frozenset(["clear", "pop", "popitem", "setdefault", "update"]),
),
(
abc.MutableSequence,
frozenset(["append", "reverse", "insert", "sort", "extend", "remove"]),
),
(
deque,
frozenset(
[
"append",
"appendleft",
"clear",
"extend",
"extendleft",
"pop",
"popleft",
"remove",
"rotate",
]
),
),
)
def inspect_format_method(callable: t.Callable) -> t.Optional[str]:
if not isinstance(
callable, (types.MethodType, types.BuiltinMethodType)
) or callable.__name__ not in ("format", "format_map"):
return None
obj = callable.__self__
if isinstance(obj, str):
return obj
return None
def safe_range(*args: int) -> range:
"""A range that can't generate ranges with a length of more than
MAX_RANGE items.
"""
rng = range(*args)
if len(rng) > MAX_RANGE:
raise OverflowError(
"Range too big. The sandbox blocks ranges larger than"
f" MAX_RANGE ({MAX_RANGE})."
)
return rng
def unsafe(f: F) -> F:
"""Marks a function or method as unsafe.
.. code-block: python
@unsafe
def delete(self):
pass
"""
f.unsafe_callable = True # type: ignore
return f
def is_internal_attribute(obj: t.Any, attr: str) -> bool:
"""Test if the attribute given is an internal python attribute. For
example this function returns `True` for the `func_code` attribute of
python objects. This is useful if the environment method
:meth:`~SandboxedEnvironment.is_safe_attribute` is overridden.
>>> from jinja2.sandbox import is_internal_attribute
>>> is_internal_attribute(str, "mro")
True
>>> is_internal_attribute(str, "upper")
False
"""
if isinstance(obj, types.FunctionType):
if attr in UNSAFE_FUNCTION_ATTRIBUTES:
return True
elif isinstance(obj, types.MethodType):
if attr in UNSAFE_FUNCTION_ATTRIBUTES or attr in UNSAFE_METHOD_ATTRIBUTES:
return True
elif isinstance(obj, type):
if attr == "mro":
return True
elif isinstance(obj, (types.CodeType, types.TracebackType, types.FrameType)):
return True
elif isinstance(obj, types.GeneratorType):
if attr in UNSAFE_GENERATOR_ATTRIBUTES:
return True
elif hasattr(types, "CoroutineType") and isinstance(obj, types.CoroutineType):
if attr in UNSAFE_COROUTINE_ATTRIBUTES:
return True
elif hasattr(types, "AsyncGeneratorType") and isinstance(
obj, types.AsyncGeneratorType
):
if attr in UNSAFE_ASYNC_GENERATOR_ATTRIBUTES:
return True
return attr.startswith("__")
def modifies_known_mutable(obj: t.Any, attr: str) -> bool:
"""This function checks if an attribute on a builtin mutable object
(list, dict, set or deque) or the corresponding ABCs would modify it
if called.
>>> modifies_known_mutable({}, "clear")
True
>>> modifies_known_mutable({}, "keys")
False
>>> modifies_known_mutable([], "append")
True
>>> modifies_known_mutable([], "index")
False
If called with an unsupported object, ``False`` is returned.
>>> modifies_known_mutable("foo", "upper")
False
"""
for typespec, unsafe in _mutable_spec:
if isinstance(obj, typespec):
return attr in unsafe
return False
class SandboxedEnvironment(Environment):
"""The sandboxed environment. It works like the regular environment but
tells the compiler to generate sandboxed code. Additionally subclasses of
this environment may override the methods that tell the runtime what
attributes or functions are safe to access.
If the template tries to access insecure code a :exc:`SecurityError` is
raised. However also other exceptions may occur during the rendering so
the caller has to ensure that all exceptions are caught.
"""
sandboxed = True
#: default callback table for the binary operators. A copy of this is
#: available on each instance of a sandboxed environment as
#: :attr:`binop_table`
default_binop_table: t.Dict[str, t.Callable[[t.Any, t.Any], t.Any]] = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
"//": operator.floordiv,
"**": operator.pow,
"%": operator.mod,
}
#: default callback table for the unary operators. A copy of this is
#: available on each instance of a sandboxed environment as
#: :attr:`unop_table`
default_unop_table: t.Dict[str, t.Callable[[t.Any], t.Any]] = {
"+": operator.pos,
"-": operator.neg,
}
#: a set of binary operators that should be intercepted. Each operator
#: that is added to this set (empty by default) is delegated to the
#: :meth:`call_binop` method that will perform the operator. The default
#: operator callback is specified by :attr:`binop_table`.
#:
#: The following binary operators are interceptable:
#: ``//``, ``%``, ``+``, ``*``, ``-``, ``/``, and ``**``
#:
#: The default operation form the operator table corresponds to the
#: builtin function. Intercepted calls are always slower than the native
#: operator call, so make sure only to intercept the ones you are
#: interested in.
#:
#: .. versionadded:: 2.6
intercepted_binops: t.FrozenSet[str] = frozenset()
#: a set of unary operators that should be intercepted. Each operator
#: that is added to this set (empty by default) is delegated to the
#: :meth:`call_unop` method that will perform the operator. The default
#: operator callback is specified by :attr:`unop_table`.
#:
#: The following unary operators are interceptable: ``+``, ``-``
#:
#: The default operation form the operator table corresponds to the
#: builtin function. Intercepted calls are always slower than the native
#: operator call, so make sure only to intercept the ones you are
#: interested in.
#:
#: .. versionadded:: 2.6
intercepted_unops: t.FrozenSet[str] = frozenset()
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.globals["range"] = safe_range
self.binop_table = self.default_binop_table.copy()
self.unop_table = self.default_unop_table.copy()
def is_safe_attribute(self, obj: t.Any, attr: str, value: t.Any) -> bool:
"""The sandboxed environment will call this method to check if the
attribute of an object is safe to access. Per default all attributes
starting with an underscore are considered private as well as the
special attributes of internal python objects as returned by the
:func:`is_internal_attribute` function.
"""
return not (attr.startswith("_") or is_internal_attribute(obj, attr))
def is_safe_callable(self, obj: t.Any) -> bool:
"""Check if an object is safely callable. By default callables
are considered safe unless decorated with :func:`unsafe`.
This also recognizes the Django convention of setting
``func.alters_data = True``.
"""
return not (
getattr(obj, "unsafe_callable", False) or getattr(obj, "alters_data", False)
)
def call_binop(
self, context: Context, operator: str, left: t.Any, right: t.Any
) -> t.Any:
"""For intercepted binary operator calls (:meth:`intercepted_binops`)
this function is executed instead of the builtin operator. This can
be used to fine tune the behavior of certain operators.
.. versionadded:: 2.6
"""
return self.binop_table[operator](left, right)
def call_unop(self, context: Context, operator: str, arg: t.Any) -> t.Any:
"""For intercepted unary operator calls (:meth:`intercepted_unops`)
this function is executed instead of the builtin operator. This can
be used to fine tune the behavior of certain operators.
.. versionadded:: 2.6
"""
return self.unop_table[operator](arg)
def getitem(
self, obj: t.Any, argument: t.Union[str, t.Any]
) -> t.Union[t.Any, Undefined]:
"""Subscribe an object from sandboxed code."""
try:
return obj[argument]
except (TypeError, LookupError):
if isinstance(argument, str):
try:
attr = str(argument)
except Exception:
pass
else:
try:
value = getattr(obj, attr)
except AttributeError:
pass
else:
if self.is_safe_attribute(obj, argument, value):
return value
return self.unsafe_undefined(obj, argument)
return self.undefined(obj=obj, name=argument)
def getattr(self, obj: t.Any, attribute: str) -> t.Union[t.Any, Undefined]:
"""Subscribe an object from sandboxed code and prefer the
attribute. The attribute passed *must* be a bytestring.
"""
try:
value = getattr(obj, attribute)
except AttributeError:
try:
return obj[attribute]
except (TypeError, LookupError):
pass
else:
if self.is_safe_attribute(obj, attribute, value):
return value
return self.unsafe_undefined(obj, attribute)
return self.undefined(obj=obj, name=attribute)
def unsafe_undefined(self, obj: t.Any, attribute: str) -> Undefined:
"""Return an undefined object for unsafe attributes."""
return self.undefined(
f"access to attribute {attribute!r} of"
f" {type(obj).__name__!r} object is unsafe.",
name=attribute,
obj=obj,
exc=SecurityError,
)
def format_string(
self,
s: str,
args: t.Tuple[t.Any, ...],
kwargs: t.Dict[str, t.Any],
format_func: t.Optional[t.Callable] = None,
) -> str:
"""If a format call is detected, then this is routed through this
method so that our safety sandbox can be used for it.
"""
formatter: SandboxedFormatter
if isinstance(s, Markup):
formatter = SandboxedEscapeFormatter(self, escape=s.escape)
else:
formatter = SandboxedFormatter(self)
if format_func is not None and format_func.__name__ == "format_map":
if len(args) != 1 or kwargs:
raise TypeError(
"format_map() takes exactly one argument"
f" {len(args) + (kwargs is not None)} given"
)
kwargs = args[0]
args = ()
rv = formatter.vformat(s, args, kwargs)
return type(s)(rv)
def call(
__self, # noqa: B902
__context: Context,
__obj: t.Any,
*args: t.Any,
**kwargs: t.Any,
) -> t.Any:
"""Call an object from sandboxed code."""
fmt = inspect_format_method(__obj)
if fmt is not None:
return __self.format_string(fmt, args, kwargs, __obj)
# the double prefixes are to avoid double keyword argument
# errors when proxying the call.
if not __self.is_safe_callable(__obj):
raise SecurityError(f"{__obj!r} is not safely callable")
return __context.call(__obj, *args, **kwargs)
class ImmutableSandboxedEnvironment(SandboxedEnvironment):
"""Works exactly like the regular `SandboxedEnvironment` but does not
permit modifications on the builtin mutable objects `list`, `set`, and
`dict` by using the :func:`modifies_known_mutable` function.
"""
def is_safe_attribute(self, obj: t.Any, attr: str, value: t.Any) -> bool:
if not super().is_safe_attribute(obj, attr, value):
return False
return not modifies_known_mutable(obj, attr)
class SandboxedFormatter(Formatter):
def __init__(self, env: Environment, **kwargs: t.Any) -> None:
self._env = env
super().__init__(**kwargs) # type: ignore
def get_field(
self, field_name: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> t.Tuple[t.Any, str]:
first, rest = formatter_field_name_split(field_name)
obj = self.get_value(first, args, kwargs)
for is_attr, i in rest:
if is_attr:
obj = self._env.getattr(obj, i)
else:
obj = self._env.getitem(obj, i)
return obj, first
class SandboxedEscapeFormatter(SandboxedFormatter, EscapeFormatter):
pass

View file

@ -0,0 +1,255 @@
"""Built-in template tests used with the ``is`` operator."""
import operator
import typing as t
from collections import abc
from numbers import Number
from .runtime import Undefined
from .utils import pass_environment
if t.TYPE_CHECKING:
from .environment import Environment
def test_odd(value: int) -> bool:
"""Return true if the variable is odd."""
return value % 2 == 1
def test_even(value: int) -> bool:
"""Return true if the variable is even."""
return value % 2 == 0
def test_divisibleby(value: int, num: int) -> bool:
"""Check if a variable is divisible by a number."""
return value % num == 0
def test_defined(value: t.Any) -> bool:
"""Return true if the variable is defined:
.. sourcecode:: jinja
{% if variable is defined %}
value of variable: {{ variable }}
{% else %}
variable is not defined
{% endif %}
See the :func:`default` filter for a simple way to set undefined
variables.
"""
return not isinstance(value, Undefined)
def test_undefined(value: t.Any) -> bool:
"""Like :func:`defined` but the other way round."""
return isinstance(value, Undefined)
@pass_environment
def test_filter(env: "Environment", value: str) -> bool:
"""Check if a filter exists by name. Useful if a filter may be
optionally available.
.. code-block:: jinja
{% if 'markdown' is filter %}
{{ value | markdown }}
{% else %}
{{ value }}
{% endif %}
.. versionadded:: 3.0
"""
return value in env.filters
@pass_environment
def test_test(env: "Environment", value: str) -> bool:
"""Check if a test exists by name. Useful if a test may be
optionally available.
.. code-block:: jinja
{% if 'loud' is test %}
{% if value is loud %}
{{ value|upper }}
{% else %}
{{ value|lower }}
{% endif %}
{% else %}
{{ value }}
{% endif %}
.. versionadded:: 3.0
"""
return value in env.tests
def test_none(value: t.Any) -> bool:
"""Return true if the variable is none."""
return value is None
def test_boolean(value: t.Any) -> bool:
"""Return true if the object is a boolean value.
.. versionadded:: 2.11
"""
return value is True or value is False
def test_false(value: t.Any) -> bool:
"""Return true if the object is False.
.. versionadded:: 2.11
"""
return value is False
def test_true(value: t.Any) -> bool:
"""Return true if the object is True.
.. versionadded:: 2.11
"""
return value is True
# NOTE: The existing 'number' test matches booleans and floats
def test_integer(value: t.Any) -> bool:
"""Return true if the object is an integer.
.. versionadded:: 2.11
"""
return isinstance(value, int) and value is not True and value is not False
# NOTE: The existing 'number' test matches booleans and integers
def test_float(value: t.Any) -> bool:
"""Return true if the object is a float.
.. versionadded:: 2.11
"""
return isinstance(value, float)
def test_lower(value: str) -> bool:
"""Return true if the variable is lowercased."""
return str(value).islower()
def test_upper(value: str) -> bool:
"""Return true if the variable is uppercased."""
return str(value).isupper()
def test_string(value: t.Any) -> bool:
"""Return true if the object is a string."""
return isinstance(value, str)
def test_mapping(value: t.Any) -> bool:
"""Return true if the object is a mapping (dict etc.).
.. versionadded:: 2.6
"""
return isinstance(value, abc.Mapping)
def test_number(value: t.Any) -> bool:
"""Return true if the variable is a number."""
return isinstance(value, Number)
def test_sequence(value: t.Any) -> bool:
"""Return true if the variable is a sequence. Sequences are variables
that are iterable.
"""
try:
len(value)
value.__getitem__
except Exception:
return False
return True
def test_sameas(value: t.Any, other: t.Any) -> bool:
"""Check if an object points to the same memory address than another
object:
.. sourcecode:: jinja
{% if foo.attribute is sameas false %}
the foo attribute really is the `False` singleton
{% endif %}
"""
return value is other
def test_iterable(value: t.Any) -> bool:
"""Check if it's possible to iterate over an object."""
try:
iter(value)
except TypeError:
return False
return True
def test_escaped(value: t.Any) -> bool:
"""Check if the value is escaped."""
return hasattr(value, "__html__")
def test_in(value: t.Any, seq: t.Container) -> bool:
"""Check if value is in seq.
.. versionadded:: 2.10
"""
return value in seq
TESTS = {
"odd": test_odd,
"even": test_even,
"divisibleby": test_divisibleby,
"defined": test_defined,
"undefined": test_undefined,
"filter": test_filter,
"test": test_test,
"none": test_none,
"boolean": test_boolean,
"false": test_false,
"true": test_true,
"integer": test_integer,
"float": test_float,
"lower": test_lower,
"upper": test_upper,
"string": test_string,
"mapping": test_mapping,
"number": test_number,
"sequence": test_sequence,
"iterable": test_iterable,
"callable": callable,
"sameas": test_sameas,
"escaped": test_escaped,
"in": test_in,
"==": operator.eq,
"eq": operator.eq,
"equalto": operator.eq,
"!=": operator.ne,
"ne": operator.ne,
">": operator.gt,
"gt": operator.gt,
"greaterthan": operator.gt,
"ge": operator.ge,
">=": operator.ge,
"<": operator.lt,
"lt": operator.lt,
"lessthan": operator.lt,
"<=": operator.le,
"le": operator.le,
}

View file

@ -0,0 +1,854 @@
import enum
import json
import os
import re
import typing as t
import warnings
from collections import abc
from collections import deque
from random import choice
from random import randrange
from threading import Lock
from types import CodeType
from urllib.parse import quote_from_bytes
import markupsafe
if t.TYPE_CHECKING:
import typing_extensions as te
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
# special singleton representing missing values for the runtime
missing: t.Any = type("MissingType", (), {"__repr__": lambda x: "missing"})()
internal_code: t.MutableSet[CodeType] = set()
concat = "".join
def pass_context(f: F) -> F:
"""Pass the :class:`~jinja2.runtime.Context` as the first argument
to the decorated function when called while rendering a template.
Can be used on functions, filters, and tests.
If only ``Context.eval_context`` is needed, use
:func:`pass_eval_context`. If only ``Context.environment`` is
needed, use :func:`pass_environment`.
.. versionadded:: 3.0.0
Replaces ``contextfunction`` and ``contextfilter``.
"""
f.jinja_pass_arg = _PassArg.context # type: ignore
return f
def pass_eval_context(f: F) -> F:
"""Pass the :class:`~jinja2.nodes.EvalContext` as the first argument
to the decorated function when called while rendering a template.
See :ref:`eval-context`.
Can be used on functions, filters, and tests.
If only ``EvalContext.environment`` is needed, use
:func:`pass_environment`.
.. versionadded:: 3.0.0
Replaces ``evalcontextfunction`` and ``evalcontextfilter``.
"""
f.jinja_pass_arg = _PassArg.eval_context # type: ignore
return f
def pass_environment(f: F) -> F:
"""Pass the :class:`~jinja2.Environment` as the first argument to
the decorated function when called while rendering a template.
Can be used on functions, filters, and tests.
.. versionadded:: 3.0.0
Replaces ``environmentfunction`` and ``environmentfilter``.
"""
f.jinja_pass_arg = _PassArg.environment # type: ignore
return f
class _PassArg(enum.Enum):
context = enum.auto()
eval_context = enum.auto()
environment = enum.auto()
@classmethod
def from_obj(cls, obj: F) -> t.Optional["_PassArg"]:
if hasattr(obj, "jinja_pass_arg"):
return obj.jinja_pass_arg # type: ignore
for prefix in "context", "eval_context", "environment":
squashed = prefix.replace("_", "")
for name in f"{squashed}function", f"{squashed}filter":
if getattr(obj, name, False) is True:
warnings.warn(
f"{name!r} is deprecated and will stop working"
f" in Jinja 3.1. Use 'pass_{prefix}' instead.",
DeprecationWarning,
stacklevel=2,
)
return cls[prefix]
return None
def contextfunction(f: F) -> F:
"""Pass the context as the first argument to the decorated function.
.. deprecated:: 3.0
Will be removed in Jinja 3.1. Use :func:`~jinja2.pass_context`
instead.
"""
warnings.warn(
"'contextfunction' is renamed to 'pass_context', the old name"
" will be removed in Jinja 3.1.",
DeprecationWarning,
stacklevel=2,
)
return pass_context(f)
def evalcontextfunction(f: F) -> F:
"""Pass the eval context as the first argument to the decorated
function.
.. deprecated:: 3.0
Will be removed in Jinja 3.1. Use
:func:`~jinja2.pass_eval_context` instead.
.. versionadded:: 2.4
"""
warnings.warn(
"'evalcontextfunction' is renamed to 'pass_eval_context', the"
" old name will be removed in Jinja 3.1.",
DeprecationWarning,
stacklevel=2,
)
return pass_eval_context(f)
def environmentfunction(f: F) -> F:
"""Pass the environment as the first argument to the decorated
function.
.. deprecated:: 3.0
Will be removed in Jinja 3.1. Use
:func:`~jinja2.pass_environment` instead.
"""
warnings.warn(
"'environmentfunction' is renamed to 'pass_environment', the"
" old name will be removed in Jinja 3.1.",
DeprecationWarning,
stacklevel=2,
)
return pass_environment(f)
def internalcode(f: F) -> F:
"""Marks the function as internally used"""
internal_code.add(f.__code__)
return f
def is_undefined(obj: t.Any) -> bool:
"""Check if the object passed is undefined. This does nothing more than
performing an instance check against :class:`Undefined` but looks nicer.
This can be used for custom filters or tests that want to react to
undefined variables. For example a custom default filter can look like
this::
def default(var, default=''):
if is_undefined(var):
return default
return var
"""
from .runtime import Undefined
return isinstance(obj, Undefined)
def consume(iterable: t.Iterable[t.Any]) -> None:
"""Consumes an iterable without doing anything with it."""
for _ in iterable:
pass
def clear_caches() -> None:
"""Jinja keeps internal caches for environments and lexers. These are
used so that Jinja doesn't have to recreate environments and lexers all
the time. Normally you don't have to care about that but if you are
measuring memory consumption you may want to clean the caches.
"""
from .environment import get_spontaneous_environment
from .lexer import _lexer_cache
get_spontaneous_environment.cache_clear()
_lexer_cache.clear()
def import_string(import_name: str, silent: bool = False) -> t.Any:
"""Imports an object based on a string. This is useful if you want to
use import paths as endpoints or something similar. An import path can
be specified either in dotted notation (``xml.sax.saxutils.escape``)
or with a colon as object delimiter (``xml.sax.saxutils:escape``).
If the `silent` is True the return value will be `None` if the import
fails.
:return: imported object
"""
try:
if ":" in import_name:
module, obj = import_name.split(":", 1)
elif "." in import_name:
module, _, obj = import_name.rpartition(".")
else:
return __import__(import_name)
return getattr(__import__(module, None, None, [obj]), obj)
except (ImportError, AttributeError):
if not silent:
raise
def open_if_exists(filename: str, mode: str = "rb") -> t.Optional[t.IO]:
"""Returns a file descriptor for the filename if that file exists,
otherwise ``None``.
"""
if not os.path.isfile(filename):
return None
return open(filename, mode)
def object_type_repr(obj: t.Any) -> str:
"""Returns the name of the object's type. For some recognized
singletons the name of the object is returned instead. (For
example for `None` and `Ellipsis`).
"""
if obj is None:
return "None"
elif obj is Ellipsis:
return "Ellipsis"
cls = type(obj)
if cls.__module__ == "builtins":
return f"{cls.__name__} object"
return f"{cls.__module__}.{cls.__name__} object"
def pformat(obj: t.Any) -> str:
"""Format an object using :func:`pprint.pformat`."""
from pprint import pformat # type: ignore
return pformat(obj)
_http_re = re.compile(
r"""
^
(
(https?://|www\.) # scheme or www
(([\w%-]+\.)+)? # subdomain
(
[a-z]{2,63} # basic tld
|
xn--[\w%]{2,59} # idna tld
)
|
([\w%-]{2,63}\.)+ # basic domain
(com|net|int|edu|gov|org|info|mil) # basic tld
|
(https?://) # scheme
(
(([\d]{1,3})(\.[\d]{1,3}){3}) # IPv4
|
(\[([\da-f]{0,4}:){2}([\da-f]{0,4}:?){1,6}]) # IPv6
)
)
(?::[\d]{1,5})? # port
(?:[/?#]\S*)? # path, query, and fragment
$
""",
re.IGNORECASE | re.VERBOSE,
)
_email_re = re.compile(r"^\S+@\w[\w.-]*\.\w+$")
def urlize(
text: str,
trim_url_limit: t.Optional[int] = None,
rel: t.Optional[str] = None,
target: t.Optional[str] = None,
extra_schemes: t.Optional[t.Iterable[str]] = None,
) -> str:
"""Convert URLs in text into clickable links.
This may not recognize links in some situations. Usually, a more
comprehensive formatter, such as a Markdown library, is a better
choice.
Works on ``http://``, ``https://``, ``www.``, ``mailto:``, and email
addresses. Links with trailing punctuation (periods, commas, closing
parentheses) and leading punctuation (opening parentheses) are
recognized excluding the punctuation. Email addresses that include
header fields are not recognized (for example,
``mailto:address@example.com?cc=copy@example.com``).
:param text: Original text containing URLs to link.
:param trim_url_limit: Shorten displayed URL values to this length.
:param target: Add the ``target`` attribute to links.
:param rel: Add the ``rel`` attribute to links.
:param extra_schemes: Recognize URLs that start with these schemes
in addition to the default behavior.
.. versionchanged:: 3.0
The ``extra_schemes`` parameter was added.
.. versionchanged:: 3.0
Generate ``https://`` links for URLs without a scheme.
.. versionchanged:: 3.0
The parsing rules were updated. Recognize email addresses with
or without the ``mailto:`` scheme. Validate IP addresses. Ignore
parentheses and brackets in more cases.
"""
if trim_url_limit is not None:
def trim_url(x: str) -> str:
if len(x) > trim_url_limit: # type: ignore
return f"{x[:trim_url_limit]}..."
return x
else:
def trim_url(x: str) -> str:
return x
words = re.split(r"(\s+)", str(markupsafe.escape(text)))
rel_attr = f' rel="{markupsafe.escape(rel)}"' if rel else ""
target_attr = f' target="{markupsafe.escape(target)}"' if target else ""
for i, word in enumerate(words):
head, middle, tail = "", word, ""
match = re.match(r"^([(<]|&lt;)+", middle)
if match:
head = match.group()
middle = middle[match.end() :]
# Unlike lead, which is anchored to the start of the string,
# need to check that the string ends with any of the characters
# before trying to match all of them, to avoid backtracking.
if middle.endswith((")", ">", ".", ",", "\n", "&gt;")):
match = re.search(r"([)>.,\n]|&gt;)+$", middle)
if match:
tail = match.group()
middle = middle[: match.start()]
# Prefer balancing parentheses in URLs instead of ignoring a
# trailing character.
for start_char, end_char in ("(", ")"), ("<", ">"), ("&lt;", "&gt;"):
start_count = middle.count(start_char)
if start_count <= middle.count(end_char):
# Balanced, or lighter on the left
continue
# Move as many as possible from the tail to balance
for _ in range(min(start_count, tail.count(end_char))):
end_index = tail.index(end_char) + len(end_char)
# Move anything in the tail before the end char too
middle += tail[:end_index]
tail = tail[end_index:]
if _http_re.match(middle):
if middle.startswith("https://") or middle.startswith("http://"):
middle = (
f'<a href="{middle}"{rel_attr}{target_attr}>{trim_url(middle)}</a>'
)
else:
middle = (
f'<a href="https://{middle}"{rel_attr}{target_attr}>'
f"{trim_url(middle)}</a>"
)
elif middle.startswith("mailto:") and _email_re.match(middle[7:]):
middle = f'<a href="{middle}">{middle[7:]}</a>'
elif (
"@" in middle
and not middle.startswith("www.")
and ":" not in middle
and _email_re.match(middle)
):
middle = f'<a href="mailto:{middle}">{middle}</a>'
elif extra_schemes is not None:
for scheme in extra_schemes:
if middle != scheme and middle.startswith(scheme):
middle = f'<a href="{middle}"{rel_attr}{target_attr}>{middle}</a>'
words[i] = f"{head}{middle}{tail}"
return "".join(words)
def generate_lorem_ipsum(
n: int = 5, html: bool = True, min: int = 20, max: int = 100
) -> str:
"""Generate some lorem ipsum for the template."""
from .constants import LOREM_IPSUM_WORDS
words = LOREM_IPSUM_WORDS.split()
result = []
for _ in range(n):
next_capitalized = True
last_comma = last_fullstop = 0
word = None
last = None
p = []
# each paragraph contains out of 20 to 100 words.
for idx, _ in enumerate(range(randrange(min, max))):
while True:
word = choice(words)
if word != last:
last = word
break
if next_capitalized:
word = word.capitalize()
next_capitalized = False
# add commas
if idx - randrange(3, 8) > last_comma:
last_comma = idx
last_fullstop += 2
word += ","
# add end of sentences
if idx - randrange(10, 20) > last_fullstop:
last_comma = last_fullstop = idx
word += "."
next_capitalized = True
p.append(word)
# ensure that the paragraph ends with a dot.
p_str = " ".join(p)
if p_str.endswith(","):
p_str = p_str[:-1] + "."
elif not p_str.endswith("."):
p_str += "."
result.append(p_str)
if not html:
return "\n\n".join(result)
return markupsafe.Markup(
"\n".join(f"<p>{markupsafe.escape(x)}</p>" for x in result)
)
def url_quote(obj: t.Any, charset: str = "utf-8", for_qs: bool = False) -> str:
"""Quote a string for use in a URL using the given charset.
:param obj: String or bytes to quote. Other types are converted to
string then encoded to bytes using the given charset.
:param charset: Encode text to bytes using this charset.
:param for_qs: Quote "/" and use "+" for spaces.
"""
if not isinstance(obj, bytes):
if not isinstance(obj, str):
obj = str(obj)
obj = obj.encode(charset)
safe = b"" if for_qs else b"/"
rv = quote_from_bytes(obj, safe)
if for_qs:
rv = rv.replace("%20", "+")
return rv
def unicode_urlencode(obj: t.Any, charset: str = "utf-8", for_qs: bool = False) -> str:
import warnings
warnings.warn(
"'unicode_urlencode' has been renamed to 'url_quote'. The old"
" name will be removed in Jinja 3.1.",
DeprecationWarning,
stacklevel=2,
)
return url_quote(obj, charset=charset, for_qs=for_qs)
@abc.MutableMapping.register
class LRUCache:
"""A simple LRU Cache implementation."""
# this is fast for small capacities (something below 1000) but doesn't
# scale. But as long as it's only used as storage for templates this
# won't do any harm.
def __init__(self, capacity: int) -> None:
self.capacity = capacity
self._mapping: t.Dict[t.Any, t.Any] = {}
self._queue: "te.Deque[t.Any]" = deque()
self._postinit()
def _postinit(self) -> None:
# alias all queue methods for faster lookup
self._popleft = self._queue.popleft
self._pop = self._queue.pop
self._remove = self._queue.remove
self._wlock = Lock()
self._append = self._queue.append
def __getstate__(self) -> t.Mapping[str, t.Any]:
return {
"capacity": self.capacity,
"_mapping": self._mapping,
"_queue": self._queue,
}
def __setstate__(self, d: t.Mapping[str, t.Any]) -> None:
self.__dict__.update(d)
self._postinit()
def __getnewargs__(self) -> t.Tuple:
return (self.capacity,)
def copy(self) -> "LRUCache":
"""Return a shallow copy of the instance."""
rv = self.__class__(self.capacity)
rv._mapping.update(self._mapping)
rv._queue.extend(self._queue)
return rv
def get(self, key: t.Any, default: t.Any = None) -> t.Any:
"""Return an item from the cache dict or `default`"""
try:
return self[key]
except KeyError:
return default
def setdefault(self, key: t.Any, default: t.Any = None) -> t.Any:
"""Set `default` if the key is not in the cache otherwise
leave unchanged. Return the value of this key.
"""
try:
return self[key]
except KeyError:
self[key] = default
return default
def clear(self) -> None:
"""Clear the cache."""
with self._wlock:
self._mapping.clear()
self._queue.clear()
def __contains__(self, key: t.Any) -> bool:
"""Check if a key exists in this cache."""
return key in self._mapping
def __len__(self) -> int:
"""Return the current size of the cache."""
return len(self._mapping)
def __repr__(self) -> str:
return f"<{type(self).__name__} {self._mapping!r}>"
def __getitem__(self, key: t.Any) -> t.Any:
"""Get an item from the cache. Moves the item up so that it has the
highest priority then.
Raise a `KeyError` if it does not exist.
"""
with self._wlock:
rv = self._mapping[key]
if self._queue[-1] != key:
try:
self._remove(key)
except ValueError:
# if something removed the key from the container
# when we read, ignore the ValueError that we would
# get otherwise.
pass
self._append(key)
return rv
def __setitem__(self, key: t.Any, value: t.Any) -> None:
"""Sets the value for an item. Moves the item up so that it
has the highest priority then.
"""
with self._wlock:
if key in self._mapping:
self._remove(key)
elif len(self._mapping) == self.capacity:
del self._mapping[self._popleft()]
self._append(key)
self._mapping[key] = value
def __delitem__(self, key: t.Any) -> None:
"""Remove an item from the cache dict.
Raise a `KeyError` if it does not exist.
"""
with self._wlock:
del self._mapping[key]
try:
self._remove(key)
except ValueError:
pass
def items(self) -> t.Iterable[t.Tuple[t.Any, t.Any]]:
"""Return a list of items."""
result = [(key, self._mapping[key]) for key in list(self._queue)]
result.reverse()
return result
def values(self) -> t.Iterable[t.Any]:
"""Return a list of all values."""
return [x[1] for x in self.items()]
def keys(self) -> t.Iterable[t.Any]:
"""Return a list of all keys ordered by most recent usage."""
return list(self)
def __iter__(self) -> t.Iterator[t.Any]:
return reversed(tuple(self._queue))
def __reversed__(self) -> t.Iterator[t.Any]:
"""Iterate over the keys in the cache dict, oldest items
coming first.
"""
return iter(tuple(self._queue))
__copy__ = copy
def select_autoescape(
enabled_extensions: t.Collection[str] = ("html", "htm", "xml"),
disabled_extensions: t.Collection[str] = (),
default_for_string: bool = True,
default: bool = False,
) -> t.Callable[[t.Optional[str]], bool]:
"""Intelligently sets the initial value of autoescaping based on the
filename of the template. This is the recommended way to configure
autoescaping if you do not want to write a custom function yourself.
If you want to enable it for all templates created from strings or
for all templates with `.html` and `.xml` extensions::
from jinja2 import Environment, select_autoescape
env = Environment(autoescape=select_autoescape(
enabled_extensions=('html', 'xml'),
default_for_string=True,
))
Example configuration to turn it on at all times except if the template
ends with `.txt`::
from jinja2 import Environment, select_autoescape
env = Environment(autoescape=select_autoescape(
disabled_extensions=('txt',),
default_for_string=True,
default=True,
))
The `enabled_extensions` is an iterable of all the extensions that
autoescaping should be enabled for. Likewise `disabled_extensions` is
a list of all templates it should be disabled for. If a template is
loaded from a string then the default from `default_for_string` is used.
If nothing matches then the initial value of autoescaping is set to the
value of `default`.
For security reasons this function operates case insensitive.
.. versionadded:: 2.9
"""
enabled_patterns = tuple(f".{x.lstrip('.').lower()}" for x in enabled_extensions)
disabled_patterns = tuple(f".{x.lstrip('.').lower()}" for x in disabled_extensions)
def autoescape(template_name: t.Optional[str]) -> bool:
if template_name is None:
return default_for_string
template_name = template_name.lower()
if template_name.endswith(enabled_patterns):
return True
if template_name.endswith(disabled_patterns):
return False
return default
return autoescape
def htmlsafe_json_dumps(
obj: t.Any, dumps: t.Optional[t.Callable[..., str]] = None, **kwargs: t.Any
) -> markupsafe.Markup:
"""Serialize an object to a string of JSON with :func:`json.dumps`,
then replace HTML-unsafe characters with Unicode escapes and mark
the result safe with :class:`~markupsafe.Markup`.
This is available in templates as the ``|tojson`` filter.
The following characters are escaped: ``<``, ``>``, ``&``, ``'``.
The returned string is safe to render in HTML documents and
``<script>`` tags. The exception is in HTML attributes that are
double quoted; either use single quotes or the ``|forceescape``
filter.
:param obj: The object to serialize to JSON.
:param dumps: The ``dumps`` function to use. Defaults to
``env.policies["json.dumps_function"]``, which defaults to
:func:`json.dumps`.
:param kwargs: Extra arguments to pass to ``dumps``. Merged onto
``env.policies["json.dumps_kwargs"]``.
.. versionchanged:: 3.0
The ``dumper`` parameter is renamed to ``dumps``.
.. versionadded:: 2.9
"""
if dumps is None:
dumps = json.dumps
return markupsafe.Markup(
dumps(obj, **kwargs)
.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
.replace("'", "\\u0027")
)
class Cycler:
"""Cycle through values by yield them one at a time, then restarting
once the end is reached. Available as ``cycler`` in templates.
Similar to ``loop.cycle``, but can be used outside loops or across
multiple loops. For example, render a list of folders and files in a
list, alternating giving them "odd" and "even" classes.
.. code-block:: html+jinja
{% set row_class = cycler("odd", "even") %}
<ul class="browser">
{% for folder in folders %}
<li class="folder {{ row_class.next() }}">{{ folder }}
{% endfor %}
{% for file in files %}
<li class="file {{ row_class.next() }}">{{ file }}
{% endfor %}
</ul>
:param items: Each positional argument will be yielded in the order
given for each cycle.
.. versionadded:: 2.1
"""
def __init__(self, *items: t.Any) -> None:
if not items:
raise RuntimeError("at least one item has to be provided")
self.items = items
self.pos = 0
def reset(self) -> None:
"""Resets the current item to the first item."""
self.pos = 0
@property
def current(self) -> t.Any:
"""Return the current item. Equivalent to the item that will be
returned next time :meth:`next` is called.
"""
return self.items[self.pos]
def next(self) -> t.Any:
"""Return the current item, then advance :attr:`current` to the
next item.
"""
rv = self.current
self.pos = (self.pos + 1) % len(self.items)
return rv
__next__ = next
class Joiner:
"""A joining helper for templates."""
def __init__(self, sep: str = ", ") -> None:
self.sep = sep
self.used = False
def __call__(self) -> str:
if not self.used:
self.used = True
return ""
return self.sep
class Namespace:
"""A namespace object that can hold arbitrary attributes. It may be
initialized from a dictionary or with keyword arguments."""
def __init__(*args: t.Any, **kwargs: t.Any) -> None: # noqa: B902
self, args = args[0], args[1:]
self.__attrs = dict(*args, **kwargs)
def __getattribute__(self, name: str) -> t.Any:
# __class__ is needed for the awaitable check in async mode
if name in {"_Namespace__attrs", "__class__"}:
return object.__getattribute__(self, name)
try:
return self.__attrs[name]
except KeyError:
raise AttributeError(name) from None
def __setitem__(self, name: str, value: t.Any) -> None:
self.__attrs[name] = value
def __repr__(self) -> str:
return f"<Namespace {self.__attrs!r}>"
class Markup(markupsafe.Markup):
def __new__(cls, base="", encoding=None, errors="strict"): # type: ignore
warnings.warn(
"'jinja2.Markup' is deprecated and will be removed in Jinja"
" 3.1. Import 'markupsafe.Markup' instead.",
DeprecationWarning,
stacklevel=2,
)
return super().__new__(cls, base, encoding, errors)
def escape(s: t.Any) -> str:
warnings.warn(
"'jinja2.escape' is deprecated and will be removed in Jinja"
" 3.1. Import 'markupsafe.escape' instead.",
DeprecationWarning,
stacklevel=2,
)
return markupsafe.escape(s)

View file

@ -0,0 +1,92 @@
"""API for traversing the AST nodes. Implemented by the compiler and
meta introspection.
"""
import typing as t
from .nodes import Node
if t.TYPE_CHECKING:
import typing_extensions as te
class VisitCallable(te.Protocol):
def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
...
class NodeVisitor:
"""Walks the abstract syntax tree and call visitor functions for every
node found. The visitor functions may return values which will be
forwarded by the `visit` method.
Per default the visitor functions for the nodes are ``'visit_'`` +
class name of the node. So a `TryFinally` node visit function would
be `visit_TryFinally`. This behavior can be changed by overriding
the `get_visitor` function. If no visitor function exists for a node
(return value `None`) the `generic_visit` visitor is used instead.
"""
def get_visitor(self, node: Node) -> "t.Optional[VisitCallable]":
"""Return the visitor function for this node or `None` if no visitor
exists for this node. In that case the generic visit function is
used instead.
"""
return getattr(self, f"visit_{type(node).__name__}", None) # type: ignore
def visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Visit a node."""
f = self.get_visitor(node)
if f is not None:
return f(node, *args, **kwargs)
return self.generic_visit(node, *args, **kwargs)
def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Called if no explicit visitor function exists for a node."""
for node in node.iter_child_nodes():
self.visit(node, *args, **kwargs)
class NodeTransformer(NodeVisitor):
"""Walks the abstract syntax tree and allows modifications of nodes.
The `NodeTransformer` will walk the AST and use the return value of the
visitor functions to replace or remove the old node. If the return
value of the visitor function is `None` the node will be removed
from the previous location otherwise it's replaced with the return
value. The return value may be the original node in which case no
replacement takes place.
"""
def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> Node:
for field, old_value in node.iter_fields():
if isinstance(old_value, list):
new_values = []
for value in old_value:
if isinstance(value, Node):
value = self.visit(value, *args, **kwargs)
if value is None:
continue
elif not isinstance(value, Node):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, Node):
new_node = self.visit(old_value, *args, **kwargs)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.List[Node]:
"""As transformers may return lists in some places this method
can be used to enforce a list as return value.
"""
rv = self.visit(node, *args, **kwargs)
if not isinstance(rv, list):
return [rv]
return rv

View file

@ -0,0 +1 @@
from jsonschema import *

View file

@ -0,0 +1,19 @@
Copyright (c) 2013 Julian Berman
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View file

@ -0,0 +1,31 @@
"""
An implementation of JSON Schema for Python
The main functionality is provided by the validator classes for each of the
supported JSON Schema versions.
Most commonly, `validate` is the quickest way to simply validate a given
instance under a schema, and will create a validator for you.
"""
from jsonschema.exceptions import (
ErrorTree, FormatError, RefResolutionError, SchemaError, ValidationError
)
from jsonschema._format import (
FormatChecker,
draft3_format_checker,
draft4_format_checker,
draft6_format_checker,
draft7_format_checker,
)
from jsonschema._types import TypeChecker
from jsonschema.validators import (
Draft3Validator,
Draft4Validator,
Draft6Validator,
Draft7Validator,
RefResolver,
validate,
)
__version__ = "3.2.0"

View file

@ -0,0 +1,2 @@
from jsonschema.cli import main
main()

View file

@ -0,0 +1,425 @@
import datetime
import re
import socket
import struct
from jsonschema.compat import str_types
from jsonschema.exceptions import FormatError
class FormatChecker(object):
"""
A ``format`` property checker.
JSON Schema does not mandate that the ``format`` property actually do any
validation. If validation is desired however, instances of this class can
be hooked into validators to enable format validation.
`FormatChecker` objects always return ``True`` when asked about
formats that they do not know how to validate.
To check a custom format using a function that takes an instance and
returns a ``bool``, use the `FormatChecker.checks` or
`FormatChecker.cls_checks` decorators.
Arguments:
formats (~collections.Iterable):
The known formats to validate. This argument can be used to
limit which formats will be used during validation.
"""
checkers = {}
def __init__(self, formats=None):
if formats is None:
self.checkers = self.checkers.copy()
else:
self.checkers = dict((k, self.checkers[k]) for k in formats)
def __repr__(self):
return "<FormatChecker checkers={}>".format(sorted(self.checkers))
def checks(self, format, raises=()):
"""
Register a decorated function as validating a new format.
Arguments:
format (str):
The format that the decorated function will check.
raises (Exception):
The exception(s) raised by the decorated function when an
invalid instance is found.
The exception object will be accessible as the
`jsonschema.exceptions.ValidationError.cause` attribute of the
resulting validation error.
"""
def _checks(func):
self.checkers[format] = (func, raises)
return func
return _checks
cls_checks = classmethod(checks)
def check(self, instance, format):
"""
Check whether the instance conforms to the given format.
Arguments:
instance (*any primitive type*, i.e. str, number, bool):
The instance to check
format (str):
The format that instance should conform to
Raises:
FormatError: if the instance does not conform to ``format``
"""
if format not in self.checkers:
return
func, raises = self.checkers[format]
result, cause = None, None
try:
result = func(instance)
except raises as e:
cause = e
if not result:
raise FormatError(
"%r is not a %r" % (instance, format), cause=cause,
)
def conforms(self, instance, format):
"""
Check whether the instance conforms to the given format.
Arguments:
instance (*any primitive type*, i.e. str, number, bool):
The instance to check
format (str):
The format that instance should conform to
Returns:
bool: whether it conformed
"""
try:
self.check(instance, format)
except FormatError:
return False
else:
return True
draft3_format_checker = FormatChecker()
draft4_format_checker = FormatChecker()
draft6_format_checker = FormatChecker()
draft7_format_checker = FormatChecker()
_draft_checkers = dict(
draft3=draft3_format_checker,
draft4=draft4_format_checker,
draft6=draft6_format_checker,
draft7=draft7_format_checker,
)
def _checks_drafts(
name=None,
draft3=None,
draft4=None,
draft6=None,
draft7=None,
raises=(),
):
draft3 = draft3 or name
draft4 = draft4 or name
draft6 = draft6 or name
draft7 = draft7 or name
def wrap(func):
if draft3:
func = _draft_checkers["draft3"].checks(draft3, raises)(func)
if draft4:
func = _draft_checkers["draft4"].checks(draft4, raises)(func)
if draft6:
func = _draft_checkers["draft6"].checks(draft6, raises)(func)
if draft7:
func = _draft_checkers["draft7"].checks(draft7, raises)(func)
# Oy. This is bad global state, but relied upon for now, until
# deprecation. See https://github.com/Julian/jsonschema/issues/519
# and test_format_checkers_come_with_defaults
FormatChecker.cls_checks(draft7 or draft6 or draft4 or draft3, raises)(
func,
)
return func
return wrap
@_checks_drafts(name="idn-email")
@_checks_drafts(name="email")
def is_email(instance):
if not isinstance(instance, str_types):
return True
return "@" in instance
_ipv4_re = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")
@_checks_drafts(
draft3="ip-address", draft4="ipv4", draft6="ipv4", draft7="ipv4",
)
def is_ipv4(instance):
if not isinstance(instance, str_types):
return True
if not _ipv4_re.match(instance):
return False
return all(0 <= int(component) <= 255 for component in instance.split("."))
if hasattr(socket, "inet_pton"):
# FIXME: Really this only should raise struct.error, but see the sadness
# that is https://twistedmatrix.com/trac/ticket/9409
@_checks_drafts(
name="ipv6", raises=(socket.error, struct.error, ValueError),
)
def is_ipv6(instance):
if not isinstance(instance, str_types):
return True
return socket.inet_pton(socket.AF_INET6, instance)
_host_name_re = re.compile(r"^[A-Za-z0-9][A-Za-z0-9\.\-]{1,255}$")
@_checks_drafts(
draft3="host-name",
draft4="hostname",
draft6="hostname",
draft7="hostname",
)
def is_host_name(instance):
if not isinstance(instance, str_types):
return True
if not _host_name_re.match(instance):
return False
components = instance.split(".")
for component in components:
if len(component) > 63:
return False
return True
try:
# The built-in `idna` codec only implements RFC 3890, so we go elsewhere.
import idna
except ImportError:
pass
else:
@_checks_drafts(draft7="idn-hostname", raises=idna.IDNAError)
def is_idn_host_name(instance):
if not isinstance(instance, str_types):
return True
idna.encode(instance)
return True
try:
import rfc3987
except ImportError:
try:
from rfc3986_validator import validate_rfc3986
except ImportError:
pass
else:
@_checks_drafts(name="uri")
def is_uri(instance):
if not isinstance(instance, str_types):
return True
return validate_rfc3986(instance, rule="URI")
@_checks_drafts(
draft6="uri-reference",
draft7="uri-reference",
raises=ValueError,
)
def is_uri_reference(instance):
if not isinstance(instance, str_types):
return True
return validate_rfc3986(instance, rule="URI_reference")
else:
@_checks_drafts(draft7="iri", raises=ValueError)
def is_iri(instance):
if not isinstance(instance, str_types):
return True
return rfc3987.parse(instance, rule="IRI")
@_checks_drafts(draft7="iri-reference", raises=ValueError)
def is_iri_reference(instance):
if not isinstance(instance, str_types):
return True
return rfc3987.parse(instance, rule="IRI_reference")
@_checks_drafts(name="uri", raises=ValueError)
def is_uri(instance):
if not isinstance(instance, str_types):
return True
return rfc3987.parse(instance, rule="URI")
@_checks_drafts(
draft6="uri-reference",
draft7="uri-reference",
raises=ValueError,
)
def is_uri_reference(instance):
if not isinstance(instance, str_types):
return True
return rfc3987.parse(instance, rule="URI_reference")
try:
from strict_rfc3339 import validate_rfc3339
except ImportError:
try:
from rfc3339_validator import validate_rfc3339
except ImportError:
validate_rfc3339 = None
if validate_rfc3339:
@_checks_drafts(name="date-time")
def is_datetime(instance):
if not isinstance(instance, str_types):
return True
return validate_rfc3339(instance)
@_checks_drafts(draft7="time")
def is_time(instance):
if not isinstance(instance, str_types):
return True
return is_datetime("1970-01-01T" + instance)
@_checks_drafts(name="regex", raises=re.error)
def is_regex(instance):
if not isinstance(instance, str_types):
return True
return re.compile(instance)
@_checks_drafts(draft3="date", draft7="date", raises=ValueError)
def is_date(instance):
if not isinstance(instance, str_types):
return True
return datetime.datetime.strptime(instance, "%Y-%m-%d")
@_checks_drafts(draft3="time", raises=ValueError)
def is_draft3_time(instance):
if not isinstance(instance, str_types):
return True
return datetime.datetime.strptime(instance, "%H:%M:%S")
try:
import webcolors
except ImportError:
pass
else:
def is_css_color_code(instance):
return webcolors.normalize_hex(instance)
@_checks_drafts(draft3="color", raises=(ValueError, TypeError))
def is_css21_color(instance):
if (
not isinstance(instance, str_types) or
instance.lower() in webcolors.css21_names_to_hex
):
return True
return is_css_color_code(instance)
def is_css3_color(instance):
if instance.lower() in webcolors.css3_names_to_hex:
return True
return is_css_color_code(instance)
try:
import jsonpointer
except ImportError:
pass
else:
@_checks_drafts(
draft6="json-pointer",
draft7="json-pointer",
raises=jsonpointer.JsonPointerException,
)
def is_json_pointer(instance):
if not isinstance(instance, str_types):
return True
return jsonpointer.JsonPointer(instance)
# TODO: I don't want to maintain this, so it
# needs to go either into jsonpointer (pending
# https://github.com/stefankoegl/python-json-pointer/issues/34) or
# into a new external library.
@_checks_drafts(
draft7="relative-json-pointer",
raises=jsonpointer.JsonPointerException,
)
def is_relative_json_pointer(instance):
# Definition taken from:
# https://tools.ietf.org/html/draft-handrews-relative-json-pointer-01#section-3
if not isinstance(instance, str_types):
return True
non_negative_integer, rest = [], ""
for i, character in enumerate(instance):
if character.isdigit():
non_negative_integer.append(character)
continue
if not non_negative_integer:
return False
rest = instance[i:]
break
return (rest == "#") or jsonpointer.JsonPointer(rest)
try:
import uritemplate.exceptions
except ImportError:
pass
else:
@_checks_drafts(
draft6="uri-template",
draft7="uri-template",
raises=uritemplate.exceptions.InvalidTemplate,
)
def is_uri_template(
instance,
template_validator=uritemplate.Validator().force_balanced_braces(),
):
template = uritemplate.URITemplate(instance)
return template_validator.validate(template)

View file

@ -0,0 +1,141 @@
from jsonschema import _utils
from jsonschema.compat import iteritems
from jsonschema.exceptions import ValidationError
def dependencies_draft3(validator, dependencies, instance, schema):
if not validator.is_type(instance, "object"):
return
for property, dependency in iteritems(dependencies):
if property not in instance:
continue
if validator.is_type(dependency, "object"):
for error in validator.descend(
instance, dependency, schema_path=property,
):
yield error
elif validator.is_type(dependency, "string"):
if dependency not in instance:
yield ValidationError(
"%r is a dependency of %r" % (dependency, property)
)
else:
for each in dependency:
if each not in instance:
message = "%r is a dependency of %r"
yield ValidationError(message % (each, property))
def disallow_draft3(validator, disallow, instance, schema):
for disallowed in _utils.ensure_list(disallow):
if validator.is_valid(instance, {"type": [disallowed]}):
yield ValidationError(
"%r is disallowed for %r" % (disallowed, instance)
)
def extends_draft3(validator, extends, instance, schema):
if validator.is_type(extends, "object"):
for error in validator.descend(instance, extends):
yield error
return
for index, subschema in enumerate(extends):
for error in validator.descend(instance, subschema, schema_path=index):
yield error
def items_draft3_draft4(validator, items, instance, schema):
if not validator.is_type(instance, "array"):
return
if validator.is_type(items, "object"):
for index, item in enumerate(instance):
for error in validator.descend(item, items, path=index):
yield error
else:
for (index, item), subschema in zip(enumerate(instance), items):
for error in validator.descend(
item, subschema, path=index, schema_path=index,
):
yield error
def minimum_draft3_draft4(validator, minimum, instance, schema):
if not validator.is_type(instance, "number"):
return
if schema.get("exclusiveMinimum", False):
failed = instance <= minimum
cmp = "less than or equal to"
else:
failed = instance < minimum
cmp = "less than"
if failed:
yield ValidationError(
"%r is %s the minimum of %r" % (instance, cmp, minimum)
)
def maximum_draft3_draft4(validator, maximum, instance, schema):
if not validator.is_type(instance, "number"):
return
if schema.get("exclusiveMaximum", False):
failed = instance >= maximum
cmp = "greater than or equal to"
else:
failed = instance > maximum
cmp = "greater than"
if failed:
yield ValidationError(
"%r is %s the maximum of %r" % (instance, cmp, maximum)
)
def properties_draft3(validator, properties, instance, schema):
if not validator.is_type(instance, "object"):
return
for property, subschema in iteritems(properties):
if property in instance:
for error in validator.descend(
instance[property],
subschema,
path=property,
schema_path=property,
):
yield error
elif subschema.get("required", False):
error = ValidationError("%r is a required property" % property)
error._set(
validator="required",
validator_value=subschema["required"],
instance=instance,
schema=schema,
)
error.path.appendleft(property)
error.schema_path.extend([property, "required"])
yield error
def type_draft3(validator, types, instance, schema):
types = _utils.ensure_list(types)
all_errors = []
for index, type in enumerate(types):
if validator.is_type(type, "object"):
errors = list(validator.descend(instance, type, schema_path=index))
if not errors:
return
all_errors.extend(errors)
else:
if validator.is_type(instance, type):
return
else:
yield ValidationError(
_utils.types_msg(instance, types), context=all_errors,
)

View file

@ -0,0 +1,155 @@
# -*- test-case-name: twisted.test.test_reflect -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Standardized versions of various cool and/or strange things that you can do
with Python's reflection capabilities.
"""
import sys
from jsonschema.compat import PY3
class _NoModuleFound(Exception):
"""
No module was found because none exists.
"""
class InvalidName(ValueError):
"""
The given name is not a dot-separated list of Python objects.
"""
class ModuleNotFound(InvalidName):
"""
The module associated with the given name doesn't exist and it can't be
imported.
"""
class ObjectNotFound(InvalidName):
"""
The object associated with the given name doesn't exist and it can't be
imported.
"""
if PY3:
def reraise(exception, traceback):
raise exception.with_traceback(traceback)
else:
exec("""def reraise(exception, traceback):
raise exception.__class__, exception, traceback""")
reraise.__doc__ = """
Re-raise an exception, with an optional traceback, in a way that is compatible
with both Python 2 and Python 3.
Note that on Python 3, re-raised exceptions will be mutated, with their
C{__traceback__} attribute being set.
@param exception: The exception instance.
@param traceback: The traceback to use, or C{None} indicating a new traceback.
"""
def _importAndCheckStack(importName):
"""
Import the given name as a module, then walk the stack to determine whether
the failure was the module not existing, or some code in the module (for
example a dependent import) failing. This can be helpful to determine
whether any actual application code was run. For example, to distiguish
administrative error (entering the wrong module name), from programmer
error (writing buggy code in a module that fails to import).
@param importName: The name of the module to import.
@type importName: C{str}
@raise Exception: if something bad happens. This can be any type of
exception, since nobody knows what loading some arbitrary code might
do.
@raise _NoModuleFound: if no module was found.
"""
try:
return __import__(importName)
except ImportError:
excType, excValue, excTraceback = sys.exc_info()
while excTraceback:
execName = excTraceback.tb_frame.f_globals["__name__"]
# in Python 2 execName is None when an ImportError is encountered,
# where in Python 3 execName is equal to the importName.
if execName is None or execName == importName:
reraise(excValue, excTraceback)
excTraceback = excTraceback.tb_next
raise _NoModuleFound()
def namedAny(name):
"""
Retrieve a Python object by its fully qualified name from the global Python
module namespace. The first part of the name, that describes a module,
will be discovered and imported. Each subsequent part of the name is
treated as the name of an attribute of the object specified by all of the
name which came before it. For example, the fully-qualified name of this
object is 'twisted.python.reflect.namedAny'.
@type name: L{str}
@param name: The name of the object to return.
@raise InvalidName: If the name is an empty string, starts or ends with
a '.', or is otherwise syntactically incorrect.
@raise ModuleNotFound: If the name is syntactically correct but the
module it specifies cannot be imported because it does not appear to
exist.
@raise ObjectNotFound: If the name is syntactically correct, includes at
least one '.', but the module it specifies cannot be imported because
it does not appear to exist.
@raise AttributeError: If an attribute of an object along the way cannot be
accessed, or a module along the way is not found.
@return: the Python object identified by 'name'.
"""
if not name:
raise InvalidName('Empty module name')
names = name.split('.')
# if the name starts or ends with a '.' or contains '..', the __import__
# will raise an 'Empty module name' error. This will provide a better error
# message.
if '' in names:
raise InvalidName(
"name must be a string giving a '.'-separated list of Python "
"identifiers, not %r" % (name,))
topLevelPackage = None
moduleNames = names[:]
while not topLevelPackage:
if moduleNames:
trialname = '.'.join(moduleNames)
try:
topLevelPackage = _importAndCheckStack(trialname)
except _NoModuleFound:
moduleNames.pop()
else:
if len(names) == 1:
raise ModuleNotFound("No module named %r" % (name,))
else:
raise ObjectNotFound('%r does not name an object' % (name,))
obj = topLevelPackage
for n in names[1:]:
obj = getattr(obj, n)
return obj

View file

@ -0,0 +1,188 @@
import numbers
from pyrsistent import pmap
import attr
from jsonschema.compat import int_types, str_types
from jsonschema.exceptions import UndefinedTypeCheck
def is_array(checker, instance):
return isinstance(instance, list)
def is_bool(checker, instance):
return isinstance(instance, bool)
def is_integer(checker, instance):
# bool inherits from int, so ensure bools aren't reported as ints
if isinstance(instance, bool):
return False
return isinstance(instance, int_types)
def is_null(checker, instance):
return instance is None
def is_number(checker, instance):
# bool inherits from int, so ensure bools aren't reported as ints
if isinstance(instance, bool):
return False
return isinstance(instance, numbers.Number)
def is_object(checker, instance):
return isinstance(instance, dict)
def is_string(checker, instance):
return isinstance(instance, str_types)
def is_any(checker, instance):
return True
@attr.s(frozen=True)
class TypeChecker(object):
"""
A ``type`` property checker.
A `TypeChecker` performs type checking for an `IValidator`. Type
checks to perform are updated using `TypeChecker.redefine` or
`TypeChecker.redefine_many` and removed via `TypeChecker.remove`.
Each of these return a new `TypeChecker` object.
Arguments:
type_checkers (dict):
The initial mapping of types to their checking functions.
"""
_type_checkers = attr.ib(default=pmap(), converter=pmap)
def is_type(self, instance, type):
"""
Check if the instance is of the appropriate type.
Arguments:
instance (object):
The instance to check
type (str):
The name of the type that is expected.
Returns:
bool: Whether it conformed.
Raises:
`jsonschema.exceptions.UndefinedTypeCheck`:
if type is unknown to this object.
"""
try:
fn = self._type_checkers[type]
except KeyError:
raise UndefinedTypeCheck(type)
return fn(self, instance)
def redefine(self, type, fn):
"""
Produce a new checker with the given type redefined.
Arguments:
type (str):
The name of the type to check.
fn (collections.Callable):
A function taking exactly two parameters - the type
checker calling the function and the instance to check.
The function should return true if instance is of this
type and false otherwise.
Returns:
A new `TypeChecker` instance.
"""
return self.redefine_many({type: fn})
def redefine_many(self, definitions=()):
"""
Produce a new checker with the given types redefined.
Arguments:
definitions (dict):
A dictionary mapping types to their checking functions.
Returns:
A new `TypeChecker` instance.
"""
return attr.evolve(
self, type_checkers=self._type_checkers.update(definitions),
)
def remove(self, *types):
"""
Produce a new checker with the given types forgotten.
Arguments:
types (~collections.Iterable):
the names of the types to remove.
Returns:
A new `TypeChecker` instance
Raises:
`jsonschema.exceptions.UndefinedTypeCheck`:
if any given type is unknown to this object
"""
checkers = self._type_checkers
for each in types:
try:
checkers = checkers.remove(each)
except KeyError:
raise UndefinedTypeCheck(each)
return attr.evolve(self, type_checkers=checkers)
draft3_type_checker = TypeChecker(
{
u"any": is_any,
u"array": is_array,
u"boolean": is_bool,
u"integer": is_integer,
u"object": is_object,
u"null": is_null,
u"number": is_number,
u"string": is_string,
},
)
draft4_type_checker = draft3_type_checker.remove(u"any")
draft6_type_checker = draft4_type_checker.redefine(
u"integer",
lambda checker, instance: (
is_integer(checker, instance) or
isinstance(instance, float) and instance.is_integer()
),
)
draft7_type_checker = draft6_type_checker

View file

@ -0,0 +1,212 @@
import itertools
import json
import pkgutil
import re
from jsonschema.compat import MutableMapping, str_types, urlsplit
class URIDict(MutableMapping):
"""
Dictionary which uses normalized URIs as keys.
"""
def normalize(self, uri):
return urlsplit(uri).geturl()
def __init__(self, *args, **kwargs):
self.store = dict()
self.store.update(*args, **kwargs)
def __getitem__(self, uri):
return self.store[self.normalize(uri)]
def __setitem__(self, uri, value):
self.store[self.normalize(uri)] = value
def __delitem__(self, uri):
del self.store[self.normalize(uri)]
def __iter__(self):
return iter(self.store)
def __len__(self):
return len(self.store)
def __repr__(self):
return repr(self.store)
class Unset(object):
"""
An as-of-yet unset attribute or unprovided default parameter.
"""
def __repr__(self):
return "<unset>"
def load_schema(name):
"""
Load a schema from ./schemas/``name``.json and return it.
"""
data = pkgutil.get_data("jsonschema", "schemas/{0}.json".format(name))
return json.loads(data.decode("utf-8"))
def indent(string, times=1):
"""
A dumb version of `textwrap.indent` from Python 3.3.
"""
return "\n".join(" " * (4 * times) + line for line in string.splitlines())
def format_as_index(indices):
"""
Construct a single string containing indexing operations for the indices.
For example, [1, 2, "foo"] -> [1][2]["foo"]
Arguments:
indices (sequence):
The indices to format.
"""
if not indices:
return ""
return "[%s]" % "][".join(repr(index) for index in indices)
def find_additional_properties(instance, schema):
"""
Return the set of additional properties for the given ``instance``.
Weeds out properties that should have been validated by ``properties`` and
/ or ``patternProperties``.
Assumes ``instance`` is dict-like already.
"""
properties = schema.get("properties", {})
patterns = "|".join(schema.get("patternProperties", {}))
for property in instance:
if property not in properties:
if patterns and re.search(patterns, property):
continue
yield property
def extras_msg(extras):
"""
Create an error message for extra items or properties.
"""
if len(extras) == 1:
verb = "was"
else:
verb = "were"
return ", ".join(repr(extra) for extra in extras), verb
def types_msg(instance, types):
"""
Create an error message for a failure to match the given types.
If the ``instance`` is an object and contains a ``name`` property, it will
be considered to be a description of that object and used as its type.
Otherwise the message is simply the reprs of the given ``types``.
"""
reprs = []
for type in types:
try:
reprs.append(repr(type["name"]))
except Exception:
reprs.append(repr(type))
return "%r is not of type %s" % (instance, ", ".join(reprs))
def flatten(suitable_for_isinstance):
"""
isinstance() can accept a bunch of really annoying different types:
* a single type
* a tuple of types
* an arbitrary nested tree of tuples
Return a flattened tuple of the given argument.
"""
types = set()
if not isinstance(suitable_for_isinstance, tuple):
suitable_for_isinstance = (suitable_for_isinstance,)
for thing in suitable_for_isinstance:
if isinstance(thing, tuple):
types.update(flatten(thing))
else:
types.add(thing)
return tuple(types)
def ensure_list(thing):
"""
Wrap ``thing`` in a list if it's a single str.
Otherwise, return it unchanged.
"""
if isinstance(thing, str_types):
return [thing]
return thing
def equal(one, two):
"""
Check if two things are equal, but evade booleans and ints being equal.
"""
return unbool(one) == unbool(two)
def unbool(element, true=object(), false=object()):
"""
A hack to make True and 1 and False and 0 unique for ``uniq``.
"""
if element is True:
return true
elif element is False:
return false
return element
def uniq(container):
"""
Check if all of a container's elements are unique.
Successively tries first to rely that the elements are hashable, then
falls back on them being sortable, and finally falls back on brute
force.
"""
try:
return len(set(unbool(i) for i in container)) == len(container)
except TypeError:
try:
sort = sorted(unbool(i) for i in container)
sliced = itertools.islice(sort, 1, None)
for i, j in zip(sort, sliced):
if i == j:
return False
except (NotImplementedError, TypeError):
seen = []
for e in container:
e = unbool(e)
if e in seen:
return False
seen.append(e)
return True

View file

@ -0,0 +1,373 @@
import re
from jsonschema._utils import (
ensure_list,
equal,
extras_msg,
find_additional_properties,
types_msg,
unbool,
uniq,
)
from jsonschema.exceptions import FormatError, ValidationError
from jsonschema.compat import iteritems
def patternProperties(validator, patternProperties, instance, schema):
if not validator.is_type(instance, "object"):
return
for pattern, subschema in iteritems(patternProperties):
for k, v in iteritems(instance):
if re.search(pattern, k):
for error in validator.descend(
v, subschema, path=k, schema_path=pattern,
):
yield error
def propertyNames(validator, propertyNames, instance, schema):
if not validator.is_type(instance, "object"):
return
for property in instance:
for error in validator.descend(
instance=property,
schema=propertyNames,
):
yield error
def additionalProperties(validator, aP, instance, schema):
if not validator.is_type(instance, "object"):
return
extras = set(find_additional_properties(instance, schema))
if validator.is_type(aP, "object"):
for extra in extras:
for error in validator.descend(instance[extra], aP, path=extra):
yield error
elif not aP and extras:
if "patternProperties" in schema:
patterns = sorted(schema["patternProperties"])
if len(extras) == 1:
verb = "does"
else:
verb = "do"
error = "%s %s not match any of the regexes: %s" % (
", ".join(map(repr, sorted(extras))),
verb,
", ".join(map(repr, patterns)),
)
yield ValidationError(error)
else:
error = "Additional properties are not allowed (%s %s unexpected)"
yield ValidationError(error % extras_msg(extras))
def items(validator, items, instance, schema):
if not validator.is_type(instance, "array"):
return
if validator.is_type(items, "array"):
for (index, item), subschema in zip(enumerate(instance), items):
for error in validator.descend(
item, subschema, path=index, schema_path=index,
):
yield error
else:
for index, item in enumerate(instance):
for error in validator.descend(item, items, path=index):
yield error
def additionalItems(validator, aI, instance, schema):
if (
not validator.is_type(instance, "array") or
validator.is_type(schema.get("items", {}), "object")
):
return
len_items = len(schema.get("items", []))
if validator.is_type(aI, "object"):
for index, item in enumerate(instance[len_items:], start=len_items):
for error in validator.descend(item, aI, path=index):
yield error
elif not aI and len(instance) > len(schema.get("items", [])):
error = "Additional items are not allowed (%s %s unexpected)"
yield ValidationError(
error %
extras_msg(instance[len(schema.get("items", [])):])
)
def const(validator, const, instance, schema):
if not equal(instance, const):
yield ValidationError("%r was expected" % (const,))
def contains(validator, contains, instance, schema):
if not validator.is_type(instance, "array"):
return
if not any(validator.is_valid(element, contains) for element in instance):
yield ValidationError(
"None of %r are valid under the given schema" % (instance,)
)
def exclusiveMinimum(validator, minimum, instance, schema):
if not validator.is_type(instance, "number"):
return
if instance <= minimum:
yield ValidationError(
"%r is less than or equal to the minimum of %r" % (
instance, minimum,
),
)
def exclusiveMaximum(validator, maximum, instance, schema):
if not validator.is_type(instance, "number"):
return
if instance >= maximum:
yield ValidationError(
"%r is greater than or equal to the maximum of %r" % (
instance, maximum,
),
)
def minimum(validator, minimum, instance, schema):
if not validator.is_type(instance, "number"):
return
if instance < minimum:
yield ValidationError(
"%r is less than the minimum of %r" % (instance, minimum)
)
def maximum(validator, maximum, instance, schema):
if not validator.is_type(instance, "number"):
return
if instance > maximum:
yield ValidationError(
"%r is greater than the maximum of %r" % (instance, maximum)
)
def multipleOf(validator, dB, instance, schema):
if not validator.is_type(instance, "number"):
return
if isinstance(dB, float):
quotient = instance / dB
failed = int(quotient) != quotient
else:
failed = instance % dB
if failed:
yield ValidationError("%r is not a multiple of %r" % (instance, dB))
def minItems(validator, mI, instance, schema):
if validator.is_type(instance, "array") and len(instance) < mI:
yield ValidationError("%r is too short" % (instance,))
def maxItems(validator, mI, instance, schema):
if validator.is_type(instance, "array") and len(instance) > mI:
yield ValidationError("%r is too long" % (instance,))
def uniqueItems(validator, uI, instance, schema):
if (
uI and
validator.is_type(instance, "array") and
not uniq(instance)
):
yield ValidationError("%r has non-unique elements" % (instance,))
def pattern(validator, patrn, instance, schema):
if (
validator.is_type(instance, "string") and
not re.search(patrn, instance)
):
yield ValidationError("%r does not match %r" % (instance, patrn))
def format(validator, format, instance, schema):
if validator.format_checker is not None:
try:
validator.format_checker.check(instance, format)
except FormatError as error:
yield ValidationError(error.message, cause=error.cause)
def minLength(validator, mL, instance, schema):
if validator.is_type(instance, "string") and len(instance) < mL:
yield ValidationError("%r is too short" % (instance,))
def maxLength(validator, mL, instance, schema):
if validator.is_type(instance, "string") and len(instance) > mL:
yield ValidationError("%r is too long" % (instance,))
def dependencies(validator, dependencies, instance, schema):
if not validator.is_type(instance, "object"):
return
for property, dependency in iteritems(dependencies):
if property not in instance:
continue
if validator.is_type(dependency, "array"):
for each in dependency:
if each not in instance:
message = "%r is a dependency of %r"
yield ValidationError(message % (each, property))
else:
for error in validator.descend(
instance, dependency, schema_path=property,
):
yield error
def enum(validator, enums, instance, schema):
if instance == 0 or instance == 1:
unbooled = unbool(instance)
if all(unbooled != unbool(each) for each in enums):
yield ValidationError("%r is not one of %r" % (instance, enums))
elif instance not in enums:
yield ValidationError("%r is not one of %r" % (instance, enums))
def ref(validator, ref, instance, schema):
resolve = getattr(validator.resolver, "resolve", None)
if resolve is None:
with validator.resolver.resolving(ref) as resolved:
for error in validator.descend(instance, resolved):
yield error
else:
scope, resolved = validator.resolver.resolve(ref)
validator.resolver.push_scope(scope)
try:
for error in validator.descend(instance, resolved):
yield error
finally:
validator.resolver.pop_scope()
def type(validator, types, instance, schema):
types = ensure_list(types)
if not any(validator.is_type(instance, type) for type in types):
yield ValidationError(types_msg(instance, types))
def properties(validator, properties, instance, schema):
if not validator.is_type(instance, "object"):
return
for property, subschema in iteritems(properties):
if property in instance:
for error in validator.descend(
instance[property],
subschema,
path=property,
schema_path=property,
):
yield error
def required(validator, required, instance, schema):
if not validator.is_type(instance, "object"):
return
for property in required:
if property not in instance:
yield ValidationError("%r is a required property" % property)
def minProperties(validator, mP, instance, schema):
if validator.is_type(instance, "object") and len(instance) < mP:
yield ValidationError(
"%r does not have enough properties" % (instance,)
)
def maxProperties(validator, mP, instance, schema):
if not validator.is_type(instance, "object"):
return
if validator.is_type(instance, "object") and len(instance) > mP:
yield ValidationError("%r has too many properties" % (instance,))
def allOf(validator, allOf, instance, schema):
for index, subschema in enumerate(allOf):
for error in validator.descend(instance, subschema, schema_path=index):
yield error
def anyOf(validator, anyOf, instance, schema):
all_errors = []
for index, subschema in enumerate(anyOf):
errs = list(validator.descend(instance, subschema, schema_path=index))
if not errs:
break
all_errors.extend(errs)
else:
yield ValidationError(
"%r is not valid under any of the given schemas" % (instance,),
context=all_errors,
)
def oneOf(validator, oneOf, instance, schema):
subschemas = enumerate(oneOf)
all_errors = []
for index, subschema in subschemas:
errs = list(validator.descend(instance, subschema, schema_path=index))
if not errs:
first_valid = subschema
break
all_errors.extend(errs)
else:
yield ValidationError(
"%r is not valid under any of the given schemas" % (instance,),
context=all_errors,
)
more_valid = [s for i, s in subschemas if validator.is_valid(instance, s)]
if more_valid:
more_valid.append(first_valid)
reprs = ", ".join(repr(schema) for schema in more_valid)
yield ValidationError(
"%r is valid under each of %s" % (instance, reprs)
)
def not_(validator, not_schema, instance, schema):
if validator.is_valid(instance, not_schema):
yield ValidationError(
"%r is not allowed for %r" % (not_schema, instance)
)
def if_(validator, if_schema, instance, schema):
if validator.is_valid(instance, if_schema):
if u"then" in schema:
then = schema[u"then"]
for error in validator.descend(instance, then, schema_path="then"):
yield error
elif u"else" in schema:
else_ = schema[u"else"]
for error in validator.descend(instance, else_, schema_path="else"):
yield error

View file

@ -0,0 +1,5 @@
"""
Benchmarks for validation.
This package is *not* public API.
"""

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
"""
A performance benchmark using the example from issue #232.
See https://github.com/Julian/jsonschema/pull/232.
"""
from twisted.python.filepath import FilePath
from pyperf import Runner
from pyrsistent import m
from jsonschema.tests._suite import Version
import jsonschema
issue232 = Version(
path=FilePath(__file__).sibling("issue232"),
remotes=m(),
name="issue232",
)
if __name__ == "__main__":
issue232.benchmark(
runner=Runner(),
Validator=jsonschema.Draft4Validator,
)

View file

@ -0,0 +1,14 @@
#!/usr/bin/env python
"""
A performance benchmark using the official test suite.
This benchmarks jsonschema using every valid example in the
JSON-Schema-Test-Suite. It will take some time to complete.
"""
from pyperf import Runner
from jsonschema.tests._suite import Suite
if __name__ == "__main__":
Suite().benchmark(runner=Runner())

View file

@ -0,0 +1,90 @@
"""
The ``jsonschema`` command line.
"""
from __future__ import absolute_import
import argparse
import json
import sys
from jsonschema import __version__
from jsonschema._reflect import namedAny
from jsonschema.validators import validator_for
def _namedAnyWithDefault(name):
if "." not in name:
name = "jsonschema." + name
return namedAny(name)
def _json_file(path):
with open(path) as file:
return json.load(file)
parser = argparse.ArgumentParser(
description="JSON Schema Validation CLI",
)
parser.add_argument(
"-i", "--instance",
action="append",
dest="instances",
type=_json_file,
help=(
"a path to a JSON instance (i.e. filename.json) "
"to validate (may be specified multiple times)"
),
)
parser.add_argument(
"-F", "--error-format",
default="{error.instance}: {error.message}\n",
help=(
"the format to use for each error output message, specified in "
"a form suitable for passing to str.format, which will be called "
"with 'error' for each error"
),
)
parser.add_argument(
"-V", "--validator",
type=_namedAnyWithDefault,
help=(
"the fully qualified object name of a validator to use, or, for "
"validators that are registered with jsonschema, simply the name "
"of the class."
),
)
parser.add_argument(
"--version",
action="version",
version=__version__,
)
parser.add_argument(
"schema",
help="the JSON Schema to validate with (i.e. schema.json)",
type=_json_file,
)
def parse_args(args):
arguments = vars(parser.parse_args(args=args or ["--help"]))
if arguments["validator"] is None:
arguments["validator"] = validator_for(arguments["schema"])
return arguments
def main(args=sys.argv[1:]):
sys.exit(run(arguments=parse_args(args=args)))
def run(arguments, stdout=sys.stdout, stderr=sys.stderr):
error_format = arguments["error_format"]
validator = arguments["validator"](schema=arguments["schema"])
validator.check_schema(arguments["schema"])
errored = False
for instance in arguments["instances"] or ():
for error in validator.iter_errors(instance):
stderr.write(error_format.format(error=error))
errored = True
return errored

View file

@ -0,0 +1,55 @@
"""
Python 2/3 compatibility helpers.
Note: This module is *not* public API.
"""
import contextlib
import operator
import sys
try:
from collections.abc import MutableMapping, Sequence # noqa
except ImportError:
from collections import MutableMapping, Sequence # noqa
PY3 = sys.version_info[0] >= 3
if PY3:
zip = zip
from functools import lru_cache
from io import StringIO as NativeIO
from urllib.parse import (
unquote, urljoin, urlunsplit, SplitResult, urlsplit
)
from urllib.request import pathname2url, urlopen
str_types = str,
int_types = int,
iteritems = operator.methodcaller("items")
else:
from itertools import izip as zip # noqa
from io import BytesIO as NativeIO
from urlparse import urljoin, urlunsplit, SplitResult, urlsplit
from urllib import pathname2url, unquote # noqa
import urllib2 # noqa
def urlopen(*args, **kwargs):
return contextlib.closing(urllib2.urlopen(*args, **kwargs))
str_types = basestring
int_types = int, long
iteritems = operator.methodcaller("iteritems")
from functools32 import lru_cache
def urldefrag(url):
if "#" in url:
s, n, p, q, frag = urlsplit(url)
defrag = urlunsplit((s, n, p, q, ""))
else:
defrag = url
frag = ""
return defrag, frag
# flake8: noqa

View file

@ -0,0 +1,374 @@
"""
Validation errors, and some surrounding helpers.
"""
from collections import defaultdict, deque
import itertools
import pprint
import textwrap
import attr
from jsonschema import _utils
from jsonschema.compat import PY3, iteritems
WEAK_MATCHES = frozenset(["anyOf", "oneOf"])
STRONG_MATCHES = frozenset()
_unset = _utils.Unset()
class _Error(Exception):
def __init__(
self,
message,
validator=_unset,
path=(),
cause=None,
context=(),
validator_value=_unset,
instance=_unset,
schema=_unset,
schema_path=(),
parent=None,
):
super(_Error, self).__init__(
message,
validator,
path,
cause,
context,
validator_value,
instance,
schema,
schema_path,
parent,
)
self.message = message
self.path = self.relative_path = deque(path)
self.schema_path = self.relative_schema_path = deque(schema_path)
self.context = list(context)
self.cause = self.__cause__ = cause
self.validator = validator
self.validator_value = validator_value
self.instance = instance
self.schema = schema
self.parent = parent
for error in context:
error.parent = self
def __repr__(self):
return "<%s: %r>" % (self.__class__.__name__, self.message)
def __unicode__(self):
essential_for_verbose = (
self.validator, self.validator_value, self.instance, self.schema,
)
if any(m is _unset for m in essential_for_verbose):
return self.message
pschema = pprint.pformat(self.schema, width=72)
pinstance = pprint.pformat(self.instance, width=72)
return self.message + textwrap.dedent("""
Failed validating %r in %s%s:
%s
On %s%s:
%s
""".rstrip()
) % (
self.validator,
self._word_for_schema_in_error_message,
_utils.format_as_index(list(self.relative_schema_path)[:-1]),
_utils.indent(pschema),
self._word_for_instance_in_error_message,
_utils.format_as_index(self.relative_path),
_utils.indent(pinstance),
)
if PY3:
__str__ = __unicode__
else:
def __str__(self):
return unicode(self).encode("utf-8")
@classmethod
def create_from(cls, other):
return cls(**other._contents())
@property
def absolute_path(self):
parent = self.parent
if parent is None:
return self.relative_path
path = deque(self.relative_path)
path.extendleft(reversed(parent.absolute_path))
return path
@property
def absolute_schema_path(self):
parent = self.parent
if parent is None:
return self.relative_schema_path
path = deque(self.relative_schema_path)
path.extendleft(reversed(parent.absolute_schema_path))
return path
def _set(self, **kwargs):
for k, v in iteritems(kwargs):
if getattr(self, k) is _unset:
setattr(self, k, v)
def _contents(self):
attrs = (
"message", "cause", "context", "validator", "validator_value",
"path", "schema_path", "instance", "schema", "parent",
)
return dict((attr, getattr(self, attr)) for attr in attrs)
class ValidationError(_Error):
"""
An instance was invalid under a provided schema.
"""
_word_for_schema_in_error_message = "schema"
_word_for_instance_in_error_message = "instance"
class SchemaError(_Error):
"""
A schema was invalid under its corresponding metaschema.
"""
_word_for_schema_in_error_message = "metaschema"
_word_for_instance_in_error_message = "schema"
@attr.s(hash=True)
class RefResolutionError(Exception):
"""
A ref could not be resolved.
"""
_cause = attr.ib()
def __str__(self):
return str(self._cause)
class UndefinedTypeCheck(Exception):
"""
A type checker was asked to check a type it did not have registered.
"""
def __init__(self, type):
self.type = type
def __unicode__(self):
return "Type %r is unknown to this type checker" % self.type
if PY3:
__str__ = __unicode__
else:
def __str__(self):
return unicode(self).encode("utf-8")
class UnknownType(Exception):
"""
A validator was asked to validate an instance against an unknown type.
"""
def __init__(self, type, instance, schema):
self.type = type
self.instance = instance
self.schema = schema
def __unicode__(self):
pschema = pprint.pformat(self.schema, width=72)
pinstance = pprint.pformat(self.instance, width=72)
return textwrap.dedent("""
Unknown type %r for validator with schema:
%s
While checking instance:
%s
""".rstrip()
) % (self.type, _utils.indent(pschema), _utils.indent(pinstance))
if PY3:
__str__ = __unicode__
else:
def __str__(self):
return unicode(self).encode("utf-8")
class FormatError(Exception):
"""
Validating a format failed.
"""
def __init__(self, message, cause=None):
super(FormatError, self).__init__(message, cause)
self.message = message
self.cause = self.__cause__ = cause
def __unicode__(self):
return self.message
if PY3:
__str__ = __unicode__
else:
def __str__(self):
return self.message.encode("utf-8")
class ErrorTree(object):
"""
ErrorTrees make it easier to check which validations failed.
"""
_instance = _unset
def __init__(self, errors=()):
self.errors = {}
self._contents = defaultdict(self.__class__)
for error in errors:
container = self
for element in error.path:
container = container[element]
container.errors[error.validator] = error
container._instance = error.instance
def __contains__(self, index):
"""
Check whether ``instance[index]`` has any errors.
"""
return index in self._contents
def __getitem__(self, index):
"""
Retrieve the child tree one level down at the given ``index``.
If the index is not in the instance that this tree corresponds to and
is not known by this tree, whatever error would be raised by
``instance.__getitem__`` will be propagated (usually this is some
subclass of `exceptions.LookupError`.
"""
if self._instance is not _unset and index not in self:
self._instance[index]
return self._contents[index]
def __setitem__(self, index, value):
"""
Add an error to the tree at the given ``index``.
"""
self._contents[index] = value
def __iter__(self):
"""
Iterate (non-recursively) over the indices in the instance with errors.
"""
return iter(self._contents)
def __len__(self):
"""
Return the `total_errors`.
"""
return self.total_errors
def __repr__(self):
return "<%s (%s total errors)>" % (self.__class__.__name__, len(self))
@property
def total_errors(self):
"""
The total number of errors in the entire tree, including children.
"""
child_errors = sum(len(tree) for _, tree in iteritems(self._contents))
return len(self.errors) + child_errors
def by_relevance(weak=WEAK_MATCHES, strong=STRONG_MATCHES):
"""
Create a key function that can be used to sort errors by relevance.
Arguments:
weak (set):
a collection of validator names to consider to be "weak".
If there are two errors at the same level of the instance
and one is in the set of weak validator names, the other
error will take priority. By default, :validator:`anyOf` and
:validator:`oneOf` are considered weak validators and will
be superseded by other same-level validation errors.
strong (set):
a collection of validator names to consider to be "strong"
"""
def relevance(error):
validator = error.validator
return -len(error.path), validator not in weak, validator in strong
return relevance
relevance = by_relevance()
def best_match(errors, key=relevance):
"""
Try to find an error that appears to be the best match among given errors.
In general, errors that are higher up in the instance (i.e. for which
`ValidationError.path` is shorter) are considered better matches,
since they indicate "more" is wrong with the instance.
If the resulting match is either :validator:`oneOf` or :validator:`anyOf`,
the *opposite* assumption is made -- i.e. the deepest error is picked,
since these validators only need to match once, and any other errors may
not be relevant.
Arguments:
errors (collections.Iterable):
the errors to select from. Do not provide a mixture of
errors from different validation attempts (i.e. from
different instances or schemas), since it won't produce
sensical output.
key (collections.Callable):
the key to use when sorting errors. See `relevance` and
transitively `by_relevance` for more details (the default is
to sort with the defaults of that function). Changing the
default is only useful if you want to change the function
that rates errors but still want the error context descent
done by this function.
Returns:
the best matching error, or ``None`` if the iterable was empty
.. note::
This function is a heuristic. Its return value may change for a given
set of inputs from version to version if better heuristics are added.
"""
errors = iter(errors)
best = next(errors, None)
if best is None:
return
best = max(itertools.chain([best], errors), key=key)
while best.context:
best = min(best.context, key=key)
return best

View file

@ -0,0 +1,199 @@
{
"$schema": "http://json-schema.org/draft-03/schema#",
"dependencies": {
"exclusiveMaximum": "maximum",
"exclusiveMinimum": "minimum"
},
"id": "http://json-schema.org/draft-03/schema#",
"properties": {
"$ref": {
"format": "uri",
"type": "string"
},
"$schema": {
"format": "uri",
"type": "string"
},
"additionalItems": {
"default": {},
"type": [
{
"$ref": "#"
},
"boolean"
]
},
"additionalProperties": {
"default": {},
"type": [
{
"$ref": "#"
},
"boolean"
]
},
"default": {
"type": "any"
},
"dependencies": {
"additionalProperties": {
"items": {
"type": "string"
},
"type": [
"string",
"array",
{
"$ref": "#"
}
]
},
"default": {},
"type": [
"string",
"array",
"object"
]
},
"description": {
"type": "string"
},
"disallow": {
"items": {
"type": [
"string",
{
"$ref": "#"
}
]
},
"type": [
"string",
"array"
],
"uniqueItems": true
},
"divisibleBy": {
"default": 1,
"exclusiveMinimum": true,
"minimum": 0,
"type": "number"
},
"enum": {
"type": "array"
},
"exclusiveMaximum": {
"default": false,
"type": "boolean"
},
"exclusiveMinimum": {
"default": false,
"type": "boolean"
},
"extends": {
"default": {},
"items": {
"$ref": "#"
},
"type": [
{
"$ref": "#"
},
"array"
]
},
"format": {
"type": "string"
},
"id": {
"format": "uri",
"type": "string"
},
"items": {
"default": {},
"items": {
"$ref": "#"
},
"type": [
{
"$ref": "#"
},
"array"
]
},
"maxDecimal": {
"minimum": 0,
"type": "number"
},
"maxItems": {
"minimum": 0,
"type": "integer"
},
"maxLength": {
"type": "integer"
},
"maximum": {
"type": "number"
},
"minItems": {
"default": 0,
"minimum": 0,
"type": "integer"
},
"minLength": {
"default": 0,
"minimum": 0,
"type": "integer"
},
"minimum": {
"type": "number"
},
"pattern": {
"format": "regex",
"type": "string"
},
"patternProperties": {
"additionalProperties": {
"$ref": "#"
},
"default": {},
"type": "object"
},
"properties": {
"additionalProperties": {
"$ref": "#",
"type": "object"
},
"default": {},
"type": "object"
},
"required": {
"default": false,
"type": "boolean"
},
"title": {
"type": "string"
},
"type": {
"default": "any",
"items": {
"type": [
"string",
{
"$ref": "#"
}
]
},
"type": [
"string",
"array"
],
"uniqueItems": true
},
"uniqueItems": {
"default": false,
"type": "boolean"
}
},
"type": "object"
}

View file

@ -0,0 +1,222 @@
{
"$schema": "http://json-schema.org/draft-04/schema#",
"default": {},
"definitions": {
"positiveInteger": {
"minimum": 0,
"type": "integer"
},
"positiveIntegerDefault0": {
"allOf": [
{
"$ref": "#/definitions/positiveInteger"
},
{
"default": 0
}
]
},
"schemaArray": {
"items": {
"$ref": "#"
},
"minItems": 1,
"type": "array"
},
"simpleTypes": {
"enum": [
"array",
"boolean",
"integer",
"null",
"number",
"object",
"string"
]
},
"stringArray": {
"items": {
"type": "string"
},
"minItems": 1,
"type": "array",
"uniqueItems": true
}
},
"dependencies": {
"exclusiveMaximum": [
"maximum"
],
"exclusiveMinimum": [
"minimum"
]
},
"description": "Core schema meta-schema",
"id": "http://json-schema.org/draft-04/schema#",
"properties": {
"$schema": {
"format": "uri",
"type": "string"
},
"additionalItems": {
"anyOf": [
{
"type": "boolean"
},
{
"$ref": "#"
}
],
"default": {}
},
"additionalProperties": {
"anyOf": [
{
"type": "boolean"
},
{
"$ref": "#"
}
],
"default": {}
},
"allOf": {
"$ref": "#/definitions/schemaArray"
},
"anyOf": {
"$ref": "#/definitions/schemaArray"
},
"default": {},
"definitions": {
"additionalProperties": {
"$ref": "#"
},
"default": {},
"type": "object"
},
"dependencies": {
"additionalProperties": {
"anyOf": [
{
"$ref": "#"
},
{
"$ref": "#/definitions/stringArray"
}
]
},
"type": "object"
},
"description": {
"type": "string"
},
"enum": {
"type": "array"
},
"exclusiveMaximum": {
"default": false,
"type": "boolean"
},
"exclusiveMinimum": {
"default": false,
"type": "boolean"
},
"format": {
"type": "string"
},
"id": {
"format": "uri",
"type": "string"
},
"items": {
"anyOf": [
{
"$ref": "#"
},
{
"$ref": "#/definitions/schemaArray"
}
],
"default": {}
},
"maxItems": {
"$ref": "#/definitions/positiveInteger"
},
"maxLength": {
"$ref": "#/definitions/positiveInteger"
},
"maxProperties": {
"$ref": "#/definitions/positiveInteger"
},
"maximum": {
"type": "number"
},
"minItems": {
"$ref": "#/definitions/positiveIntegerDefault0"
},
"minLength": {
"$ref": "#/definitions/positiveIntegerDefault0"
},
"minProperties": {
"$ref": "#/definitions/positiveIntegerDefault0"
},
"minimum": {
"type": "number"
},
"multipleOf": {
"exclusiveMinimum": true,
"minimum": 0,
"type": "number"
},
"not": {
"$ref": "#"
},
"oneOf": {
"$ref": "#/definitions/schemaArray"
},
"pattern": {
"format": "regex",
"type": "string"
},
"patternProperties": {
"additionalProperties": {
"$ref": "#"
},
"default": {},
"type": "object"
},
"properties": {
"additionalProperties": {
"$ref": "#"
},
"default": {},
"type": "object"
},
"required": {
"$ref": "#/definitions/stringArray"
},
"title": {
"type": "string"
},
"type": {
"anyOf": [
{
"$ref": "#/definitions/simpleTypes"
},
{
"items": {
"$ref": "#/definitions/simpleTypes"
},
"minItems": 1,
"type": "array",
"uniqueItems": true
}
]
},
"uniqueItems": {
"default": false,
"type": "boolean"
}
},
"type": "object"
}

View file

@ -0,0 +1,153 @@
{
"$schema": "http://json-schema.org/draft-06/schema#",
"$id": "http://json-schema.org/draft-06/schema#",
"title": "Core schema meta-schema",
"definitions": {
"schemaArray": {
"type": "array",
"minItems": 1,
"items": { "$ref": "#" }
},
"nonNegativeInteger": {
"type": "integer",
"minimum": 0
},
"nonNegativeIntegerDefault0": {
"allOf": [
{ "$ref": "#/definitions/nonNegativeInteger" },
{ "default": 0 }
]
},
"simpleTypes": {
"enum": [
"array",
"boolean",
"integer",
"null",
"number",
"object",
"string"
]
},
"stringArray": {
"type": "array",
"items": { "type": "string" },
"uniqueItems": true,
"default": []
}
},
"type": ["object", "boolean"],
"properties": {
"$id": {
"type": "string",
"format": "uri-reference"
},
"$schema": {
"type": "string",
"format": "uri"
},
"$ref": {
"type": "string",
"format": "uri-reference"
},
"title": {
"type": "string"
},
"description": {
"type": "string"
},
"default": {},
"examples": {
"type": "array",
"items": {}
},
"multipleOf": {
"type": "number",
"exclusiveMinimum": 0
},
"maximum": {
"type": "number"
},
"exclusiveMaximum": {
"type": "number"
},
"minimum": {
"type": "number"
},
"exclusiveMinimum": {
"type": "number"
},
"maxLength": { "$ref": "#/definitions/nonNegativeInteger" },
"minLength": { "$ref": "#/definitions/nonNegativeIntegerDefault0" },
"pattern": {
"type": "string",
"format": "regex"
},
"additionalItems": { "$ref": "#" },
"items": {
"anyOf": [
{ "$ref": "#" },
{ "$ref": "#/definitions/schemaArray" }
],
"default": {}
},
"maxItems": { "$ref": "#/definitions/nonNegativeInteger" },
"minItems": { "$ref": "#/definitions/nonNegativeIntegerDefault0" },
"uniqueItems": {
"type": "boolean",
"default": false
},
"contains": { "$ref": "#" },
"maxProperties": { "$ref": "#/definitions/nonNegativeInteger" },
"minProperties": { "$ref": "#/definitions/nonNegativeIntegerDefault0" },
"required": { "$ref": "#/definitions/stringArray" },
"additionalProperties": { "$ref": "#" },
"definitions": {
"type": "object",
"additionalProperties": { "$ref": "#" },
"default": {}
},
"properties": {
"type": "object",
"additionalProperties": { "$ref": "#" },
"default": {}
},
"patternProperties": {
"type": "object",
"additionalProperties": { "$ref": "#" },
"propertyNames": { "format": "regex" },
"default": {}
},
"dependencies": {
"type": "object",
"additionalProperties": {
"anyOf": [
{ "$ref": "#" },
{ "$ref": "#/definitions/stringArray" }
]
}
},
"propertyNames": { "$ref": "#" },
"const": {},
"enum": {
"type": "array"
},
"type": {
"anyOf": [
{ "$ref": "#/definitions/simpleTypes" },
{
"type": "array",
"items": { "$ref": "#/definitions/simpleTypes" },
"minItems": 1,
"uniqueItems": true
}
]
},
"format": { "type": "string" },
"allOf": { "$ref": "#/definitions/schemaArray" },
"anyOf": { "$ref": "#/definitions/schemaArray" },
"oneOf": { "$ref": "#/definitions/schemaArray" },
"not": { "$ref": "#" }
},
"default": {}
}

View file

@ -0,0 +1,166 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "http://json-schema.org/draft-07/schema#",
"title": "Core schema meta-schema",
"definitions": {
"schemaArray": {
"type": "array",
"minItems": 1,
"items": { "$ref": "#" }
},
"nonNegativeInteger": {
"type": "integer",
"minimum": 0
},
"nonNegativeIntegerDefault0": {
"allOf": [
{ "$ref": "#/definitions/nonNegativeInteger" },
{ "default": 0 }
]
},
"simpleTypes": {
"enum": [
"array",
"boolean",
"integer",
"null",
"number",
"object",
"string"
]
},
"stringArray": {
"type": "array",
"items": { "type": "string" },
"uniqueItems": true,
"default": []
}
},
"type": ["object", "boolean"],
"properties": {
"$id": {
"type": "string",
"format": "uri-reference"
},
"$schema": {
"type": "string",
"format": "uri"
},
"$ref": {
"type": "string",
"format": "uri-reference"
},
"$comment": {
"type": "string"
},
"title": {
"type": "string"
},
"description": {
"type": "string"
},
"default": true,
"readOnly": {
"type": "boolean",
"default": false
},
"examples": {
"type": "array",
"items": true
},
"multipleOf": {
"type": "number",
"exclusiveMinimum": 0
},
"maximum": {
"type": "number"
},
"exclusiveMaximum": {
"type": "number"
},
"minimum": {
"type": "number"
},
"exclusiveMinimum": {
"type": "number"
},
"maxLength": { "$ref": "#/definitions/nonNegativeInteger" },
"minLength": { "$ref": "#/definitions/nonNegativeIntegerDefault0" },
"pattern": {
"type": "string",
"format": "regex"
},
"additionalItems": { "$ref": "#" },
"items": {
"anyOf": [
{ "$ref": "#" },
{ "$ref": "#/definitions/schemaArray" }
],
"default": true
},
"maxItems": { "$ref": "#/definitions/nonNegativeInteger" },
"minItems": { "$ref": "#/definitions/nonNegativeIntegerDefault0" },
"uniqueItems": {
"type": "boolean",
"default": false
},
"contains": { "$ref": "#" },
"maxProperties": { "$ref": "#/definitions/nonNegativeInteger" },
"minProperties": { "$ref": "#/definitions/nonNegativeIntegerDefault0" },
"required": { "$ref": "#/definitions/stringArray" },
"additionalProperties": { "$ref": "#" },
"definitions": {
"type": "object",
"additionalProperties": { "$ref": "#" },
"default": {}
},
"properties": {
"type": "object",
"additionalProperties": { "$ref": "#" },
"default": {}
},
"patternProperties": {
"type": "object",
"additionalProperties": { "$ref": "#" },
"propertyNames": { "format": "regex" },
"default": {}
},
"dependencies": {
"type": "object",
"additionalProperties": {
"anyOf": [
{ "$ref": "#" },
{ "$ref": "#/definitions/stringArray" }
]
}
},
"propertyNames": { "$ref": "#" },
"const": true,
"enum": {
"type": "array",
"items": true
},
"type": {
"anyOf": [
{ "$ref": "#/definitions/simpleTypes" },
{
"type": "array",
"items": { "$ref": "#/definitions/simpleTypes" },
"minItems": 1,
"uniqueItems": true
}
]
},
"format": { "type": "string" },
"contentMediaType": { "type": "string" },
"contentEncoding": { "type": "string" },
"if": {"$ref": "#"},
"then": {"$ref": "#"},
"else": {"$ref": "#"},
"allOf": { "$ref": "#/definitions/schemaArray" },
"anyOf": { "$ref": "#/definitions/schemaArray" },
"oneOf": { "$ref": "#/definitions/schemaArray" },
"not": { "$ref": "#" }
},
"default": true
}

View file

@ -0,0 +1,5 @@
def bug(issue=None):
message = "A known bug."
if issue is not None:
message += " See issue #{issue}.".format(issue=issue)
return message

View file

@ -0,0 +1,239 @@
"""
Python representations of the JSON Schema Test Suite tests.
"""
from functools import partial
import json
import os
import re
import subprocess
import sys
import unittest
from twisted.python.filepath import FilePath
import attr
from jsonschema.compat import PY3
from jsonschema.validators import validators
import jsonschema
def _find_suite():
root = os.environ.get("JSON_SCHEMA_TEST_SUITE")
if root is not None:
return FilePath(root)
root = FilePath(jsonschema.__file__).parent().sibling("json")
if not root.isdir(): # pragma: no cover
raise ValueError(
(
"Can't find the JSON-Schema-Test-Suite directory. "
"Set the 'JSON_SCHEMA_TEST_SUITE' environment "
"variable or run the tests from alongside a checkout "
"of the suite."
),
)
return root
@attr.s(hash=True)
class Suite(object):
_root = attr.ib(default=attr.Factory(_find_suite))
def _remotes(self):
jsonschema_suite = self._root.descendant(["bin", "jsonschema_suite"])
remotes = subprocess.check_output(
[sys.executable, jsonschema_suite.path, "remotes"],
)
return {
"http://localhost:1234/" + name: schema
for name, schema in json.loads(remotes.decode("utf-8")).items()
}
def benchmark(self, runner): # pragma: no cover
for name in validators:
self.version(name=name).benchmark(runner=runner)
def version(self, name):
return Version(
name=name,
path=self._root.descendant(["tests", name]),
remotes=self._remotes(),
)
@attr.s(hash=True)
class Version(object):
_path = attr.ib()
_remotes = attr.ib()
name = attr.ib()
def benchmark(self, runner, **kwargs): # pragma: no cover
for suite in self.tests():
for test in suite:
runner.bench_func(
test.fully_qualified_name,
partial(test.validate_ignoring_errors, **kwargs),
)
def tests(self):
return (
test
for child in self._path.globChildren("*.json")
for test in self._tests_in(
subject=child.basename()[:-5],
path=child,
)
)
def format_tests(self):
path = self._path.descendant(["optional", "format"])
return (
test
for child in path.globChildren("*.json")
for test in self._tests_in(
subject=child.basename()[:-5],
path=child,
)
)
def tests_of(self, name):
return self._tests_in(
subject=name,
path=self._path.child(name + ".json"),
)
def optional_tests_of(self, name):
return self._tests_in(
subject=name,
path=self._path.descendant(["optional", name + ".json"]),
)
def to_unittest_testcase(self, *suites, **kwargs):
name = kwargs.pop("name", "Test" + self.name.title())
methods = {
test.method_name: test.to_unittest_method(**kwargs)
for suite in suites
for tests in suite
for test in tests
}
cls = type(name, (unittest.TestCase,), methods)
try:
cls.__module__ = _someone_save_us_the_module_of_the_caller()
except Exception: # pragma: no cover
# We're doing crazy things, so if they go wrong, like a function
# behaving differently on some other interpreter, just make them
# not happen.
pass
return cls
def _tests_in(self, subject, path):
for each in json.loads(path.getContent().decode("utf-8")):
yield (
_Test(
version=self,
subject=subject,
case_description=each["description"],
schema=each["schema"],
remotes=self._remotes,
**test
) for test in each["tests"]
)
@attr.s(hash=True, repr=False)
class _Test(object):
version = attr.ib()
subject = attr.ib()
case_description = attr.ib()
description = attr.ib()
data = attr.ib()
schema = attr.ib(repr=False)
valid = attr.ib()
_remotes = attr.ib()
def __repr__(self): # pragma: no cover
return "<Test {}>".format(self.fully_qualified_name)
@property
def fully_qualified_name(self): # pragma: no cover
return " > ".join(
[
self.version.name,
self.subject,
self.case_description,
self.description,
]
)
@property
def method_name(self):
delimiters = r"[\W\- ]+"
name = "test_%s_%s_%s" % (
re.sub(delimiters, "_", self.subject),
re.sub(delimiters, "_", self.case_description),
re.sub(delimiters, "_", self.description),
)
if not PY3: # pragma: no cover
name = name.encode("utf-8")
return name
def to_unittest_method(self, skip=lambda test: None, **kwargs):
if self.valid:
def fn(this):
self.validate(**kwargs)
else:
def fn(this):
with this.assertRaises(jsonschema.ValidationError):
self.validate(**kwargs)
fn.__name__ = self.method_name
reason = skip(self)
return unittest.skipIf(reason is not None, reason)(fn)
def validate(self, Validator, **kwargs):
resolver = jsonschema.RefResolver.from_schema(
schema=self.schema,
store=self._remotes,
id_of=Validator.ID_OF,
)
jsonschema.validate(
instance=self.data,
schema=self.schema,
cls=Validator,
resolver=resolver,
**kwargs
)
def validate_ignoring_errors(self, Validator): # pragma: no cover
try:
self.validate(Validator=Validator)
except jsonschema.ValidationError:
pass
def _someone_save_us_the_module_of_the_caller():
"""
The FQON of the module 2nd stack frames up from here.
This is intended to allow us to dynamicallly return test case classes that
are indistinguishable from being defined in the module that wants them.
Otherwise, trial will mis-print the FQON, and copy pasting it won't re-run
the class that really is running.
Save us all, this is all so so so so so terrible.
"""
return sys._getframe(2).f_globals["__name__"]

View file

@ -0,0 +1,151 @@
from unittest import TestCase
import json
import subprocess
import sys
from jsonschema import Draft4Validator, ValidationError, cli, __version__
from jsonschema.compat import NativeIO
from jsonschema.exceptions import SchemaError
def fake_validator(*errors):
errors = list(reversed(errors))
class FakeValidator(object):
def __init__(self, *args, **kwargs):
pass
def iter_errors(self, instance):
if errors:
return errors.pop()
return []
def check_schema(self, schema):
pass
return FakeValidator
class TestParser(TestCase):
FakeValidator = fake_validator()
instance_file = "foo.json"
schema_file = "schema.json"
def setUp(self):
cli.open = self.fake_open
self.addCleanup(delattr, cli, "open")
def fake_open(self, path):
if path == self.instance_file:
contents = ""
elif path == self.schema_file:
contents = {}
else: # pragma: no cover
self.fail("What is {!r}".format(path))
return NativeIO(json.dumps(contents))
def test_find_validator_by_fully_qualified_object_name(self):
arguments = cli.parse_args(
[
"--validator",
"jsonschema.tests.test_cli.TestParser.FakeValidator",
"--instance", self.instance_file,
self.schema_file,
]
)
self.assertIs(arguments["validator"], self.FakeValidator)
def test_find_validator_in_jsonschema(self):
arguments = cli.parse_args(
[
"--validator", "Draft4Validator",
"--instance", self.instance_file,
self.schema_file,
]
)
self.assertIs(arguments["validator"], Draft4Validator)
class TestCLI(TestCase):
def test_draft3_schema_draft4_validator(self):
stdout, stderr = NativeIO(), NativeIO()
with self.assertRaises(SchemaError):
cli.run(
{
"validator": Draft4Validator,
"schema": {
"anyOf": [
{"minimum": 20},
{"type": "string"},
{"required": True},
],
},
"instances": [1],
"error_format": "{error.message}",
},
stdout=stdout,
stderr=stderr,
)
def test_successful_validation(self):
stdout, stderr = NativeIO(), NativeIO()
exit_code = cli.run(
{
"validator": fake_validator(),
"schema": {},
"instances": [1],
"error_format": "{error.message}",
},
stdout=stdout,
stderr=stderr,
)
self.assertFalse(stdout.getvalue())
self.assertFalse(stderr.getvalue())
self.assertEqual(exit_code, 0)
def test_unsuccessful_validation(self):
error = ValidationError("I am an error!", instance=1)
stdout, stderr = NativeIO(), NativeIO()
exit_code = cli.run(
{
"validator": fake_validator([error]),
"schema": {},
"instances": [1],
"error_format": "{error.instance} - {error.message}",
},
stdout=stdout,
stderr=stderr,
)
self.assertFalse(stdout.getvalue())
self.assertEqual(stderr.getvalue(), "1 - I am an error!")
self.assertEqual(exit_code, 1)
def test_unsuccessful_validation_multiple_instances(self):
first_errors = [
ValidationError("9", instance=1),
ValidationError("8", instance=1),
]
second_errors = [ValidationError("7", instance=2)]
stdout, stderr = NativeIO(), NativeIO()
exit_code = cli.run(
{
"validator": fake_validator(first_errors, second_errors),
"schema": {},
"instances": [1, 2],
"error_format": "{error.instance} - {error.message}\t",
},
stdout=stdout,
stderr=stderr,
)
self.assertFalse(stdout.getvalue())
self.assertEqual(stderr.getvalue(), "1 - 9\t1 - 8\t2 - 7\t")
self.assertEqual(exit_code, 1)
def test_version(self):
version = subprocess.check_output(
[sys.executable, "-m", "jsonschema", "--version"],
stderr=subprocess.STDOUT,
)
version = version.decode("utf-8").strip()
self.assertEqual(version, __version__)

View file

@ -0,0 +1,462 @@
from unittest import TestCase
import textwrap
from jsonschema import Draft4Validator, exceptions
from jsonschema.compat import PY3
class TestBestMatch(TestCase):
def best_match(self, errors):
errors = list(errors)
best = exceptions.best_match(errors)
reversed_best = exceptions.best_match(reversed(errors))
msg = "Didn't return a consistent best match!\nGot: {0}\n\nThen: {1}"
self.assertEqual(
best._contents(), reversed_best._contents(),
msg=msg.format(best, reversed_best),
)
return best
def test_shallower_errors_are_better_matches(self):
validator = Draft4Validator(
{
"properties": {
"foo": {
"minProperties": 2,
"properties": {"bar": {"type": "object"}},
},
},
},
)
best = self.best_match(validator.iter_errors({"foo": {"bar": []}}))
self.assertEqual(best.validator, "minProperties")
def test_oneOf_and_anyOf_are_weak_matches(self):
"""
A property you *must* match is probably better than one you have to
match a part of.
"""
validator = Draft4Validator(
{
"minProperties": 2,
"anyOf": [{"type": "string"}, {"type": "number"}],
"oneOf": [{"type": "string"}, {"type": "number"}],
}
)
best = self.best_match(validator.iter_errors({}))
self.assertEqual(best.validator, "minProperties")
def test_if_the_most_relevant_error_is_anyOf_it_is_traversed(self):
"""
If the most relevant error is an anyOf, then we traverse its context
and select the otherwise *least* relevant error, since in this case
that means the most specific, deep, error inside the instance.
I.e. since only one of the schemas must match, we look for the most
relevant one.
"""
validator = Draft4Validator(
{
"properties": {
"foo": {
"anyOf": [
{"type": "string"},
{"properties": {"bar": {"type": "array"}}},
],
},
},
},
)
best = self.best_match(validator.iter_errors({"foo": {"bar": 12}}))
self.assertEqual(best.validator_value, "array")
def test_if_the_most_relevant_error_is_oneOf_it_is_traversed(self):
"""
If the most relevant error is an oneOf, then we traverse its context
and select the otherwise *least* relevant error, since in this case
that means the most specific, deep, error inside the instance.
I.e. since only one of the schemas must match, we look for the most
relevant one.
"""
validator = Draft4Validator(
{
"properties": {
"foo": {
"oneOf": [
{"type": "string"},
{"properties": {"bar": {"type": "array"}}},
],
},
},
},
)
best = self.best_match(validator.iter_errors({"foo": {"bar": 12}}))
self.assertEqual(best.validator_value, "array")
def test_if_the_most_relevant_error_is_allOf_it_is_traversed(self):
"""
Now, if the error is allOf, we traverse but select the *most* relevant
error from the context, because all schemas here must match anyways.
"""
validator = Draft4Validator(
{
"properties": {
"foo": {
"allOf": [
{"type": "string"},
{"properties": {"bar": {"type": "array"}}},
],
},
},
},
)
best = self.best_match(validator.iter_errors({"foo": {"bar": 12}}))
self.assertEqual(best.validator_value, "string")
def test_nested_context_for_oneOf(self):
validator = Draft4Validator(
{
"properties": {
"foo": {
"oneOf": [
{"type": "string"},
{
"oneOf": [
{"type": "string"},
{
"properties": {
"bar": {"type": "array"},
},
},
],
},
],
},
},
},
)
best = self.best_match(validator.iter_errors({"foo": {"bar": 12}}))
self.assertEqual(best.validator_value, "array")
def test_one_error(self):
validator = Draft4Validator({"minProperties": 2})
error, = validator.iter_errors({})
self.assertEqual(
exceptions.best_match(validator.iter_errors({})).validator,
"minProperties",
)
def test_no_errors(self):
validator = Draft4Validator({})
self.assertIsNone(exceptions.best_match(validator.iter_errors({})))
class TestByRelevance(TestCase):
def test_short_paths_are_better_matches(self):
shallow = exceptions.ValidationError("Oh no!", path=["baz"])
deep = exceptions.ValidationError("Oh yes!", path=["foo", "bar"])
match = max([shallow, deep], key=exceptions.relevance)
self.assertIs(match, shallow)
match = max([deep, shallow], key=exceptions.relevance)
self.assertIs(match, shallow)
def test_global_errors_are_even_better_matches(self):
shallow = exceptions.ValidationError("Oh no!", path=[])
deep = exceptions.ValidationError("Oh yes!", path=["foo"])
errors = sorted([shallow, deep], key=exceptions.relevance)
self.assertEqual(
[list(error.path) for error in errors],
[["foo"], []],
)
errors = sorted([deep, shallow], key=exceptions.relevance)
self.assertEqual(
[list(error.path) for error in errors],
[["foo"], []],
)
def test_weak_validators_are_lower_priority(self):
weak = exceptions.ValidationError("Oh no!", path=[], validator="a")
normal = exceptions.ValidationError("Oh yes!", path=[], validator="b")
best_match = exceptions.by_relevance(weak="a")
match = max([weak, normal], key=best_match)
self.assertIs(match, normal)
match = max([normal, weak], key=best_match)
self.assertIs(match, normal)
def test_strong_validators_are_higher_priority(self):
weak = exceptions.ValidationError("Oh no!", path=[], validator="a")
normal = exceptions.ValidationError("Oh yes!", path=[], validator="b")
strong = exceptions.ValidationError("Oh fine!", path=[], validator="c")
best_match = exceptions.by_relevance(weak="a", strong="c")
match = max([weak, normal, strong], key=best_match)
self.assertIs(match, strong)
match = max([strong, normal, weak], key=best_match)
self.assertIs(match, strong)
class TestErrorTree(TestCase):
def test_it_knows_how_many_total_errors_it_contains(self):
# FIXME: https://github.com/Julian/jsonschema/issues/442
errors = [
exceptions.ValidationError("Something", validator=i)
for i in range(8)
]
tree = exceptions.ErrorTree(errors)
self.assertEqual(tree.total_errors, 8)
def test_it_contains_an_item_if_the_item_had_an_error(self):
errors = [exceptions.ValidationError("a message", path=["bar"])]
tree = exceptions.ErrorTree(errors)
self.assertIn("bar", tree)
def test_it_does_not_contain_an_item_if_the_item_had_no_error(self):
errors = [exceptions.ValidationError("a message", path=["bar"])]
tree = exceptions.ErrorTree(errors)
self.assertNotIn("foo", tree)
def test_validators_that_failed_appear_in_errors_dict(self):
error = exceptions.ValidationError("a message", validator="foo")
tree = exceptions.ErrorTree([error])
self.assertEqual(tree.errors, {"foo": error})
def test_it_creates_a_child_tree_for_each_nested_path(self):
errors = [
exceptions.ValidationError("a bar message", path=["bar"]),
exceptions.ValidationError("a bar -> 0 message", path=["bar", 0]),
]
tree = exceptions.ErrorTree(errors)
self.assertIn(0, tree["bar"])
self.assertNotIn(1, tree["bar"])
def test_children_have_their_errors_dicts_built(self):
e1, e2 = (
exceptions.ValidationError("1", validator="foo", path=["bar", 0]),
exceptions.ValidationError("2", validator="quux", path=["bar", 0]),
)
tree = exceptions.ErrorTree([e1, e2])
self.assertEqual(tree["bar"][0].errors, {"foo": e1, "quux": e2})
def test_multiple_errors_with_instance(self):
e1, e2 = (
exceptions.ValidationError(
"1",
validator="foo",
path=["bar", "bar2"],
instance="i1"),
exceptions.ValidationError(
"2",
validator="quux",
path=["foobar", 2],
instance="i2"),
)
exceptions.ErrorTree([e1, e2])
def test_it_does_not_contain_subtrees_that_are_not_in_the_instance(self):
error = exceptions.ValidationError("123", validator="foo", instance=[])
tree = exceptions.ErrorTree([error])
with self.assertRaises(IndexError):
tree[0]
def test_if_its_in_the_tree_anyhow_it_does_not_raise_an_error(self):
"""
If a validator is dumb (like :validator:`required` in draft 3) and
refers to a path that isn't in the instance, the tree still properly
returns a subtree for that path.
"""
error = exceptions.ValidationError(
"a message", validator="foo", instance={}, path=["foo"],
)
tree = exceptions.ErrorTree([error])
self.assertIsInstance(tree["foo"], exceptions.ErrorTree)
class TestErrorInitReprStr(TestCase):
def make_error(self, **kwargs):
defaults = dict(
message=u"hello",
validator=u"type",
validator_value=u"string",
instance=5,
schema={u"type": u"string"},
)
defaults.update(kwargs)
return exceptions.ValidationError(**defaults)
def assertShows(self, expected, **kwargs):
if PY3: # pragma: no cover
expected = expected.replace("u'", "'")
expected = textwrap.dedent(expected).rstrip("\n")
error = self.make_error(**kwargs)
message_line, _, rest = str(error).partition("\n")
self.assertEqual(message_line, error.message)
self.assertEqual(rest, expected)
def test_it_calls_super_and_sets_args(self):
error = self.make_error()
self.assertGreater(len(error.args), 1)
def test_repr(self):
self.assertEqual(
repr(exceptions.ValidationError(message="Hello!")),
"<ValidationError: %r>" % "Hello!",
)
def test_unset_error(self):
error = exceptions.ValidationError("message")
self.assertEqual(str(error), "message")
kwargs = {
"validator": "type",
"validator_value": "string",
"instance": 5,
"schema": {"type": "string"},
}
# Just the message should show if any of the attributes are unset
for attr in kwargs:
k = dict(kwargs)
del k[attr]
error = exceptions.ValidationError("message", **k)
self.assertEqual(str(error), "message")
def test_empty_paths(self):
self.assertShows(
"""
Failed validating u'type' in schema:
{u'type': u'string'}
On instance:
5
""",
path=[],
schema_path=[],
)
def test_one_item_paths(self):
self.assertShows(
"""
Failed validating u'type' in schema:
{u'type': u'string'}
On instance[0]:
5
""",
path=[0],
schema_path=["items"],
)
def test_multiple_item_paths(self):
self.assertShows(
"""
Failed validating u'type' in schema[u'items'][0]:
{u'type': u'string'}
On instance[0][u'a']:
5
""",
path=[0, u"a"],
schema_path=[u"items", 0, 1],
)
def test_uses_pprint(self):
self.assertShows(
"""
Failed validating u'maxLength' in schema:
{0: 0,
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
12: 12,
13: 13,
14: 14,
15: 15,
16: 16,
17: 17,
18: 18,
19: 19}
On instance:
[0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24]
""",
instance=list(range(25)),
schema=dict(zip(range(20), range(20))),
validator=u"maxLength",
)
def test_str_works_with_instances_having_overriden_eq_operator(self):
"""
Check for https://github.com/Julian/jsonschema/issues/164 which
rendered exceptions unusable when a `ValidationError` involved
instances with an `__eq__` method that returned truthy values.
"""
class DontEQMeBro(object):
def __eq__(this, other): # pragma: no cover
self.fail("Don't!")
def __ne__(this, other): # pragma: no cover
self.fail("Don't!")
instance = DontEQMeBro()
error = exceptions.ValidationError(
"a message",
validator="foo",
instance=instance,
validator_value="some",
schema="schema",
)
self.assertIn(repr(instance), str(error))
class TestHashable(TestCase):
def test_hashable(self):
set([exceptions.ValidationError("")])
set([exceptions.SchemaError("")])

View file

@ -0,0 +1,89 @@
"""
Tests for the parts of jsonschema related to the :validator:`format` property.
"""
from unittest import TestCase
from jsonschema import FormatError, ValidationError, FormatChecker
from jsonschema.validators import Draft4Validator
BOOM = ValueError("Boom!")
BANG = ZeroDivisionError("Bang!")
def boom(thing):
if thing == "bang":
raise BANG
raise BOOM
class TestFormatChecker(TestCase):
def test_it_can_validate_no_formats(self):
checker = FormatChecker(formats=())
self.assertFalse(checker.checkers)
def test_it_raises_a_key_error_for_unknown_formats(self):
with self.assertRaises(KeyError):
FormatChecker(formats=["o noes"])
def test_it_can_register_cls_checkers(self):
original = dict(FormatChecker.checkers)
self.addCleanup(FormatChecker.checkers.pop, "boom")
FormatChecker.cls_checks("boom")(boom)
self.assertEqual(
FormatChecker.checkers,
dict(original, boom=(boom, ())),
)
def test_it_can_register_checkers(self):
checker = FormatChecker()
checker.checks("boom")(boom)
self.assertEqual(
checker.checkers,
dict(FormatChecker.checkers, boom=(boom, ()))
)
def test_it_catches_registered_errors(self):
checker = FormatChecker()
checker.checks("boom", raises=type(BOOM))(boom)
with self.assertRaises(FormatError) as cm:
checker.check(instance=12, format="boom")
self.assertIs(cm.exception.cause, BOOM)
self.assertIs(cm.exception.__cause__, BOOM)
# Unregistered errors should not be caught
with self.assertRaises(type(BANG)):
checker.check(instance="bang", format="boom")
def test_format_error_causes_become_validation_error_causes(self):
checker = FormatChecker()
checker.checks("boom", raises=ValueError)(boom)
validator = Draft4Validator({"format": "boom"}, format_checker=checker)
with self.assertRaises(ValidationError) as cm:
validator.validate("BOOM")
self.assertIs(cm.exception.cause, BOOM)
self.assertIs(cm.exception.__cause__, BOOM)
def test_format_checkers_come_with_defaults(self):
# This is bad :/ but relied upon.
# The docs for quite awhile recommended people do things like
# validate(..., format_checker=FormatChecker())
# We should change that, but we can't without deprecation...
checker = FormatChecker()
with self.assertRaises(FormatError):
checker.check(instance="not-an-ipv4", format="ipv4")
def test_repr(self):
checker = FormatChecker(formats=())
checker.checks("foo")(lambda thing: True)
checker.checks("bar")(lambda thing: True)
checker.checks("baz")(lambda thing: True)
self.assertEqual(
repr(checker),
"<FormatChecker checkers=['bar', 'baz', 'foo']>",
)

Some files were not shown because too many files have changed in this diff Show more