profile
viewpoint

google/jax 9978

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

google/trax 5111

Trax — Deep Learning with Clear Code and Speed

google/flax 1100

Flax is a neural network library for JAX that is designed for flexibility.

levskaya/jslinux-deobfuscated 912

An old version of Mr. Bellard's JSLinux rewritten to be human readable, hand deobfuscated and annotated.

levskaya/eschersketch 188

A drawing program for exploring symmetrical designs

levskaya/polyhedronisme 133

A tool to construct and explore polyhedra.

broxtronix/Fiat-Lux 14

Lasers. On ice.

levskaya/OW 11

Source code for generating personal website based on Hyde.

levskaya/BioJSON 9

Parsers to wrangle the plethora of biological flatfile formats into standard JSON

levskaya/coffee-mode 1

Emacs Major Mode for CoffeeScript

PullRequestReviewEvent

PR opened google/flax

Fix transient module-state handling by transforms.

This should fix the tests failing with default named_call profiling.

+7 -0

0 comment

1 changed file

pr created time in 8 hours

create barnchlevskaya/flax

branch : trafofix

created branch time in 8 hours

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

pull request commentgoogle/flax

use core reservation check in linen.

Well we disabled it because we were doing our own name-reservation handling, perhaps at the time because we had to for some short-lived reason while we were prototyping things.

jheek

comment created time in 2 days

issue commentgoogle/jax

Deepcopy of bfloat16 array messes up bfloat16 definition.

Yeah the deepcopy issue was impacting a user just yesterday whom I was helping, and I was completely flummoxed by why things were ending up back on host!

frsong

comment created time in 4 days

PullRequestReviewEvent

Pull request review commentgoogle/flax

Allow assignment of mixed Module pytrees in setup.

 def is_module_tree(in_tree: Any) -> bool:   # reject trivial pytrees, {}, [], (), etc.   if not tree_util.tree_leaves(in_tree):     return False-  reduce_fn = lambda prev, cur: prev and isinstance(cur, Module)-  return jax.tree_util.tree_reduce(reduce_fn, in_tree, True)+  reduce_fn = lambda prev, cur: prev or isinstance(cur, Module)

ok, I cleaned all this up

levskaya

comment created time in 10 days

PullRequestReviewEvent

Pull request review commentgoogle/flax

Allow assignment of mixed Module pytrees in setup.

 def is_module_tree(in_tree: Any) -> bool:   # reject trivial pytrees, {}, [], (), etc.

removed this fn.

levskaya

comment created time in 10 days

PullRequestReviewEvent

push eventlevskaya/flax

Anselm Levskaya

commit sha a1989410227f62636ce2db0a24abdc4f62fa64b2

Fix Module dataclass handling of inheritance. The dataclass transform looks into the base class __dataclass_fields__ to set inherited dataclass args. We were permanently mutating these base class fields, which would screw up the handling of multiple separate subclasses of a Module. This fixes this edgecase.

view details

Mohit Reddy

commit sha 4dcad3a1953f37e0a519bdefbe1867f6100d3edb

Create benchmark for full linen imagenet training. PiperOrigin-RevId: 337093018

view details

Anselm Levskaya

commit sha 5b808a8d372144e6b6373ed3be5cc80c1613625c

Update flax/linen/module.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

view details

Anselm Levskaya

commit sha 8679b8a6659c99ee55693ceccc57c929144ddc1b

Update flax/linen/module.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

view details

Flax Authors

commit sha 9fa2831d90adf44bcf1b05391b3ece55d3db835d

Merge pull request #535 from levskaya:inheritancebugfix PiperOrigin-RevId: 337172636

view details

Anselm Levskaya

commit sha 573a9084e3872e9351ebdbd0e7d80b578b592815

Allow assignment of mixed Module pytrees in setup. Previously we forced assignment of Module lists in setup() to only allow lists containing -only- Module subclasses. But people would like to specify lists containing functions like nn.relu as well, and there's absolutely no reason to force this strict convention, so it is relaxed here.

view details

Anselm Levskaya

commit sha c2f0fdc9145d2440f96071b98841f04ea405d0d9

Simplify __setattr__ handling of trees with Modules.

view details

push time in 10 days

Pull request review commentgoogle/flax

Check shape consistency for params

 def f(scope):      init(f)(random.PRNGKey(0)) -  +  def test_inconsistent_param_shapes(self):+    def f(scope):+      scope.param('test', nn.initializers.ones, (4,))+    +    msg = 'Inconisist shapes between value and initializer for paramater "test": (2,), (4,)'

"Inconsistent shapes between value and initializer for parameter" - the spelling fix is breaking the test :)

jheek

comment created time in 10 days

PullRequestReviewEvent
PullRequestReviewEvent

Pull request review commentgoogle/flax

Fix Module dataclass handling of inheritance.

 def __init_subclass__(cls):     cls.scope = None    @classmethod-  def _add_parent_and_name_attrs(cls):-    """Add final optional dataclass attributes: `parent` and `name`."""-    annotations = cls.__dict__.get('__annotations__', {})+  def _customized_dataclass_transform(cls):+    """Handle final optional dataclass attributes: `parent` and `name`."""+    annotations = dict(cls.__dict__.get('__annotations__', {}))     if 'parent' in annotations or 'name' in annotations:       raise ValueError(           f'properties `parent` and `name` are reserved: {annotations}')     # Add `parent` and `name` default fields at end.-    new_annotations = {}-    new_annotations.update(annotations)-    if 'parent' in getattr(cls, '__dataclass_fields__', {}):+    # We temporarily modify base class __dataclass_fields__ to force desired+    # argument behavior and ordering from dataclass class-transform.+    parent_dataclass_fields = dict(getattr(cls, '__dataclass_fields__', {}))+    if 'parent' in parent_dataclass_fields:       cls.__dataclass_fields__.pop('parent')-    new_annotations['parent'] = Union[Type["Module"], Type["Scope"],-                                      Type["_Sentinel"], None]-    cls.parent = dataclasses.field(repr=False, default=_unspecified_parent)-    if 'name' in getattr(cls, '__dataclass_fields__', {}):+    if 'name' in parent_dataclass_fields:       cls.__dataclass_fields__.pop('name')-    new_annotations['name'] = str-    cls.__annotations__ = new_annotations+    annotations['parent'] = Union[Type["Module"], Type["Scope"],+                                  Type["_Sentinel"], None]+    cls.parent = dataclasses.field(repr=False, default=_unspecified_parent)+    annotations['name'] = str     cls.name = None  # default value of name is None.+    cls.__annotations__ = annotations+    # Now apply dataclass transform (which operates in-place).+    dataclasses.dataclass(cls)+    # Restore original base class __dataclass_fields__.+    if dataclasses.is_dataclass(cls.__bases__[0]):+     cls.__bases__[0].__dataclass_fields__ = parent_dataclass_fields

When parent_dataclass_fields is grabbed and copied by dict(getattr(cls, '__dataclass_fields__', {})) we know there is no __dataclass_fields__ on the child, because the @dataclass transform hasn't been run on this class yet. So the only fields that can show up come from direct inheritance. (Note there's the wrinkle of trying to deal w. multiple inheritance in this code, but that is always a buggy nightmare in python and we should just explicitly forbid multiple inheritance - dataclasses themselves only barely work in a useable way w. multiple inheritance.)

levskaya

comment created time in 10 days

PullRequestReviewEvent

push eventlevskaya/flax

Anselm Levskaya

commit sha 8679b8a6659c99ee55693ceccc57c929144ddc1b

Update flax/linen/module.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

view details

push time in 10 days

push eventlevskaya/flax

Anselm Levskaya

commit sha 5b808a8d372144e6b6373ed3be5cc80c1613625c

Update flax/linen/module.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

view details

push time in 10 days

Pull request review commentgoogle/flax

Allow assignment of mixed Module pytrees in setup.

 def is_module_tree(in_tree: Any) -> bool:   # reject trivial pytrees, {}, [], (), etc.   if not tree_util.tree_leaves(in_tree):     return False-  reduce_fn = lambda prev, cur: prev and isinstance(cur, Module)-  return jax.tree_util.tree_reduce(reduce_fn, in_tree, True)+  reduce_fn = lambda prev, cur: prev or isinstance(cur, Module)

yeah this seems super dumb now, I'll try to remove this nonsense tomorrow.

levskaya

comment created time in 11 days

PullRequestReviewEvent

Pull request review commentgoogle/flax

Check shape consistency for params

 def variable(self, col: str, name: str, init_fn: Callable[..., T],    def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:     """Create a paramater."""-    s_init_fn = lambda *args: init_fn(self.make_rng('params'), *init_args)-    v = self.variable('params', name, s_init_fn, *init_args)-    return v.value+    if self.has_variable('params', name):+      abs_rng = jax.ShapeDtypeStruct((2,), jnp.uint32)+      value = self.get_variable('params', name)+      # validate shape of init_fn output is the same as the shape of the existing+      # paramater.+      abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng)+      abs_value_flat = jax.tree_leaves(abs_value)+      value_flat = jax.tree_leaves(value)+      for val, abs_val in zip(value_flat, abs_value_flat):+        # NOTE: we could check dtype consistency here as well but it's usefuleness is less obvious.+        # we might intentionally change the dtype for inference to a half float type for example.+        if jnp.shape(val) != jnp.shape(abs_val):+          raise ValueError('Inconisist shapes between value and initializer '+                           f'for paramater "{name}": {jnp.shape(val)}, {jnp.shape(abs_val)}')

parameter

jheek

comment created time in 11 days

Pull request review commentgoogle/flax

Check shape consistency for params

 def variable(self, col: str, name: str, init_fn: Callable[..., T],    def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:     """Create a paramater."""-    s_init_fn = lambda *args: init_fn(self.make_rng('params'), *init_args)-    v = self.variable('params', name, s_init_fn, *init_args)-    return v.value+    if self.has_variable('params', name):+      abs_rng = jax.ShapeDtypeStruct((2,), jnp.uint32)+      value = self.get_variable('params', name)+      # validate shape of init_fn output is the same as the shape of the existing+      # paramater.+      abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng)+      abs_value_flat = jax.tree_leaves(abs_value)+      value_flat = jax.tree_leaves(value)+      for val, abs_val in zip(value_flat, abs_value_flat):+        # NOTE: we could check dtype consistency here as well but it's usefuleness is less obvious.+        # we might intentionally change the dtype for inference to a half float type for example.+        if jnp.shape(val) != jnp.shape(abs_val):+          raise ValueError('Inconisist shapes between value and initializer '

Inconsistent

jheek

comment created time in 11 days

PullRequestReviewEvent
PullRequestReviewEvent

Pull request review commentgoogle/flax

Check shape consistency for params

 def variable(self, col: str, name: str, init_fn: Callable[..., T],    def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:     """Create a paramater."""-    s_init_fn = lambda *args: init_fn(self.make_rng('params'), *init_args)-    v = self.variable('params', name, s_init_fn, *init_args)-    return v.value+    if self.has_variable('params', name):+      abs_rng = jax.ShapeDtypeStruct((2,), jnp.uint32)+      value = self.get_variable('params', name)+      # validate shape of init_fn output is the same as the shape of the existing+      # paramater.+      abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng)+      abs_value_flat = jax.tree_leaves(abs_value)+      value_flat = jax.tree_leaves(value)+      for val, abs_val in zip(value_flat, abs_value_flat):+        # NOTE: we could check dtype consistency here as well but it's usefuleness is less obvious.+        # we might intentionally change the dtype for inference to a half float type for example.+        if jnp.shape(val) != jnp.shape(abs_val):+          raise ValueError('Inconisist shapes between value and initializer '

Inconsistent

jheek

comment created time in 11 days

PullRequestReviewEvent

PR opened google/flax

Fix Module dataclass handling of inheritance.

The dataclass transform looks into the base class dataclass_fields to set inherited dataclass args. We were permanently mutating these base class fields, which would screw up the handling of multiple separate subclasses of a Module. This fixes this edgecase.

+63 -15

0 comment

2 changed files

pr created time in 11 days

create barnchlevskaya/flax

branch : inheritancebugfix

created branch time in 11 days

PullRequestReviewEvent

PR opened google/flax

Allow assignment of mixed Module pytrees in setup.

Previously we forced assignment of Module lists in setup() to only allow lists containing -only- Module subclasses. But people would like to specify lists containing functions like nn.relu as well, and there's absolutely no reason to force this strict convention, so it is relaxed here.

+47 -3

0 comment

2 changed files

pr created time in 11 days

create barnchlevskaya/flax

branch : modfix

created branch time in 11 days

issue commentgoogle/jax

factor named_call primitive into jax core

I'd be a big fan of doing this! I tried suggesting it before, but I think people were too busy to think about it at the time. It would be much better having such a simple, basic primitive live in a single place in JAX.

froystig

comment created time in 12 days

startedrui314/chibicc

started time in 21 days

Pull request review commentgoogle/flax

Use automasking in seq2seq example

 def __call__(self, carry, x):     carry_rng, categorical_rng = jax.random.split(rng, 2)     if not self.teacher_force:       x = last_prediction-    lstm_cell = nn.LSTMCell(name='lstm_cell')-    projection = nn.Dense(features=self.vocab_size, name='projection')-    lstm_state, y = lstm_cell(lstm_state, x)-    logits = projection(y)-    predicted_tokens = jax.random.categorical(categorical_rng, logits)-    prediction = onehot(predicted_tokens, self.vocab_size)+    lstm_state, y = nn.LSTMCell()(lstm_state, x)+    logits = nn.Dense(features=CTABLE.vocab_size)(y)+    predicted_token = jax.random.categorical(categorical_rng, logits)+    prediction = jnp.array(predicted_token == jnp.arange(CTABLE.vocab_size), +                           dtype=jnp.float32)

Oh sure, I was just curious about it!

marcvanzee

comment created time in 23 days

PullRequestReviewEvent

Pull request review commentgoogle/flax

Use automasking in seq2seq example

 def compute_metrics(logits, labels):   return metrics  +IN_SHAPES = [{'query': '(n, _)', 'answer': '(m, _)'}]+OUT_ELEM = f'(m +- 1, {CTABLE.vocab_size})'

ok cool that makes sense.

marcvanzee

comment created time in 23 days

PullRequestReviewEvent

Pull request review commentgoogle/flax

Use automasking in seq2seq example

 def encode_onehot(batch_inputs, max_len):    def encode_str(s):     tokens = CTABLE.encode(s)-    if len(tokens) > max_len:+    unpadded_len = len(tokens)+    if unpadded_len > max_len:       raise ValueError(f'Sequence too long ({len(tokens)}>{max_len}): \'{s}\'')     tokens = np.pad(tokens, [(0, max_len-len(tokens))], mode='constant')-    return onehot(tokens, CTABLE.vocab_size)+    return onehot(tokens, CTABLE.vocab_size), unpadded_len -  return np.array([encode_str(inp) for inp in batch_inputs])+  return np.array([encode_str(inp) for inp in batch_inputs], dtype=object)

but why wrap this in an ndarray at all?? why not just keep it as a list?

marcvanzee

comment created time in 23 days

PullRequestReviewEvent

Pull request review commentgoogle/flax

Use automasking in seq2seq example

 def __call__(self, carry, x):     carry_rng, categorical_rng = jax.random.split(rng, 2)     if not self.teacher_force:       x = last_prediction-    lstm_cell = nn.LSTMCell(name='lstm_cell')-    projection = nn.Dense(features=self.vocab_size, name='projection')-    lstm_state, y = lstm_cell(lstm_state, x)-    logits = projection(y)-    predicted_tokens = jax.random.categorical(categorical_rng, logits)-    prediction = onehot(predicted_tokens, self.vocab_size)+    lstm_state, y = nn.LSTMCell()(lstm_state, x)+    logits = nn.Dense(features=CTABLE.vocab_size)(y)+    predicted_token = jax.random.categorical(categorical_rng, logits)+    prediction = jnp.array(predicted_token == jnp.arange(CTABLE.vocab_size), +                           dtype=jnp.float32)

just curious: why return things in one-hot form instead of just integer ids? It's probably obvious I'm just dumb about this.

marcvanzee

comment created time in 23 days

Pull request review commentgoogle/flax

Use automasking in seq2seq example

 def encode_onehot(batch_inputs, max_len):    def encode_str(s):     tokens = CTABLE.encode(s)-    if len(tokens) > max_len:+    unpadded_len = len(tokens)+    if unpadded_len > max_len:       raise ValueError(f'Sequence too long ({len(tokens)}>{max_len}): \'{s}\'')     tokens = np.pad(tokens, [(0, max_len-len(tokens))], mode='constant')-    return onehot(tokens, CTABLE.vocab_size)+    return onehot(tokens, CTABLE.vocab_size), unpadded_len -  return np.array([encode_str(inp) for inp in batch_inputs])+  return np.array([encode_str(inp) for inp in batch_inputs], dtype=object)

What is the point of wrapping this in a np.array(..., dtype=object)? Don't the below uses work just as well if this were just a list?

marcvanzee

comment created time in 23 days

Pull request review commentgoogle/flax

Use automasking in seq2seq example

 def compute_metrics(logits, labels):   return metrics  +IN_SHAPES = [{'query': '(n, _)', 'answer': '(m, _)'}]+OUT_ELEM = f'(m +- 1, {CTABLE.vocab_size})'

curious: what does m +- 1 syntax mean here?

marcvanzee

comment created time in 23 days

PullRequestReviewEvent
PullRequestReviewEvent

startedgoogle-research/sputnik

started time in 24 days

Pull request review commenthuggingface/transformers

Integrate Bert-like model on Flax runtime.

+from typing import Callable, Dict++import numpy as np++import flax.nn as nn+import jax+import jax.numpy as jnp+from transformers import BertConfig+from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP+from transformers.modeling_flax_utils import FlaxPreTrainedModel+++@jax.jit+def gelu(x):+    r"""Gaussian error linear unit activation function.++    Computes the element-wise function:++    .. math::+      \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(+        \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)++    We explicitly use the approximation rather than the exact formulation for+    speed. For more information, see `Gaussian Error Linear Units (GELUs)+    <https://arxiv.org/abs/1606.08415>`_, section 2.+    """+    return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0)))+++ACT2FN = {+    "gelu": gelu,+    "relu": nn.relu,+    "swish": nn.swish,+    "gelu_new": gelu,+}+++class BertLayerNorm(nn.Module):+    """Layer normalization (https://arxiv.org/abs/1607.06450).+    Operates on the last axis of the input data.+    """++    def apply(+        self,+        x,+        epsilon=1e-6,+        dtype=jnp.float32,+        bias=True,+        scale=True,+        bias_init=nn.initializers.zeros,+        scale_init=nn.initializers.ones,+    ):+        """Applies layer normalization on the input.+        It normalizes the activations of the layer for each given example in a+        batch independently, rather than across a batch like Batch Normalization.+        i.e. applies a transformation that maintains the mean activation within+        each example close to 0 and the activation standard deviation close to 1.+        Args:+          x: the inputs+          epsilon: A small float added to variance to avoid dividing by zero.+          dtype: the dtype of the computation (default: float32).+          bias:  If True, bias (beta) is added.+          scale: If True, multiply by scale (gamma). When the next layer is linear+            (also e.g. nn.relu), this can be disabled since the scaling will be done+            by the next layer.+          bias_init: Initializer for bias, by default, zero.+          scale_init: Initializer for scale, by default, one.+        Returns:+          Normalized inputs (the same shape as inputs).+        """+        features = x.shape[-1]+        mean = jnp.mean(x, axis=-1, keepdims=True)+        mean2 = jnp.mean(jnp.lax.square(x), axis=-1, keepdims=True)+        var = mean2 - jnp.lax.square(mean)+        mul = jnp.lax.rsqrt(var + epsilon)+        if scale:+            mul = mul * jnp.asarray(self.param("gamma", (features,), scale_init), dtype)+        y = (x - mean) * mul+        if bias:+            y = y + jnp.asarray(self.param("beta", (features,), bias_init), dtype)+        return y+++class BertEmbedding(nn.Module):+    """+    Specify a new class for doing the embedding stuff+    as Flax's one use 'embedding' for the parameter name+    and PyTorch use 'weight'+    """++    def apply(+        self,+        input,+        vocab_size: int,+        hidden_size: int,+        emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1),+    ):++        embedding = self.param("weight", (vocab_size, hidden_size), emb_init)+        return jnp.take(embedding, input, axis=0)+++class BertEmbeddings(nn.Module):+    def apply(+        self,+        input_ids,+        token_type_ids,+        position_ids,+        attention_mask,+        vocab_size: int,+        hidden_size: int,+        type_vocab_size: int,+        max_length: int,+    ):++        # Embed+        w_emb = BertEmbedding(jnp.atleast_2d(input_ids.astype("i4")), vocab_size, hidden_size, name="word_embeddings")+        p_emb = BertEmbedding(+            jnp.atleast_2d(position_ids.astype("i4")), max_length, hidden_size, name="position_embeddings"+        )+        t_emb = BertEmbedding(+            jnp.atleast_2d(token_type_ids.astype("i4")), type_vocab_size, hidden_size, name="token_type_embeddings"+        )++        # Sum all embeddings+        summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb++        # Layer Norm+        norm = BertLayerNorm(summed_emb, name="layer_norm")++        return norm+++class BertAttention(nn.Module):+    def apply(self, hidden_state, attention_mask, num_heads: int, head_size: int):+        self_att = nn.attention.SelfAttention(+            hidden_state, num_heads=num_heads, qkv_features=head_size, padding_mask=attention_mask, name="self"+        )++        return BertLayerNorm(self_att + hidden_state, name="layer_norm")+++class BertIntermediate(nn.Module):+    def apply(self, hidden_state, output_size: int):+        # TODO: Add ACT2FN reference to change activation function+        h = nn.Dense(hidden_state, features=output_size, name="dense")+        return gelu(h)+++class BertOutput(nn.Module):+    def apply(self, intermediate_output, attention_output):+        h = nn.Dense(intermediate_output, attention_output.shape[-1], name="dense")+        h = BertLayerNorm(h + attention_output, name="layer_norm")++        return h+++class BertLayer(nn.Module):+    def apply(self, hidden_state, attention_mask, num_heads: int, head_size: int, intermediate_size: int):+        attention = BertAttention(hidden_state, attention_mask, num_heads, head_size, name="attention")+        intermediate = BertIntermediate(attention, intermediate_size, name="intermediate")+        output = BertOutput(intermediate, attention, name="output")++        return output+++class BertLayerCollection(nn.Module):+    """+    Stores N BertLayer(s)+    """++    def apply(self, input, attention_mask, num_layers: int, num_heads: int, head_size: int, intermediate_size: int):+        assert num_layers > 0, "num_layers should be >= 1, got ({})".format(num_layers)++        # Initialize input / output+        input_i = output_i = input++        # Forward over all encoders+        for i in range(num_layers):+            output_i = BertLayer(input_i, attention_mask, num_heads, head_size, intermediate_size, name="{}".format(i))+            input_i = output_i+        return output_i+++class BertEncoder(nn.Module):+    def apply(+        self, hidden_state, attention_mask, num_layers: int, num_heads: int, head_size: int, intermediate_size: int+    ):+        encodings = BertLayerCollection(+            hidden_state, attention_mask, num_layers, num_heads, head_size, intermediate_size, name="layer"+        )+        return encodings+++class BertPooler(nn.Module):+    def apply(self, hidden_state):+        first_token = hidden_state[:, 0]+        out = nn.Dense(first_token, hidden_state.shape[-1], name="dense")+        return jax.lax.tanh(out)+++class BertModel(nn.Module):+    def apply(+        self,+        input_ids,+        token_type_ids,+        position_ids,+        attention_mask,+        vocab_size: int,+        hidden_size: int,+        type_vocab_size: int,+        max_length: int,+        num_encoder_layers: int,+        num_heads: int,+        head_size: int,+        intermediate_size: int,+        padding_idx: int,+        emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1),+    ):++        # Embedding+        embeddings = BertEmbeddings(+            input_ids,+            token_type_ids,+            position_ids,+            attention_mask,+            vocab_size,+            hidden_size,+            type_vocab_size,+            max_length,+            name="embeddings",+        )++        # N stacked encoding layers+        encoder = BertEncoder(+            embeddings, attention_mask, num_encoder_layers, num_heads, head_size, intermediate_size, name="encoder"+        )++        pooled = BertPooler(encoder, name="pooler")+        return encoder, pooled+++class FlaxBertModel(FlaxPreTrainedModel):+    """+    BERT implementation using JAX/Flax as backend+    """++    model_class = BertModel+    config_class = BertConfig+    pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP+    base_model_prefix = "bert"++    @staticmethod+    def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:+        jax_state = dict(pt_state)++        # Need to change some parameters name to match Flax names so that we don't have to fork any layer+        for key, tensor in pt_state.items():+            # Key parts+            key_parts = set(key.split("."))++            # Every dense layer have a "kernel" parameters instead of "weight"+            if "dense.weight" in key:+                del jax_state[key]+                key = key.replace("weight", "kernel")+                jax_state[key] = tensor++            # SelfAttention needs also to replace "weight" by "kernel"+            if {"query", "key", "value"} & key_parts:++                # Flax SelfAttention decomposes the heads (num_head, size // num_heads)+                if "bias" in key:+                    jax_state[key] = tensor.reshape((config.num_attention_heads, -1))+                elif "weight":+                    del jax_state[key]+                    key = key.replace("weight", "kernel")+                    tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))+                    jax_state[key] = tensor++            # SelfAttention output is not a separate layer, remove one nesting+            if "attention.output.dense" in key:+                del jax_state[key]+                key = key.replace("attention.output.dense", "attention.self.out")+                jax_state[key] = tensor++            # SelfAttention output is not a separate layer, remove nesting on layer norm+            if "attention.output.LayerNorm" in key:+                del jax_state[key]+                key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")+                jax_state[key] = tensor++            # There are some transposed parameters w.r.t their PyTorch counterpart+            if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:+                jax_state[key] = tensor.T++            # Self Attention output projection needs to be transposed+            if "out.kernel" in key:+                jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(+                    1, 2, 0+                )++            # Pooler needs to transpose its kernel+            if "pooler.dense.kernel" in key:+                jax_state[key] = tensor.T++            # Handle LayerNorm conversion+            if "LayerNorm" in key:+                del jax_state[key]++                # Replace LayerNorm by layer_norm+                new_key = key.replace("LayerNorm", "layer_norm")++                if "weight" in key:+                    new_key = new_key.replace("weight", "gamma")+                elif "bias" in key:+                    new_key = new_key.replace("bias", "beta")++                jax_state[new_key] = tensor++        return jax_state++    def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs):+        model_def = BertModel.partial(+            vocab_size=config.vocab_size,+            hidden_size=config.hidden_size,+            type_vocab_size=config.type_vocab_size,+            max_length=config.max_position_embeddings,+            num_encoder_layers=config.num_hidden_layers,+            num_heads=config.num_attention_heads,+            head_size=config.hidden_size,+            intermediate_size=config.intermediate_size,+            padding_idx=config.pad_token_id,+        )++        super().__init__(config, model_def, state, seed)++    @property+    def module(self) -> BertModel:+        return self._module++    @property+    def config(self) -> BertConfig:+        return self._config++    def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):+        @jax.jit

yeah you definitely don't want to use jit inside a function like this, unless perhaps it's only called once anyway.

mfuntowicz

comment created time in 25 days

Pull request review commenthuggingface/transformers

Integrate Bert-like model on Flax runtime.

     _tf_available = False  # pylint: disable=invalid-name  +try:+    USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()++    if USE_JAX in ENV_VARS_TRUE_VALUES:+        import flax+        import jax+        from jax.config import config+        # TODO(marcvanzee): Flax Linen requires JAX omnistaging. Remove this +        # once JAX enables it by default.+        config.enable_omnistaging()

If we cut a new release of flax and pin to newer jax/flax versions this is no longer necessary as it's now (recently) the default.

mfuntowicz

comment created time in 25 days

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

issue commentgoogle/jax

Support autodiff of Eigendecomposition with repeated eigenvalues

As you note the general problem is quite tricky. To use physics parlance, there are two cases: you can have a matrix w. degenerate eigenvalues where the perturbation (gradient direction) "breaks the symmetry" and causes the degenerate eigenvalues to split, and then you have the case where the perturbation preserves the degeneracy... which generally makes talking about eigenvector derivatives very tricky / ill-defined with simple approaches. Especially if you're dealing with the general complex case where the eigenvector phase has additional freedom. There are a few papers that seem to offer general algorithmic approaches, but they're complicated enough that no one has sat down to try to implement them to see how they'd work...

sethaxen

comment created time in a month

pull request commentgoogle/flax

design unit tests

Ah you're right, I think there's some other test-runner that I used that actually runs things as main that bit me once a long while ago. ;)

jheek

comment created time in a month

PullRequestReviewEvent

Pull request review commentgoogle/flax

core docstrings

 def find_length(axis, x):       return ()     # split rngs     lengths = jax.tree_multimap(find_length, in_axes, args)-    if length is None:-      d_length, = set(jax.tree_leaves(lengths))+    lengths = set(jax.tree_leaves(lengths))+    if length is None and len(lengths) == 1:+      d_length, = lengths+    elif len(lengths) > 1:+      raise ValueError(f'Incosistent scan lengths: {lengths}')

typo: Inconsistent

jheek

comment created time in a month

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

Pull request review commentgoogle/flax

Examples : Hide GPUs from TensorFlow and use default v2.

  import numpy as onp +import tensorflow as tf+tf.config.experimental.set_visible_devices([], "GPU")

ditto

andsteing

comment created time in a month

Pull request review commentgoogle/flax

Examples : Hide GPUs from TensorFlow and use default v2.

  import numpy as onp +import tensorflow as tf+tf.config.experimental.set_visible_devices([], "GPU")

same comment as above - needs to go in main file, not lib file.

andsteing

comment created time in a month

Pull request review commentgoogle/flax

Examples : Hide GPUs from TensorFlow and use default v2.

 import jax.numpy as jnp  import tensorflow as tf+tf.config.experimental.set_visible_devices([], "GPU")

Realized after seeing some jax2tf test failures - We have to only call this after absl flag parsing / absl lib init happens so this should instead be put in "imagenet_main.py" and not a the example training lib file.

andsteing

comment created time in a month

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

issue closedgoogle/flax

I think linen_linear_test.py is outdated

Hi! I was trying to use the nn.Embed part from here https://github.com/google/flax/blob/0132b3f234a9868b47df491efde870bdc58e97a9/tests/linen/linen_linear_test.py and it's outdated,

Thanks!

closed time in a month

LysSanzMoreta

issue commentgoogle/flax

I think linen_linear_test.py is outdated

I might close this issue for the moment (with the understanding that we will be trying to improve documentation for the linen api over the coming weeks.). If you have further questions/concerns about this feel free to reopen the issue.

LysSanzMoreta

comment created time in a month

issue commentgoogle/flax

I think linen_linear_test.py is outdated

init_with_output is a supported, canonical function - it's only used when you want -both- the result of the module computation as well as its initialized variables - it's the sort of thing you often want to do in test code. For a more canonical usage-example of Embed you can look at the "wmt" example that implements a transformer encoder-decoder with embeddings.

LysSanzMoreta

comment created time in a month

issue closedgoogle/flax

vmap-friendly Conv

We should make a vmap-friendly conv, i.e. a Conv that can work on single-examples (XLA's Conv assumes a batch dimension).

Perhaps something like this modified version of nn.Conv that detects the single-example case, adds a singleton batch dimension, then removes it at the end:

from jax import lax
class Conv(nn.base.Module):
  def apply(self,
            inputs,
            features,
            kernel_size,
            strides=None,
            padding='SAME',
            input_dilation=None,
            kernel_dilation=None,
            feature_group_count=1,
            bias=True,
            dtype=jnp.float32,
            precision=None,
            kernel_init=nn.linear.default_kernel_init,
            bias_init=nn.initializers.zeros):
    inputs = jnp.asarray(inputs, dtype)
    # For single-example inputs, fake the batch dimension.
    single_example_conv = False
    if inputs.ndim == len(kernel_size) + 1:
      single_example_conv = True
      inputs = jnp.expand_dims(inputs, axis=0)

    if strides is None:
      strides = (1,) * (inputs.ndim - 2)

    in_features = inputs.shape[-1]
    assert in_features % feature_group_count == 0
    kernel_shape = kernel_size + (in_features // feature_group_count, features)
    kernel = self.param('kernel', kernel_shape, kernel_init)
    kernel = jnp.asarray(kernel, dtype)

    dimension_numbers = nn.linear._conv_dimension_numbers(inputs.shape)

    y = lax.conv_general_dilated(
        inputs,
        kernel,
        strides,
        padding,
        lhs_dilation=input_dilation,
        rhs_dilation=kernel_dilation,
        dimension_numbers=dimension_numbers,
        feature_group_count=feature_group_count,
        precision=precision)
    # Remove the added batch-dimension for single-example inputs.
    if single_example_conv:
      y = jnp.squeeze(y, axis=0)
    if bias:
      bias = self.param('bias', (features,), bias_init)
      bias = jnp.asarray(bias, dtype)
      y = y + bias
    return y

closed time in a month

levskaya

issue commentgoogle/flax

vmap-friendly Conv

Implemented by #444

levskaya

comment created time in a month

issue commentgoogle/flax

I think linen_linear_test.py is outdated

@LysSanzMoreta could you provide more detail on what's wrong? I just double-checked that code and nothing seems wrong there? This test is invoking the nn.Embed layer in a very unusual way for testing (using an initializer that returns a constant embed table for testing) that you wouldn't normally do.

LysSanzMoreta

comment created time in a month

Pull request review commentgoogle/flax

Update attention layers with simpler flat forms, update wmt example.

  PRNGKey = Any Shape = Tuple[int]-Dtype = Any  # this could be a real type?+Dtype = Any Array = Any  -def dot_product_attention(query,-                          key,-                          value,-                          dtype=jnp.float32,-                          bias=None,-                          axis=None,-                          broadcast_dropout=True,-                          dropout_rng=None,-                          dropout_rate=0.,-                          deterministic=False,-                          precision=None):+def dot_product_attention(query: Array,+                          key: Array,+                          value: Array,+                          bias: Optional[Array] = None,+                          broadcast_dropout: bool = True,

I agree that we should decouple it more. it's already a dynamic attribute but the expected signature of the Callable is too specialized, it should just take q,k,v,bias. I think I'd rather do that change in a separate PR though.

levskaya

comment created time in 2 months

PullRequestReviewEvent

Pull request review commentgoogle/flax

Update attention layers with simpler flat forms, update wmt example.

 def __call__(self,         query,         key,         value,-        dtype=self.dtype,-        axis=attention_axis,         bias=attention_bias,-        precision=self.precision,         dropout_rng=dropout_rng,         dropout_rate=self.dropout_rate,         broadcast_dropout=self.broadcast_dropout,-        deterministic=self.deterministic)+        deterministic=self.deterministic,+        dtype=self.dtype,+        precision=self.precision)      # back to the original inputs dimensions     out = DenseGeneral(features=features,

Hmm, I think the use of DenseGeneral here and above is nice, actually. What you often do otherwise is define two "split_heads", "join_heads" util functions that do the two reshapes... and DenseGeneral is sort of just a nice concise way to do that. Since we used it in mlperf I also know that it performs fine.

levskaya

comment created time in 2 months

PullRequestReviewEvent

issue openedgoogle/flax

Recent TF breaks JAX on GPU by eating all memory

Some update to TF is again causing it to preemptively grab all memory on GPU. This manifests with strange errors such as: CUBLAS_STATUS_NOT_INITIALIZED

the workaround requires declaring explicitly to TF that GPUs are off-limits, we should add this to all of our examples:

import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")

created time in 2 months

PullRequestReviewEvent

push eventlevskaya/flax

Anselm Levskaya

commit sha c312382d250080d3247ade4cb1791f7781914af1

Update attention layers with simpler flat forms, update wmt example.

view details

push time in 2 months

push eventlevskaya/flax

Anselm Levskaya

commit sha 5854942a7ee60e652fabe9641b2d26c338b82a6b

Update attention layers with simpler flat forms, update wmt example.

view details

push time in 2 months

push eventlevskaya/flax

Anselm Levskaya

commit sha 5989f3d83e4c4d737eb0ab1ead40232a3f6eb846

Update attention layers with simpler flat forms, update wmt example.

view details

push time in 2 months

push eventlevskaya/flax

Anselm Levskaya

commit sha 0cf0c13b565dfc73271206653edd8a114d58f0ad

Update attention layers with simpler flat forms, update wmt example.

view details

push time in 2 months

Pull request review commentgoogle/flax

add scope aliasing to lift.pack

  ScanVariableMode = Union[str, Tuple[str, str]] ++def _dedup_scopes(scopes):+  paths = []+  minimal_set = set(scopes)+  for leave in scopes:

I think you mean "leaf"?

jheek

comment created time in 2 months

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

PR opened google/flax

Update attention layers with simpler flat forms, update wmt example.

refactoring wmt example in progress.

+240 -404

0 comment

5 changed files

pr created time in 2 months

create barnchlevskaya/flax

branch : linen_wmt

created branch time in 2 months

PR opened google/flax

Update run_all_tests.sh to run linen_examples tests.

run_all_tests wasn't running linen_examples tests, this fixes that.

+5 -0

0 comment

1 changed file

pr created time in 2 months

create barnchgoogle/flax

branch : levskaya-patch-1

created branch time in 2 months

issue commentgoogle/flax

JAX update broke my random initialization of a Flax network

Sorry about this! A fairly large internal change to JAX has happened recently in how jit compilation works, a behavior referred to as "omnistaging". It mostly doesn't affect Flax, but in some cases partial evaluation (what init_by_shape uses) can break - we're still trying to debug why this is happening.

For now, could you try enabling the new JAX behavior with jax.config.enable_omnistaging() at the beginning of your imports and seeing if the problem persists?

stanislavfort

comment created time in 2 months

startedopen-sdr/openwifi

started time in 2 months

PullRequestReviewEvent

create barnchlevskaya/flax

branch : harmonized-scan

created branch time in 2 months

push eventlevskaya/flax

push time in 2 months

push eventlevskaya/flax

Anselm Levskaya

commit sha 979a5b98f5ded03c2d7c3106787c78a3ac85fcb5

remove tie_in from linen

view details

Anselm Levskaya

commit sha f8fac0eac5965a3d296cc4dd5439849982614f03

Remove obsolete python3 linting directives.

view details

Anselm Levskaya

commit sha 14f32472551c07d526d020d8e533705f05f2a076

Update attention layers with simpler flat forms, update wmt example.

view details

push time in 2 months

issue commentgoogle/jax

Add ability to place computation on a specific GPU

@BoyuanJackChen - there's nothing special about flax optimizers, they just operate on jax arrays like any other computation, so you should be able to target them if needed.

afrozenator

comment created time in 2 months

push eventlevskaya/flax

Anselm Levskaya

commit sha c82fa6da2f697c398e593234c80320bb7e2e530d

Add automatic named call wrapping for profiling. Because named call wraps every method in a model, out of paranoia we keep it off by default to lower python overheads until we've measured the cost of having it on, but one can enable this mode by calling nn.enable_named_call() and later disable it with nn.disable_named_call(). The wrapping occurs at subclass creation time.

view details

push time in 2 months

more