profile
viewpoint

google/jax 6843

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

hawkinsp/ZTopo 5

Topographic Map Viewer

hawkinsp/legion 0

The Legion Parallel Programming System

hawkinsp/numpy 0

The fundamental package for scientific computing with Python.

hawkinsp/pybind11 0

Seamless operability between C++11 and Python

hawkinsp/tensorflow 0

Computation using data flow graphs for scalable machine learning

PR opened google/jax

Fix abstract evaluation rule for lax.top_k.
+11 -3

0 comment

2 changed files

pr created time in 16 hours

push eventhawkinsp/jax

Peter Hawkins

commit sha 067e45fe9f3cff50f781b431014cdc5f10fc6a4a

Fix abstract evaluation rule for lax.top_k.

view details

push time in 16 hours

push eventhawkinsp/jax

Peter Hawkins

commit sha 5f77202d7ce6a7be79766b2683ca7ab1e57e4c86

Fix abstract evaluation rule for lax.top_k.

view details

push time in 16 hours

push eventhawkinsp/jax

Peter Hawkins

commit sha af0967fdbf1960d4f830c888103aa8624479c23d

Add an experimental lax.top_k operator. (#2280)

view details

Stephan Hoyer

commit sha 8c3e3b2dae0e4e9b7f6064499317349da3e57a70

Always jit scipy.ndimage.map_coordinates (#2286) Fixes GH2282

view details

Peter Hawkins

commit sha 80abdf0c5307f4b917c281428009e32c66f9f1a9

Unbreak build and update XLA. (#2289) * raise minimum Bazel version to 2.0.0 to match TensorFlow. * set --experimental_repo_remote_exec since it is required by the TF build. * bump TF/XLA version. * use the --config=short_logs trick from TF to suppress build warnings.

view details

Peter Hawkins

commit sha 05fd2037d44db6117525ba43098150b9292ea41a

Fix abstract evaluation rule for lax.top_k.

view details

push time in 16 hours

push eventgoogle/jax

Peter Hawkins

commit sha 80abdf0c5307f4b917c281428009e32c66f9f1a9

Unbreak build and update XLA. (#2289) * raise minimum Bazel version to 2.0.0 to match TensorFlow. * set --experimental_repo_remote_exec since it is required by the TF build. * bump TF/XLA version. * use the --config=short_logs trick from TF to suppress build warnings.

view details

push time in 16 hours

PR merged google/jax

Unbreak build and update XLA. cla: yes
  • raise minimum Bazel version to 2.0.0 to match TensorFlow.
  • set --experimental_repo_remote_exec since it is required by the TF build.
  • bump TF/XLA version.
  • use the --config=short_logs trick from TF to suppress build warnings.

@skye as an FYI.

+16 -9

0 comment

2 changed files

hawkinsp

pr closed time in 16 hours

PR opened google/jax

Unbreak build and update XLA.
  • raise minimum Bazel version to 2.0.0 to match TensorFlow.
  • set --experimental_repo_remote_exec since it is required by the TF build.
  • bump TF/XLA version.
  • use the --config=short_logs trick from TF to suppress build warnings.

@skye as an FYI.

+16 -9

0 comment

2 changed files

pr created time in 16 hours

push eventhawkinsp/jax

Roy Frostig

commit sha afb8af19ff7474561c3c904e03c63dbf8f57de3f

implement JVP of while loop Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

Peter Hawkins

commit sha 653001aa64f8de1a6b020a4c4bf17949cea0584b

Update references to bazel repositories in WORKSPACE to match TF head. (#2005)

view details

Stephan Hoyer

commit sha a5644edbbcbd4093cc8cfa993145f471336fa0b6

Defer to unrecognized types in arithmetic (#1942) This is useful for building higher level array libraries around JAX, because it makes it possible to override operations like `jax_array + other`. I think I covered all the array types that JAX should be able to handle: - Python builtin numbers int, float and complex - NumPy scalars - NumPy arrays - JAX array types and tracers Did I miss anything? Maybe bfloat16 scalars?

view details

Trevor Cai

commit sha 12975bbcc84ae0303f243b3ed727340bbd514321

[pmap] Add support for nested pmaps on multihost platforms via axis_size (#2002) One issue with nested pmaps on multihost platforms is inferring the global pmap axis size without communication. This commit sidesteps the issue by adding an `axis_size` argument to manually provide this information. This change only enables a single cross-host pmap; all inner pmaps must be single-host. Addressing: #1753

view details

Srinivas Vasudevan

commit sha 80b35dd4e52e0b19e3eaf555c8b8fbdf8ea3e9e8

Add betainc to JAX (#1998) Adds betaln, a wrapper for the Beta function (scipy.special.betaln).

view details

Julius Kunze

commit sha 55c971e47fbfad18bdb2248efc72249221bffae9

Implement shapecheck for more primitives (#1990) * shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals * WIP shapecheck np.pad * Implement shapecheck of gather, pad * Fix shapecheck of pad * Implement polymorphic shape rule for (strided/dilated) convolution, refactor * Cleanup * Fix * Remove all polymorphic shape rules, reuse shape rules instead. * Register shape_rule for all standard_primitives * Remove ShapeExpr, canonicalize_poly, renames * Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes * Allow Poly of form d*poly + k to be divided by d * Fix bug, inline poly_without_zeros.

view details

Roy Frostig

commit sha 335ecb97b838ea0185c26ec744d63c3a096a858b

test JVP of while loop, and fix the nonzero tangent calculation in the JVP rule

view details

Roy Frostig

commit sha 28f70cc8f8ac4fc4b98ce85f51b389cec8635704

Merge pull request #1980 from google/jvp-while implement JVP of while loop. closes #650

view details

Surya Bulusu

commit sha 71323b5d023a33bf8c06d435c4a6e406dea3c0a8

changes loop_mjp(f, x, M) (#2013) a minor change: we iterate over M and not S

view details

Jamie Townsend

commit sha 3974df0aeeeb89d74ee6832894ab153406626266

[docs] Pmap compiles functions with XLA (#2021)

view details

Jamie Townsend

commit sha 371001aad1d6235e86e8825286b7fa2eb1f0fdb2

Fix README typo (#2020)

view details

Skye Wanderman-Milne

commit sha 19fb494adbd4cefc984f08edf1982e644fe4a2b1

Add jax changelog (#2022)

view details

Peter Hawkins

commit sha 7dbc8dc1bc1bec13408c02d9415f001eabc7595e

Minimal changes to make Jax pass a pytype check. (#2024)

view details

Matthew Johnson

commit sha 17b5fe11d03e35bcebe662fc14d2decaeca81b95

add test for issue #553

view details

Mu Li

commit sha 2c80cd3d88c7393a9bfaa80da626be1a868c0049

Fix Sysml paper link in README (#2036) The original ULR was broken as sysml updated their links.

view details

Skye Wanderman-Milne

commit sha f04348ed534dc88e0041b753d369a3c7a39ec60b

Bump jaxlib version to 0.1.38 and update WORKSPACE.

view details

Sri Hari Krishna Narayanan

commit sha 03b2ae6d5907e76472c0d43e5d7793e80523bb95

Issue1635 expm (#1940) * Issue1635 expm Implemented expm using Pade approximation. The implmentation is wrapped using custom_transforms. Frechet derivatives are provided using defvjp. * Issue1635 expm Implemented expm using Pade approximation based on tf.linalg.expm. * Revert "Revert "Merge remote-tracking branch 'origin/Issue1635' into Issue1635"" This reverts commit dd26c6eeeb60fa556f55abc8acb2f5969b64a2f5, reversing changes made to b63c190c7671ebb9b911a52dcc203285c56a8051. * Issue1635 expm testing Add a test that compares numerical output of scipy.linalg.expm against jax.scipy.linalg.expm * travis build Issue1635 branch * Issue1635 expm testing Use rand_small to get numerical agreeming * Issue1635 expm testing Use @jit to prevent recompilation * Issue1635 expm testing Use rand_small to get numerical agreement * Revert "travis build Issue1635 branch" This reverts commit 6139772555e3af79dc0307fce88838a480e42d38. * Issue1635 Replace construct with jax.numpy.select * Issue1635 Restructure to support the docstring from SciPy * Issue1635 Restructure to support the docstring from SciPy * Issue1635 Remove the note that sparsity is not exploited because JAX does not support sparsity. * Issue1635 expm Support for the case where A is upper triangular. Instead of autodetection, the option is specified explicitly. * Issue1635 Rename argument, make it positional. Update documentation Co-authored-by: Jan <j.hueckelheim@imperial.ac.uk>

view details

Roy Frostig

commit sha 8449c4af9bd5c4d9f58a1232087eb0fd0d11f14e

implement JVP of cond Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

brett koonce

commit sha e18d697ac6250f6cf541a9ec8653312fc35dfd61

minor spelling tweaks (#2043)

view details

Matthew Johnson

commit sha 07260f6572ac436317558a4b78b4b0931e0b30ad

remove hasing methods from core.Literal (#2038)

view details

push time in 16 hours

Pull request review commenttensorflow/tensorflow

[Features] DLPack functions

+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.++Licensed under the Apache License, Version 2.0 (the "License");+you may not use this file except in compliance with the License.+You may obtain a copy of the License at++    http://www.apache.org/licenses/LICENSE-2.0++Unless required by applicable law or agreed to in writing, software+distributed under the License is distributed on an "AS IS" BASIS,+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+See the License for the specific language governing permissions and+limitations under the License.+==============================================================================*/++#include "tensorflow/c/eager/dlpack.h"+#include "include/dlpack/dlpack.h"  // TF:dlpack+#include "tensorflow/c/eager/c_api_internal.h"+#include "tensorflow/c/tf_status_helper.h"+#include "tensorflow/core/framework/tensor.h"+#include "tensorflow/core/platform/casts.h"++#include "tensorflow/core/framework/tensor_reference.h"+#include "tensorflow/core/platform/logging.h"++namespace tensorflow {++namespace {++struct TFDLManagedTensorCtx {+  TensorReference* handle;+  DLManagedTensor tensor;+};++const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {+  if (h == nullptr || !h->handle->IsValid(&status->status)) {+    status->status = tensorflow::errors::InvalidArgument(+        "The passed in handle is a nullptr");+    return nullptr;+  }+  tensorflow::TensorHandle* handle =+      tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())+          ->Handle();++  if (handle->IsRemote()) {+    status->status = tensorflow::errors::InvalidArgument(+        "DLPack doesn't support remote tensor");+    return nullptr;+  }+  const tensorflow::Tensor* tensor;+  status->status = handle->Tensor(&tensor);+  if (!status->status.ok()) {+    return nullptr;+  }+  return tensor;+};++void DLManagedTensorDeleter(DLManagedTensor* arg) {+  TFDLManagedTensorCtx* owner =+      static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);+  owner->handle->Unref();+  delete owner;+}++DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {+  DLDataType dtype;+  dtype.lanes = 1;+  dtype.bits = TF_DataTypeSize(data_type) * 8;+  switch (data_type) {+    case TF_DataType::TF_HALF:+    case TF_DataType::TF_FLOAT:+    case TF_DataType::TF_DOUBLE:+      dtype.code = DLDataTypeCode::kDLFloat;+      break;+    case TF_DataType::TF_INT8:+    case TF_DataType::TF_INT16:+    case TF_DataType::TF_INT32:+    case TF_DataType::TF_INT64:+      dtype.code = DLDataTypeCode::kDLInt;+      break;+    case TF_DataType::TF_BOOL:+    case TF_DataType::TF_UINT8:+    case TF_DataType::TF_UINT16:+    case TF_DataType::TF_UINT32:+    case TF_DataType::TF_UINT64:+      dtype.code = DLDataTypeCode::kDLUInt;+      break;+    case TF_DataType::TF_BFLOAT16:+      dtype.code = DLDataTypeCode::kDLBfloat;+      break;+    default:+      status->status = tensorflow::errors::InvalidArgument(+          DataType_Name(static_cast<DataType>(data_type)),+          " is not supported by dlpack");+      break;+  }+  return dtype;+}++DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {+  DLContext ctx;+  const char* device_name = h->handle->DeviceName(&status->status);+  DeviceNameUtils::ParsedName parsed_name;+  tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);+  std::string device_type = parsed_name.type;+  int device_id = -1;+  if (parsed_name.has_id) {+    device_id = parsed_name.id;+  }  // Question: Is it possible that it doens't have id?++  ctx.device_id = device_id;+  if (device_type == "CPU") {+    ctx.device_type = DLDeviceType::kDLCPU;+  } else if (device_type == "GPU") {+    ctx.device_type = DLDeviceType::kDLGPU;+  } else {+    status->status = tensorflow::errors::InvalidArgument(+        "Unsupported Device Type for dlpack");+  }++  return ctx;+}++DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,+                                                 TF_Status* status) {+  const Tensor* tensor = GetTensorFromHandle(h, status);+  TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());+  auto* tf_dlm_tensor_ctx = new TFDLManagedTensorCtx;++  TensorReference* tensor_ref =+      new TensorReference(*tensor);  // This will call buf_->Ref()+  tf_dlm_tensor_ctx->handle = tensor_ref;+  tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;+  tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;+  tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);+  int ndim = tensor->dims();+  tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;+  tf_dlm_tensor_ctx->tensor.dl_tensor.data =+      TFE_TensorHandleDevicePointer(h, status);+  tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);++  int64_t* shape_arr = new int64_t[ndim];+  for (int i = 0; i < ndim; i++) {+    shape_arr[i] = tensor->dim_size(i);+  }++  tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr;++  tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr;+  tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =+      0;  // TF doesn't handle the strides and byte_offsets here+  return &tf_dlm_tensor_ctx->tensor;+}++absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,+                                                    TF_Status* status) {+  switch (ctx.device_type) {+    case DLDeviceType::kDLCPU:+      return "CPU:0";+    case DLDeviceType::kDLGPU:+      return absl::StrCat("GPU:", ctx.device_id);+    default:+      return absl::nullopt;+  };+}+TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,+                                     TF_Status* status) {+  TF_DataType tf_dtype;+  switch (dtype.code) {+    case DLDataTypeCode::kDLUInt:+      switch (dtype.bits) {+        case 1:+          tf_dtype = TF_DataType::TF_BOOL;+          break;+        case 8:+          tf_dtype = TF_DataType::TF_UINT8;+          break;+        case 16:+          tf_dtype = TF_DataType::TF_UINT16;+          break;+        case 32:+          tf_dtype = TF_DataType::TF_UINT32;+          break;+        case 64:+          tf_dtype = TF_DataType::TF_UINT64;+          break;+        default:+          status->status = tensorflow::errors::InvalidArgument(+              "Unsupported UInt bits: ", dtype.bits);+      }+      break;+    case DLDataTypeCode::kDLInt:+      switch (dtype.bits) {+        case 8:+          tf_dtype = TF_DataType::TF_INT8;+          break;+        case 16:+          tf_dtype = TF_DataType::TF_INT16;+          break;+        case 32:+          tf_dtype = TF_DataType::TF_INT32;+          break;+        case 64:+          tf_dtype = TF_DataType::TF_INT64;+          break;+        default:+          status->status = tensorflow::errors::InvalidArgument(+              "Unsupported Int bits: ", dtype.bits);+      }+      break;+    case DLDataTypeCode::kDLFloat:+      switch (dtype.bits) {+        case 16:+          tf_dtype = TF_DataType::TF_HALF;+          break;+        case 32:+          tf_dtype = TF_DataType::TF_FLOAT;+          break;+        case 64:+          tf_dtype = TF_DataType::TF_DOUBLE;+          break;+        default:+          status->status = tensorflow::errors::InvalidArgument(+              "Unsupported Float bits: ", dtype.bits);+      }+      break;+    case DLDataTypeCode::kDLBfloat:+      switch (dtype.bits) {+        case 16:+          tf_dtype = TF_DataType::TF_BFLOAT16;+          break;+        default:+          status->status = tensorflow::errors::InvalidArgument(+              "Unsupported BFloat bits: ", dtype.bits);+      }+      break;+    default:+      status->status = tensorflow::errors::InvalidArgument(+          "Unsupported Type Codes: ", dtype.code);+  }++  return tf_dtype;+}++void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {+  DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);+  dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));+}++}  // namespace++void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {+  DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);+  if (dlMTensor) {+    dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));+  }+}++void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {+  DLManagedTensor* tfdlmtensor = TFEHandleToTFDLManagedTensorCtx(h, status);+  return static_cast<void*>(tfdlmtensor);+}++TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {+  TFE_ContextOptions* opts = TFE_NewContextOptions();+  TFE_Context* ctx = TFE_NewContext(opts, status);+  DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);++  absl::optional<std::string> device_name =+      DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status);+  if (!device_name.has_value()) {+    status->status =+        tensorflow::errors::InvalidArgument("Unsupported Device Type");+    return nullptr;+  }+  TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status);+  int num_dims = dlmt->dl_tensor.ndim;+  const int64_t* dims = dlmt->dl_tensor.shape;+  void* data = dlmt->dl_tensor.data;+

Here you are importing a DLPack tensor that may have non-trivial strides, even though TF does not support them. So you must do something, and I think the only reasonable thing to do is to detect that case and fail. (i.e., where the strides in the DLPack tensor do not imply the canonical major-to-minor order of elements that TF expects.)

VoVAllen

comment created time in a day

issue closedgoogle/jax

Jax index_update does not update

dt=1e-3
N=100
x0 = 1.5
v0 = 0


tab_x = np.zeros(N)
tab_v = np.zeros(N)
tab_t = np.arange(N)*dt
c=4*np.pi**2
x = x0
v = v0
for i in range(N):
  x = x+v*dt
  v = v-c*np.sin(x)*dt
  jax.ops.index_update(tab_x, i, x)
  jax.ops.index_update(tab_v, i, v)

closed time in a day

kasteiner

issue commentgoogle/jax

Jax index_update does not update

That's right. Please read: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-In-Place-Updates

It returns the updated array.

kasteiner

comment created time in a day

issue commentgoogle/jax

np.var does not match onp.var in some cases

I think we should do one of two things: (a) match NumPy, or (b) issue an error, if the behavior isn't sensible.

I can't particularly decide if the behavior of NumPy is sensible here; it might make more sense to issue an error.

yurodiviy

comment created time in 2 days

issue commentgoogle/jax

jax.numpy.squeeze: error for nonexistent / nonsingleton axis

Sound reasonable. Contributions welcome!

GeorgOstrovski

comment created time in 2 days

push eventgoogle/jax

Peter Hawkins

commit sha af0967fdbf1960d4f830c888103aa8624479c23d

Add an experimental lax.top_k operator. (#2280)

view details

push time in 2 days

PR merged google/jax

Add an experimental lax.top_k operator. cla: yes
+42 -1

0 comment

2 changed files

hawkinsp

pr closed time in 2 days

PR opened google/jax

Add an experimental lax.top_k operator.
+42 -1

0 comment

2 changed files

pr created time in 2 days

push eventhawkinsp/jax

Peter Hawkins

commit sha a8c0c7f7f62ffbe72966c0b5f37afa8dd2281506

Add an experimental lax.top_k operator.

view details

push time in 2 days

create barnchhawkinsp/jax

branch : topk

created branch time in 2 days

Pull request review commenttensorflow/tensorflow

[Features] DLPack functions

+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.++Licensed under the Apache License, Version 2.0 (the "License");+you may not use this file except in compliance with the License.+You may obtain a copy of the License at++    http://www.apache.org/licenses/LICENSE-2.0++Unless required by applicable law or agreed to in writing, software+distributed under the License is distributed on an "AS IS" BASIS,+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+See the License for the specific language governing permissions and+limitations under the License.+==============================================================================*/++#include "tensorflow/c/eager/dlpack.h"+#include "include/dlpack/dlpack.h"  // TF:dlpack+#include "tensorflow/c/eager/c_api_internal.h"+#include "tensorflow/c/tf_status_helper.h"+#include "tensorflow/core/framework/tensor.h"+#include "tensorflow/core/platform/casts.h"++#include "tensorflow/core/framework/tensor_reference.h"+#include "tensorflow/core/platform/logging.h"++namespace tensorflow {++namespace {++struct TFDLManagedTensorCtx {+  TensorReference* handle;+  DLManagedTensor tensor;+};++const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {+  if (h == nullptr || !h->handle->IsValid(&status->status)) {+    status->status = tensorflow::errors::InvalidArgument(+        "The passed in handle is a nullptr");+    return nullptr;+  }+  tensorflow::TensorHandle* handle =+      tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())+          ->Handle();++  if (handle->IsRemote()) {+    status->status = tensorflow::errors::InvalidArgument(+        "DLPack doesn't support remote tensor");+    return nullptr;+  }+  const tensorflow::Tensor* tensor;+  status->status = handle->Tensor(&tensor);+  if (!status->status.ok()) {+    return nullptr;+  }+  return tensor;+};++void DLManagedTensorDeleter(DLManagedTensor* arg) {+  TFDLManagedTensorCtx* owner =+      static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);+  owner->handle->Unref();+  delete owner;+}++DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {+  DLDataType dtype;+  dtype.lanes = 1;+  dtype.bits = TF_DataTypeSize(data_type) * 8;+  switch (data_type) {+    case TF_DataType::TF_HALF:+    case TF_DataType::TF_FLOAT:+    case TF_DataType::TF_DOUBLE:+      dtype.code = DLDataTypeCode::kDLFloat;+      break;+    case TF_DataType::TF_INT8:+    case TF_DataType::TF_INT16:+    case TF_DataType::TF_INT32:+    case TF_DataType::TF_INT64:+      dtype.code = DLDataTypeCode::kDLInt;+      break;+    case TF_DataType::TF_BOOL:+    case TF_DataType::TF_UINT8:+    case TF_DataType::TF_UINT16:+    case TF_DataType::TF_UINT32:+    case TF_DataType::TF_UINT64:+      dtype.code = DLDataTypeCode::kDLUInt;+      break;+    case TF_DataType::TF_BFLOAT16:+      dtype.code = DLDataTypeCode::kDLBfloat;+      break;+    default:+      status->status = tensorflow::errors::InvalidArgument(+          DataType_Name(static_cast<DataType>(data_type)),+          " is not supported by dlpack");+      break;+  }+  return dtype;+}++DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {+  DLContext ctx;+  const char* device_name = h->handle->DeviceName(&status->status);+  DeviceNameUtils::ParsedName parsed_name;+  tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);+  std::string device_type = parsed_name.type;+  int device_id = -1;+  if (parsed_name.has_id) {+    device_id = parsed_name.id;+  }  // Question: Is it possible that it doens't have id?++  ctx.device_id = device_id;+  if (device_type == "CPU") {+    ctx.device_type = DLDeviceType::kDLCPU;+  } else if (device_type == "GPU") {+    ctx.device_type = DLDeviceType::kDLGPU;+  } else {+    status->status = tensorflow::errors::InvalidArgument(+        "Unsupported Device Type for dlpack");+  }++  return ctx;+}++DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,+                                                 TF_Status* status) {+  const Tensor* tensor = GetTensorFromHandle(h, status);+  TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());+  auto* tf_dlm_tensor_ctx = new TFDLManagedTensorCtx;++  TensorReference* tensor_ref =+      new TensorReference(*tensor);  // This will call buf_->Ref()+  tf_dlm_tensor_ctx->handle = tensor_ref;+  tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;+  tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;+  tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);+  int ndim = tensor->dims();+  tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;+  tf_dlm_tensor_ctx->tensor.dl_tensor.data =+      TFE_TensorHandleDevicePointer(h, status);+  tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);++  int64_t* shape_arr = new int64_t[ndim];+  for (int i = 0; i < ndim; i++) {+    shape_arr[i] = tensor->dim_size(i);+  }++  tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr;++  tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr;+  tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =+      0;  // TF doesn't handle the strides and byte_offsets here+  return &tf_dlm_tensor_ctx->tensor;+}++absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,+                                                    TF_Status* status) {+  switch (ctx.device_type) {+    case DLDeviceType::kDLCPU:+      return "CPU:0";+    case DLDeviceType::kDLGPU:+      return absl::StrCat("GPU:", ctx.device_id);+    default:+      return absl::nullopt;+  };+}+TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,+                                     TF_Status* status) {+  TF_DataType tf_dtype;+  switch (dtype.code) {+    case DLDataTypeCode::kDLUInt:+      switch (dtype.bits) {+        case 1:

Isn't this size inconsistent between your export and import functions? Your export size of a bool is 8 bits, whereas your import size is 1 bit. So if nothing else, this won't round trip.

(I couldn't decide the right representation in JAX either, so I made import/export of booleans an error there.)

VoVAllen

comment created time in 3 days

Pull request review commenttensorflow/tensorflow

[Features] DLPack functions

+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.++Licensed under the Apache License, Version 2.0 (the "License");+you may not use this file except in compliance with the License.+You may obtain a copy of the License at++    http://www.apache.org/licenses/LICENSE-2.0++Unless required by applicable law or agreed to in writing, software+distributed under the License is distributed on an "AS IS" BASIS,+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+See the License for the specific language governing permissions and+limitations under the License.+==============================================================================*/++#include "tensorflow/c/eager/dlpack.h"+#include "include/dlpack/dlpack.h"  // TF:dlpack+#include "tensorflow/c/eager/c_api_internal.h"+#include "tensorflow/c/tf_status_helper.h"+#include "tensorflow/core/framework/tensor.h"+#include "tensorflow/core/platform/casts.h"++#include "tensorflow/core/framework/tensor_reference.h"+#include "tensorflow/core/platform/logging.h"++namespace tensorflow {++namespace {++struct TFDLManagedTensorCtx {+  TensorReference* handle;+  DLManagedTensor tensor;+};++const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {+  if (h == nullptr || !h->handle->IsValid(&status->status)) {+    status->status = tensorflow::errors::InvalidArgument(+        "The passed in handle is a nullptr");+    return nullptr;+  }+  tensorflow::TensorHandle* handle =+      tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())+          ->Handle();++  if (handle->IsRemote()) {+    status->status = tensorflow::errors::InvalidArgument(+        "DLPack doesn't support remote tensor");+    return nullptr;+  }+  const tensorflow::Tensor* tensor;+  status->status = handle->Tensor(&tensor);+  if (!status->status.ok()) {+    return nullptr;+  }+  return tensor;+};++void DLManagedTensorDeleter(DLManagedTensor* arg) {+  TFDLManagedTensorCtx* owner =+      static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);+  owner->handle->Unref();+  delete owner;+}++DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {+  DLDataType dtype;+  dtype.lanes = 1;+  dtype.bits = TF_DataTypeSize(data_type) * 8;+  switch (data_type) {+    case TF_DataType::TF_HALF:+    case TF_DataType::TF_FLOAT:+    case TF_DataType::TF_DOUBLE:+      dtype.code = DLDataTypeCode::kDLFloat;+      break;+    case TF_DataType::TF_INT8:+    case TF_DataType::TF_INT16:+    case TF_DataType::TF_INT32:+    case TF_DataType::TF_INT64:+      dtype.code = DLDataTypeCode::kDLInt;+      break;+    case TF_DataType::TF_BOOL:+    case TF_DataType::TF_UINT8:+    case TF_DataType::TF_UINT16:+    case TF_DataType::TF_UINT32:+    case TF_DataType::TF_UINT64:+      dtype.code = DLDataTypeCode::kDLUInt;+      break;+    case TF_DataType::TF_BFLOAT16:+      dtype.code = DLDataTypeCode::kDLBfloat;+      break;+    default:+      status->status = tensorflow::errors::InvalidArgument(+          DataType_Name(static_cast<DataType>(data_type)),+          " is not supported by dlpack");+      break;+  }+  return dtype;+}++DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {+  DLContext ctx;+  const char* device_name = h->handle->DeviceName(&status->status);+  DeviceNameUtils::ParsedName parsed_name;+  tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);+  std::string device_type = parsed_name.type;+  int device_id = -1;+  if (parsed_name.has_id) {+    device_id = parsed_name.id;+  }  // Question: Is it possible that it doens't have id?++  ctx.device_id = device_id;+  if (device_type == "CPU") {+    ctx.device_type = DLDeviceType::kDLCPU;+  } else if (device_type == "GPU") {+    ctx.device_type = DLDeviceType::kDLGPU;+  } else {+    status->status = tensorflow::errors::InvalidArgument(+        "Unsupported Device Type for dlpack");+  }++  return ctx;+}++DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,+                                                 TF_Status* status) {+  const Tensor* tensor = GetTensorFromHandle(h, status);+  TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());+  auto* tf_dlm_tensor_ctx = new TFDLManagedTensorCtx;++  TensorReference* tensor_ref =+      new TensorReference(*tensor);  // This will call buf_->Ref()+  tf_dlm_tensor_ctx->handle = tensor_ref;+  tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;+  tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;+  tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);+  int ndim = tensor->dims();+  tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;+  tf_dlm_tensor_ctx->tensor.dl_tensor.data =+      TFE_TensorHandleDevicePointer(h, status);+  tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);++  int64_t* shape_arr = new int64_t[ndim];+  for (int i = 0; i < ndim; i++) {+    shape_arr[i] = tensor->dim_size(i);+  }++  tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr;++  tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr;+  tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =+      0;  // TF doesn't handle the strides and byte_offsets here+  return &tf_dlm_tensor_ctx->tensor;+}++absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,+                                                    TF_Status* status) {+  switch (ctx.device_type) {+    case DLDeviceType::kDLCPU:+      return "CPU:0";+    case DLDeviceType::kDLGPU:+      return absl::StrCat("GPU:", ctx.device_id);+    default:+      return absl::nullopt;+  };+}+TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,+                                     TF_Status* status) {+  TF_DataType tf_dtype;+  switch (dtype.code) {+    case DLDataTypeCode::kDLUInt:+      switch (dtype.bits) {+        case 1:+          tf_dtype = TF_DataType::TF_BOOL;+          break;+        case 8:+          tf_dtype = TF_DataType::TF_UINT8;+          break;+        case 16:+          tf_dtype = TF_DataType::TF_UINT16;+          break;+        case 32:+          tf_dtype = TF_DataType::TF_UINT32;+          break;+        case 64:+          tf_dtype = TF_DataType::TF_UINT64;+          break;+        default:+          status->status = tensorflow::errors::InvalidArgument(+              "Unsupported UInt bits: ", dtype.bits);+      }+      break;+    case DLDataTypeCode::kDLInt:+      switch (dtype.bits) {+        case 8:+          tf_dtype = TF_DataType::TF_INT8;+          break;+        case 16:+          tf_dtype = TF_DataType::TF_INT16;+          break;+        case 32:+          tf_dtype = TF_DataType::TF_INT32;+          break;+        case 64:+          tf_dtype = TF_DataType::TF_INT64;+          break;+        default:+          status->status = tensorflow::errors::InvalidArgument(+              "Unsupported Int bits: ", dtype.bits);+      }+      break;+    case DLDataTypeCode::kDLFloat:+      switch (dtype.bits) {+        case 16:+          tf_dtype = TF_DataType::TF_HALF;+          break;+        case 32:+          tf_dtype = TF_DataType::TF_FLOAT;+          break;+        case 64:+          tf_dtype = TF_DataType::TF_DOUBLE;+          break;+        default:+          status->status = tensorflow::errors::InvalidArgument(+              "Unsupported Float bits: ", dtype.bits);+      }+      break;+    case DLDataTypeCode::kDLBfloat:+      switch (dtype.bits) {+        case 16:+          tf_dtype = TF_DataType::TF_BFLOAT16;+          break;+        default:+          status->status = tensorflow::errors::InvalidArgument(+              "Unsupported BFloat bits: ", dtype.bits);+      }+      break;+    default:+      status->status = tensorflow::errors::InvalidArgument(+          "Unsupported Type Codes: ", dtype.code);+  }++  return tf_dtype;+}++void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {+  DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);+  dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));+}++}  // namespace++void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {+  DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);+  if (dlMTensor) {+    dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));+  }+}++void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {+  DLManagedTensor* tfdlmtensor = TFEHandleToTFDLManagedTensorCtx(h, status);+  return static_cast<void*>(tfdlmtensor);+}++TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {+  TFE_ContextOptions* opts = TFE_NewContextOptions();+  TFE_Context* ctx = TFE_NewContext(opts, status);+  DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);++  absl::optional<std::string> device_name =+      DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status);+  if (!device_name.has_value()) {+    status->status =+        tensorflow::errors::InvalidArgument("Unsupported Device Type");+    return nullptr;+  }+  TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status);+  int num_dims = dlmt->dl_tensor.ndim;+  const int64_t* dims = dlmt->dl_tensor.shape;+  void* data = dlmt->dl_tensor.data;+

Don't you need to check that the strides are major-to-minor here if not nullptr?

VoVAllen

comment created time in 3 days

Pull request review commenttensorflow/tensorflow

[Features] DLPack functions

+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.++Licensed under the Apache License, Version 2.0 (the "License");+you may not use this file except in compliance with the License.+You may obtain a copy of the License at++    http://www.apache.org/licenses/LICENSE-2.0++Unless required by applicable law or agreed to in writing, software+distributed under the License is distributed on an "AS IS" BASIS,+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+See the License for the specific language governing permissions and+limitations under the License.+==============================================================================*/++#include "tensorflow/c/eager/dlpack.h"+#include "include/dlpack/dlpack.h"  // TF:dlpack+#include "tensorflow/c/eager/c_api_internal.h"+#include "tensorflow/c/tf_status_helper.h"+#include "tensorflow/core/framework/tensor.h"+#include "tensorflow/core/platform/casts.h"++#include "tensorflow/core/framework/tensor_reference.h"+#include "tensorflow/core/platform/logging.h"++namespace tensorflow {++namespace {++struct TFDLManagedTensorCtx {+  TensorReference* handle;+  DLManagedTensor tensor;+};++const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {+  if (h == nullptr || !h->handle->IsValid(&status->status)) {+    status->status = tensorflow::errors::InvalidArgument(+        "The passed in handle is a nullptr");+    return nullptr;+  }+  tensorflow::TensorHandle* handle =+      tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())+          ->Handle();++  if (handle->IsRemote()) {+    status->status = tensorflow::errors::InvalidArgument(+        "DLPack doesn't support remote tensor");+    return nullptr;+  }+  const tensorflow::Tensor* tensor;+  status->status = handle->Tensor(&tensor);+  if (!status->status.ok()) {+    return nullptr;+  }+  return tensor;+};++void DLManagedTensorDeleter(DLManagedTensor* arg) {+  TFDLManagedTensorCtx* owner =+      static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);+  owner->handle->Unref();+  delete owner;+}++DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {+  DLDataType dtype;+  dtype.lanes = 1;+  dtype.bits = TF_DataTypeSize(data_type) * 8;+  switch (data_type) {+    case TF_DataType::TF_HALF:+    case TF_DataType::TF_FLOAT:+    case TF_DataType::TF_DOUBLE:+      dtype.code = DLDataTypeCode::kDLFloat;+      break;+    case TF_DataType::TF_INT8:+    case TF_DataType::TF_INT16:+    case TF_DataType::TF_INT32:+    case TF_DataType::TF_INT64:+      dtype.code = DLDataTypeCode::kDLInt;+      break;+    case TF_DataType::TF_BOOL:+    case TF_DataType::TF_UINT8:+    case TF_DataType::TF_UINT16:+    case TF_DataType::TF_UINT32:+    case TF_DataType::TF_UINT64:+      dtype.code = DLDataTypeCode::kDLUInt;+      break;+    case TF_DataType::TF_BFLOAT16:+      dtype.code = DLDataTypeCode::kDLBfloat;+      break;+    default:+      status->status = tensorflow::errors::InvalidArgument(+          DataType_Name(static_cast<DataType>(data_type)),+          " is not supported by dlpack");+      break;+  }+  return dtype;+}++DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {+  DLContext ctx;+  const char* device_name = h->handle->DeviceName(&status->status);+  DeviceNameUtils::ParsedName parsed_name;+  tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);+  std::string device_type = parsed_name.type;+  int device_id = -1;+  if (parsed_name.has_id) {+    device_id = parsed_name.id;+  }  // Question: Is it possible that it doens't have id?++  ctx.device_id = device_id;+  if (device_type == "CPU") {+    ctx.device_type = DLDeviceType::kDLCPU;+  } else if (device_type == "GPU") {+    ctx.device_type = DLDeviceType::kDLGPU;+  } else {+    status->status = tensorflow::errors::InvalidArgument(+        "Unsupported Device Type for dlpack");+  }++  return ctx;+}++DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,+                                                 TF_Status* status) {+  const Tensor* tensor = GetTensorFromHandle(h, status);+  TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());+  auto* tf_dlm_tensor_ctx = new TFDLManagedTensorCtx;++  TensorReference* tensor_ref =+      new TensorReference(*tensor);  // This will call buf_->Ref()+  tf_dlm_tensor_ctx->handle = tensor_ref;+  tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;+  tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;+  tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);+  int ndim = tensor->dims();+  tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;+  tf_dlm_tensor_ctx->tensor.dl_tensor.data =+      TFE_TensorHandleDevicePointer(h, status);+  tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);++  int64_t* shape_arr = new int64_t[ndim];

This appears to be leaked? I don't see any logic to free the shape.

VoVAllen

comment created time in 3 days

Pull request review commenttensorflow/tensorflow

[Features] DLPack functions

+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.++Licensed under the Apache License, Version 2.0 (the "License");+you may not use this file except in compliance with the License.+You may obtain a copy of the License at++    http://www.apache.org/licenses/LICENSE-2.0++Unless required by applicable law or agreed to in writing, software+distributed under the License is distributed on an "AS IS" BASIS,+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+See the License for the specific language governing permissions and+limitations under the License.+==============================================================================*/++#include "tensorflow/c/eager/dlpack.h"+#include "include/dlpack/dlpack.h"  // TF:dlpack+#include "tensorflow/c/eager/c_api_internal.h"+#include "tensorflow/c/tf_status_helper.h"+#include "tensorflow/core/framework/tensor.h"+#include "tensorflow/core/platform/casts.h"++#include "tensorflow/core/framework/tensor_reference.h"+#include "tensorflow/core/platform/logging.h"++namespace tensorflow {++namespace {++struct TFDLManagedTensorCtx {+  TensorReference* handle;+  DLManagedTensor tensor;+};++const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {+  if (h == nullptr || !h->handle->IsValid(&status->status)) {+    status->status = tensorflow::errors::InvalidArgument(+        "The passed in handle is a nullptr");+    return nullptr;+  }+  tensorflow::TensorHandle* handle =+      tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())+          ->Handle();++  if (handle->IsRemote()) {+    status->status = tensorflow::errors::InvalidArgument(+        "DLPack doesn't support remote tensor");+    return nullptr;+  }+  const tensorflow::Tensor* tensor;+  status->status = handle->Tensor(&tensor);+  if (!status->status.ok()) {+    return nullptr;+  }+  return tensor;+};++void DLManagedTensorDeleter(DLManagedTensor* arg) {+  TFDLManagedTensorCtx* owner =+      static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);+  owner->handle->Unref();+  delete owner;+}++DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {+  DLDataType dtype;+  dtype.lanes = 1;+  dtype.bits = TF_DataTypeSize(data_type) * 8;+  switch (data_type) {+    case TF_DataType::TF_HALF:+    case TF_DataType::TF_FLOAT:+    case TF_DataType::TF_DOUBLE:+      dtype.code = DLDataTypeCode::kDLFloat;+      break;+    case TF_DataType::TF_INT8:+    case TF_DataType::TF_INT16:+    case TF_DataType::TF_INT32:+    case TF_DataType::TF_INT64:+      dtype.code = DLDataTypeCode::kDLInt;+      break;+    case TF_DataType::TF_BOOL:+    case TF_DataType::TF_UINT8:+    case TF_DataType::TF_UINT16:+    case TF_DataType::TF_UINT32:+    case TF_DataType::TF_UINT64:+      dtype.code = DLDataTypeCode::kDLUInt;+      break;+    case TF_DataType::TF_BFLOAT16:+      dtype.code = DLDataTypeCode::kDLBfloat;+      break;+    default:+      status->status = tensorflow::errors::InvalidArgument(+          DataType_Name(static_cast<DataType>(data_type)),+          " is not supported by dlpack");+      break;+  }+  return dtype;+}++DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {+  DLContext ctx;+  const char* device_name = h->handle->DeviceName(&status->status);+  DeviceNameUtils::ParsedName parsed_name;+  tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);+  std::string device_type = parsed_name.type;+  int device_id = -1;+  if (parsed_name.has_id) {+    device_id = parsed_name.id;+  }  // Question: Is it possible that it doens't have id?++  ctx.device_id = device_id;+  if (device_type == "CPU") {+    ctx.device_type = DLDeviceType::kDLCPU;+  } else if (device_type == "GPU") {+    ctx.device_type = DLDeviceType::kDLGPU;+  } else {+    status->status = tensorflow::errors::InvalidArgument(+        "Unsupported Device Type for dlpack");+  }++  return ctx;+}++DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,+                                                 TF_Status* status) {+  const Tensor* tensor = GetTensorFromHandle(h, status);+  TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());+  auto* tf_dlm_tensor_ctx = new TFDLManagedTensorCtx;++  TensorReference* tensor_ref =+      new TensorReference(*tensor);  // This will call buf_->Ref()

You could probably save an allocation and just keep a TensorReference rather than a TensorReference* in the manager context object?

It also appears you leak the TensorReference object at the moment?

VoVAllen

comment created time in 3 days

issue closedgoogle/jax

`assert np.all(np.isfinite(x))` not working under jit

Dear jax team,

I'd like to check an intermediate result x inside a jit'ed function to not contain any NaNs via assert np.all(np.isfinite(x)).

I found that this does not work under jit. It does work though

  1. without jit or
  2. for return np.all(np.isfinite(x)) under jit.

If I can return but not assert under jit, what is the problem here?


repro:

import jax
import jax.numpy as np

a = np.array([1., np.nan, 3.])

@jax.jit
def test_nan_return_jit(x):
    return np.all(np.isfinite(x))

print(test_nan_return_jit(a)) # works (False)



def test_nan_assert_nojit(x):
    assert np.all(np.isfinite(x))

print(test_nan_assert_nojit(a)) # works (AssertionError)



@jax.jit
def test_nan_assert_jit(x):
    assert np.all(np.isfinite(x))

print(test_nan_assert_jit(a)) # doesn't work (TypeError: Abstract value passed to `bool`)

closed time in 3 days

clemisch

issue commentgoogle/jax

`assert np.all(np.isfinite(x))` not working under jit

This is working as intended, and it's a fairly fundamental consequence of the design of JAX.

When you wrap a function with jit, we trace the function with symbolic ("abstract") values. However, whenever tracing encounters control flow (here, the if condition that is implicitly inside assert), in order to know which branch to take, we must have a concrete value for the if-condition. If we don't, you get an exception, much like the one you see there.

Taking a step back, you can't really call assert inside a jit-decorated function unless you are asserting properties that would be known at jit trace time, not run time. Here, the result of the assertion is clearly known only at runtime because it depends on x.

The workaround are:

  • return a boolean value from jit and do the assertion outside (i.e., your first version)
  • don't use jit around that function (i.e., your second version)

Does that help?

clemisch

comment created time in 3 days

issue commentgoogle/jax

Failed to load Starlark extension '@com_github_grpc_grpc//bazel:grpc_deps.bzl'.

I think a recent PR broke the build. Sorry for the breakage. We'll look into it. For now, can you sync back to https://github.com/google/jax/commit/96b66ac9762225da4d3977e82dfd35c1160827bc ?

(Alternatively, just use a prebuilt jaxlib, as described in the README!)

oliviermattelaer

comment created time in 3 days

CommitCommentEvent

issue commentgoogle/jax

Is there a way to store (serialize) expensive first-time use JIT'ted methods?

As a general proposition, if JAX builds a very large XLA computation, XLA may take a long time to compile it. That's probably what's happening here. It's hard to say why your computation is so large without a runnable reproduction. The most common source is Python loops, which often have the effect of unrolling the computation and should be replaced with lax loop constructs.

It would be possible to say more with a small runnable reproduction.

larsgeb

comment created time in 5 days

issue commentgoogle/jax

Slow JIT compilation when involving index_update and index_add

One other thing I will note is that there is an in-Python LU decomposition that JAX uses as its TPU implementation, here: https://github.com/google/jax/blob/b6e834117616977c85dae1e166124e3254304ca4/jax/lax_linalg.py#L497

The inner block kernel is rolled, which helps avoid the size explosion.

goingtosleep

comment created time in 6 days

push eventgoogle/jax

Peter Hawkins

commit sha b6e834117616977c85dae1e166124e3254304ca4

Improve developer documentation. (#2247) Add Python version test to build.py.

view details

push time in 6 days

PR merged google/jax

Improve developer documentation. cla: yes

Add Python version test to build.py.

Documentation improvements are thanks to nairb774.

+23 -5

0 comment

2 changed files

hawkinsp

pr closed time in 6 days

PR opened google/jax

Improve developer documentation.

Add Python version test to build.py.

Documentation improvements are thanks to nairb774.

+23 -5

0 comment

2 changed files

pr created time in 6 days

push eventhawkinsp/jax

Colin

commit sha d6489103f754674eb5f16ded961bbbbc2c5817e5

Bump cell execution timeout (#2147) Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to - Cell timeouts (which this tries to fix), - Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks), - Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).

view details

Roman Novak

commit sha 1022573b26a1996db524229de10fb84dbe6e08b3

Make stax pooling layers accept `spec=None` (#2145) Currently pooling layers have a default channel-last spec that is explicitly 2D. This change will make this default work for arbitrary input dimensionality.

view details

Stephan Hoyer

commit sha 0644f5c56175104d862cf7e03fe6f7cd14cdba88

Better batching rule for triangular_solve (#2138) * Better batching rule for triangular_solve Now, if only the right hand side argument `b` is batched, we leverage triangular solve's builtin batching for handling multiple right-hand-side vectors. This makes the performance of `vmap` over only the second argument of linear solves equivalent to relying on builtin batching:: rs = onp.random.RandomState(0) a = rs.randn(500, 500) + 0.1 * np.eye(500) b_mat = jax.device_put(rs.randn(500, 10)) solve1 = jax.jit(np.linalg.solve) solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1)) Before:: In [6]: %timeit jax.device_get(solve1(a, b_mat)) 3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) # 8x slower :( In [9]: %timeit jax.device_get(solve2(a, b_mat)) 23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) Now:: In [2]: %timeit jax.device_get(solve1(a, b_mat)) 3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) # same speed :) In [3]: %timeit jax.device_get(solve2(a, b_mat)) 3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) * Test failures * Check b.ndim == 2 in triangular solve shape rule

view details

Peter Hawkins

commit sha 0b1d2fc3d187f779934cfaeb9188e1fcb208a6fc

Avoid accidental type promotion in gamma sampler gradient. (#2150) Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.

view details

Peter Hawkins

commit sha 3c9ae5e221316c82f1dda34aa7f12173b12e3a21

Add jax.scipy.stats.logistic to documentation. (#2149)

view details

George Necula

commit sha 4f5987ccd9c4447955d8dc3463613ac3bc44b6a3

Simplify Jaxpr: remove freevars. Freevars played a very small role, and they can be folded with the invars. This simplifies the Jaxpr data structure.We remove the `freevars` field from Jaxpr and from the bound_subjaxprs. The only non-trivial change is for xla_pmap, where we need to carry one extra parameter `mapped_invars` with a bitmap to encode which invars are mapped and which are broadcast. Previously, the freevars were broadcast.

view details

George Necula

commit sha a955fd9deee29f6de5023bf4077d678a56084785

Updated notebook that refered to freevars

view details

Stephan Hoyer

commit sha 2d0b8c2c609829e9745c1a1bc64c0fcf777fc899

Fix precision in triangular solve batching test for TPUs (#2159)

view details

Skye Wanderman-Milne

commit sha 7404e88b358a377cdf8d8e580349311185184af6

Adjust scipy_stats_test.py tolerance.

view details

Skye Wanderman-Milne

commit sha b19f7e935781bd848525c418c5a093b1d200dca5

WIP sharded_jit implementation (#2158)

view details

Anselm Levskaya

commit sha ffc55ee6008c054a2e58d01f64ba0ced36b36048

Update linspace edgecase to match numpy fix. (#2162) * Update linspace edgecase to match numpy fix. * only test fixed linspace behavior against newer numpy * remove unneeded version pkg

view details

George Necula

commit sha 272620e66cd2f9c686a45306f469df873f535fc7

Added note to CHANGELOG.md

view details

Jonas Adler

commit sha 4080a1c2ce95dc4a90f899fe4bf9ad5ac6a7b8b3

Add np.fft.fftshift/ifftshift (#1850)

view details

Lukas Prediger

commit sha ddc83e093778e227c7688f6ad16888b211a554ef

Added dtype arg for NN initializer factory methods (#2034) * Added dtype arg for NN initializer factory methods Initializer factories in jax/nn/initializers.py (such as uniform(), normal(), glorot_normal(), etc) now have an optional `dtype` argument. The value passed in that argument becomes the default value for the same `dtype` argument of the initializer function returned by the factory. * fixed failed test for delta_orthogonal in d12cdc47

view details

George Necula

commit sha 862a1d594b67845f480bba7b342afb134ffc1a14

Moved the mapped_invars parameter setting to the process_map

view details

George Necula

commit sha d01210e9e338b8051da78fcc104b404b82ffd8a0

Merge pull request #1959 from gnecula/no_freevars An attempt to remove freevars from JAXPR.

view details

George Necula

commit sha b18a4d8583c0e11e228a0792793d6f6e99292766

Disabled tests known to fail on Mac, and optionally slow tests. Issue: #2166 Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known to be slow.

view details

Pavel Sountsov

commit sha b2ef5bc09552e8ed39759df4ff49ea97e32db708

Canonicalize the shape in the wrapper functions in random.py. (#2165) * Canonicalize the shape in the wrapper functions in random.py. This lets the user be more sloppy in using numpy arrays and statically known DeviceArrays for shapes, and still hit the jit cache. When they are not, the error is improved. * Fix some errors. * No need for the Poly workaround. * Bypass canonicalization for None shapes in random.py.

view details

Skye Wanderman-Milne

commit sha 13316f35705fee2f43376655313e698b52d1965f

Fix type error in partial_eval.py. (#2171)

view details

George Necula

commit sha ae3003e9d42a8df41d6d8bbd62c0ba2b4c2c13ce

Simplify bound_subjaxprs. Before, bound_subjaxprs was a tuple (0 or 1 values) of a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs such that they do not take constvars and their constant values are part of the arguments. We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr) This is first part of a simplification. In a subsequent PR I will move the bound_subjaxpr into params, as for most higher-order primitives.

view details

push time in 6 days

issue commentgoogle/jax

JAX/XLA compiling for one device and running on another

@mktal I just hit this issue myself.

Trax uses TF for its input pipelines. I think it might be some sort of bad interaction between JAX and TensorFlow on GPU. Can you try installing a CPU-only tensorflow wheel from here: https://www.tensorflow.org/install/pip#package-location

Note the default wheels include GPU support, so you'll need to explicitly install a tensorflow-cpu wheel.

jlebar

comment created time in 7 days

push eventgoogle/jax

Peter Hawkins

commit sha 9b362380039ba6eaa72ec4122c9544b71c5b5665

Fix sha256 sum for XLA release. (#2230)

view details

push time in 9 days

PR merged google/jax

Fix sha256 sum for XLA release. cla: yes
+1 -1

0 comment

1 changed file

hawkinsp

pr closed time in 9 days

PR opened google/jax

Fix sha256 sum for XLA release.
+1 -1

0 comment

1 changed file

pr created time in 9 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 71adf8cf7cda2644e25440d79a9b5899beffa970

Update XLA. (#2229) .

view details

Peter Hawkins

commit sha 1c2609bbfecfb75acc003e87a4ee7cc7a93f4a42

Fix sha256 sum for XLA release.

view details

push time in 9 days

push eventgoogle/jax

Peter Hawkins

commit sha 71adf8cf7cda2644e25440d79a9b5899beffa970

Update XLA. (#2229) .

view details

push time in 9 days

PR merged google/jax

Update XLA. cla: yes

.

+3 -3

0 comment

1 changed file

hawkinsp

pr closed time in 9 days

PR opened google/jax

Update XLA.

.

+3 -3

0 comment

1 changed file

pr created time in 9 days

push eventhawkinsp/jax

Roy Frostig

commit sha afb8af19ff7474561c3c904e03c63dbf8f57de3f

implement JVP of while loop Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

Peter Hawkins

commit sha 64bf55dc6fb76e68c92bea52b0ff1dc6f9fa0894

Update XLA. (#1997) Drop six dependency from jaxlib, since xla_client.py no longer uses six.

view details

AmKhan

commit sha dcda87d0e7f79ada2ca72e4054b92c799971c543

added batching to LAPACK triangular_solve (#1985) * Added batching to cpu triangular_solver * addressed comments about int overflows and returned triangular solve to use XLA over LAPACK * add todo to benchmark LAPACK vs XLA

view details

Peter Hawkins

commit sha 938a7f801255d1d9969c68131d9a894819589ce9

Remove :libjax alias from BUILD file. (#1996)

view details

Peter Hawkins

commit sha 11224bd2b19a011cd65577940ac26eb80be04ba7

Use a uniform rng rather than a normal rng to defeat CSE. (#2000) The normal distribution is relatively expensive to compute.

view details

Peter Hawkins

commit sha 653001aa64f8de1a6b020a4c4bf17949cea0584b

Update references to bazel repositories in WORKSPACE to match TF head. (#2005)

view details

Stephan Hoyer

commit sha a5644edbbcbd4093cc8cfa993145f471336fa0b6

Defer to unrecognized types in arithmetic (#1942) This is useful for building higher level array libraries around JAX, because it makes it possible to override operations like `jax_array + other`. I think I covered all the array types that JAX should be able to handle: - Python builtin numbers int, float and complex - NumPy scalars - NumPy arrays - JAX array types and tracers Did I miss anything? Maybe bfloat16 scalars?

view details

Trevor Cai

commit sha 12975bbcc84ae0303f243b3ed727340bbd514321

[pmap] Add support for nested pmaps on multihost platforms via axis_size (#2002) One issue with nested pmaps on multihost platforms is inferring the global pmap axis size without communication. This commit sidesteps the issue by adding an `axis_size` argument to manually provide this information. This change only enables a single cross-host pmap; all inner pmaps must be single-host. Addressing: #1753

view details

Srinivas Vasudevan

commit sha 80b35dd4e52e0b19e3eaf555c8b8fbdf8ea3e9e8

Add betainc to JAX (#1998) Adds betaln, a wrapper for the Beta function (scipy.special.betaln).

view details

Julius Kunze

commit sha 55c971e47fbfad18bdb2248efc72249221bffae9

Implement shapecheck for more primitives (#1990) * shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals * WIP shapecheck np.pad * Implement shapecheck of gather, pad * Fix shapecheck of pad * Implement polymorphic shape rule for (strided/dilated) convolution, refactor * Cleanup * Fix * Remove all polymorphic shape rules, reuse shape rules instead. * Register shape_rule for all standard_primitives * Remove ShapeExpr, canonicalize_poly, renames * Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes * Allow Poly of form d*poly + k to be divided by d * Fix bug, inline poly_without_zeros.

view details

Roy Frostig

commit sha 335ecb97b838ea0185c26ec744d63c3a096a858b

test JVP of while loop, and fix the nonzero tangent calculation in the JVP rule

view details

Roy Frostig

commit sha 28f70cc8f8ac4fc4b98ce85f51b389cec8635704

Merge pull request #1980 from google/jvp-while implement JVP of while loop. closes #650

view details

Surya Bulusu

commit sha 71323b5d023a33bf8c06d435c4a6e406dea3c0a8

changes loop_mjp(f, x, M) (#2013) a minor change: we iterate over M and not S

view details

Jamie Townsend

commit sha 3974df0aeeeb89d74ee6832894ab153406626266

[docs] Pmap compiles functions with XLA (#2021)

view details

Jamie Townsend

commit sha 371001aad1d6235e86e8825286b7fa2eb1f0fdb2

Fix README typo (#2020)

view details

Skye Wanderman-Milne

commit sha 19fb494adbd4cefc984f08edf1982e644fe4a2b1

Add jax changelog (#2022)

view details

Peter Hawkins

commit sha 7dbc8dc1bc1bec13408c02d9415f001eabc7595e

Minimal changes to make Jax pass a pytype check. (#2024)

view details

Matthew Johnson

commit sha 17b5fe11d03e35bcebe662fc14d2decaeca81b95

add test for issue #553

view details

Mu Li

commit sha 2c80cd3d88c7393a9bfaa80da626be1a868c0049

Fix Sysml paper link in README (#2036) The original ULR was broken as sysml updated their links.

view details

Skye Wanderman-Milne

commit sha f04348ed534dc88e0041b753d369a3c7a39ec60b

Bump jaxlib version to 0.1.38 and update WORKSPACE.

view details

push time in 9 days

push eventgoogle/jax

Peter Hawkins

commit sha 33f8600acfd7c429adf98f3f6a5fd9525ab8cd95

Disable PRED dlpack test. (#2227)

view details

push time in 9 days

PR merged google/jax

Disable PRED dlpack test. cla: yes
+5 -5

0 comment

1 changed file

hawkinsp

pr closed time in 9 days

PR opened google/jax

Disable PRED dlpack test.
+5 -5

0 comment

1 changed file

pr created time in 9 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 91cd20b1734112e1005f8a9b00007b0b5a79a527

Update documentation and changelog to mention DLPack and array interface support. (#2134)

view details

Skye Wanderman-Milne

commit sha efbdaf66bfa584cc635092919a23b684c7fb2247

Adjust scipy_stats_test.py tolerance.

view details

Matthew Johnson

commit sha ae1d6b875fcc2b23909b360c73db00489f32068e

fix remat with nontrivial env (#2136) fixes #2030

view details

Ruizhe Zhao

commit sha 8c7fc3919d3e131da6a2121158084ed480dbec2a

Upgrade bazel from 0.29.1 to 1.2.1 (#2137)

view details

Peter Hawkins

commit sha fe041c75900023c1774bc81b59b094f38250f3b8

Set minimum Bazel version to 1.2.1.

view details

Colin

commit sha d6489103f754674eb5f16ded961bbbbc2c5817e5

Bump cell execution timeout (#2147) Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to - Cell timeouts (which this tries to fix), - Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks), - Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).

view details

Roman Novak

commit sha 1022573b26a1996db524229de10fb84dbe6e08b3

Make stax pooling layers accept `spec=None` (#2145) Currently pooling layers have a default channel-last spec that is explicitly 2D. This change will make this default work for arbitrary input dimensionality.

view details

Stephan Hoyer

commit sha 0644f5c56175104d862cf7e03fe6f7cd14cdba88

Better batching rule for triangular_solve (#2138) * Better batching rule for triangular_solve Now, if only the right hand side argument `b` is batched, we leverage triangular solve's builtin batching for handling multiple right-hand-side vectors. This makes the performance of `vmap` over only the second argument of linear solves equivalent to relying on builtin batching:: rs = onp.random.RandomState(0) a = rs.randn(500, 500) + 0.1 * np.eye(500) b_mat = jax.device_put(rs.randn(500, 10)) solve1 = jax.jit(np.linalg.solve) solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1)) Before:: In [6]: %timeit jax.device_get(solve1(a, b_mat)) 3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) # 8x slower :( In [9]: %timeit jax.device_get(solve2(a, b_mat)) 23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) Now:: In [2]: %timeit jax.device_get(solve1(a, b_mat)) 3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) # same speed :) In [3]: %timeit jax.device_get(solve2(a, b_mat)) 3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) * Test failures * Check b.ndim == 2 in triangular solve shape rule

view details

Peter Hawkins

commit sha 0b1d2fc3d187f779934cfaeb9188e1fcb208a6fc

Avoid accidental type promotion in gamma sampler gradient. (#2150) Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.

view details

Peter Hawkins

commit sha 3c9ae5e221316c82f1dda34aa7f12173b12e3a21

Add jax.scipy.stats.logistic to documentation. (#2149)

view details

George Necula

commit sha 4f5987ccd9c4447955d8dc3463613ac3bc44b6a3

Simplify Jaxpr: remove freevars. Freevars played a very small role, and they can be folded with the invars. This simplifies the Jaxpr data structure.We remove the `freevars` field from Jaxpr and from the bound_subjaxprs. The only non-trivial change is for xla_pmap, where we need to carry one extra parameter `mapped_invars` with a bitmap to encode which invars are mapped and which are broadcast. Previously, the freevars were broadcast.

view details

George Necula

commit sha a955fd9deee29f6de5023bf4077d678a56084785

Updated notebook that refered to freevars

view details

Stephan Hoyer

commit sha 2d0b8c2c609829e9745c1a1bc64c0fcf777fc899

Fix precision in triangular solve batching test for TPUs (#2159)

view details

Skye Wanderman-Milne

commit sha 7404e88b358a377cdf8d8e580349311185184af6

Adjust scipy_stats_test.py tolerance.

view details

Skye Wanderman-Milne

commit sha b19f7e935781bd848525c418c5a093b1d200dca5

WIP sharded_jit implementation (#2158)

view details

Anselm Levskaya

commit sha ffc55ee6008c054a2e58d01f64ba0ced36b36048

Update linspace edgecase to match numpy fix. (#2162) * Update linspace edgecase to match numpy fix. * only test fixed linspace behavior against newer numpy * remove unneeded version pkg

view details

George Necula

commit sha 272620e66cd2f9c686a45306f469df873f535fc7

Added note to CHANGELOG.md

view details

Jonas Adler

commit sha 4080a1c2ce95dc4a90f899fe4bf9ad5ac6a7b8b3

Add np.fft.fftshift/ifftshift (#1850)

view details

Lukas Prediger

commit sha ddc83e093778e227c7688f6ad16888b211a554ef

Added dtype arg for NN initializer factory methods (#2034) * Added dtype arg for NN initializer factory methods Initializer factories in jax/nn/initializers.py (such as uniform(), normal(), glorot_normal(), etc) now have an optional `dtype` argument. The value passed in that argument becomes the default value for the same `dtype` argument of the initializer function returned by the factory. * fixed failed test for delta_orthogonal in d12cdc47

view details

George Necula

commit sha 862a1d594b67845f480bba7b342afb134ffc1a14

Moved the mapped_invars parameter setting to the process_map

view details

push time in 9 days

Pull request review commentgoogle/jax

Allow ShardedDeviceArrays to represent arbitrary data shardings.

 def block_until_ready(self):   @property   def _value(self):     if self._npy_value is None:-      ids = self._ids()+      # TODO(skye): remove this to avoid transferring replicated buffers?

It's triggering the transfers for all of the device buffers before blocking on the transfer for any of them.

skye

comment created time in 9 days

issue commentgoogle/jax

"Could not find any cudnn.h" when building from source

Please open a new issue. I don't think this is the same problem. a) if you haven't considered it, you should prefer to use a prebuilt jaxlib. b) everything that patch did is present in master (i.e., --cudnn_path is optional, and if you don't specify it, it acts the same way as that patch did.)

nottombrown

comment created time in 10 days

push eventgoogle/jax

Tom Hennigan

commit sha 9797ea2485540ac7bf9e0f48cba4a4c7f0a6d8bc

Implement size/ndim/__len__/repr/str/eq/hash for ShapeDtypeStruct. (#2206)

view details

push time in 12 days

PR merged google/jax

Implement size/ndim/__len__/repr/str/eq/hash for ShapeDtypeStruct. cla: yes

Per #2179 we considered and rejected NamedTuple for this case and the alternatives (dataclasses or attrs) would require changing system requirements or adding a dep which I have deferred to another pull request. For now I've hand rolled the missing methods.

Fixes #2179.

+46 -0

2 comments

3 changed files

tomhennigan

pr closed time in 12 days

issue closedgoogle/jax

Should ShapeDtypeStruct.size exist?

ShapeDtypeStruct currently has dtype and shape. We could add numpy's size property via

class ShapeDtypeStruct(object):
  __slots__ = ["shape", "dtype"]
  def __init__(self, shape, dtype):
    self.shape = shape
    self.dtype = dtype

  @property
  def size(self):
    return np.prod(self.shape, dtype=int)

Do you want this patch?

closed time in 12 days

girving

pull request commentgoogle/jax

Implement size/ndim/__len__/repr/str/eq/hash for ShapeDtypeStruct.

Sure! Sorry, I was waiting for the presubmits to finish but I never came back to do the merge yesterday.

tomhennigan

comment created time in 12 days

push eventgoogle/jax

Tom Hennigan

commit sha 4c682b46bb05301212b0cfe8affa970e1feccde4

Add missing sources to jax build. (#2208)

view details

push time in 13 days

PR merged google/jax

Reviewers
Add missing sources to jax build. cla: yes
+2 -0

0 comment

1 changed file

tomhennigan

pr closed time in 13 days

Pull request review commentgoogle/jax

Implement size/ndim/__len__/repr/str/eq/hash for ShapeDtypeStruct.

 def __init__(self, shape, dtype):     self.shape = shape     self.dtype = dtype +  size = property(lambda self: onp.prod(self.shape))+  ndim = property(lambda self: len(self.shape))++  def __len__(self):+    try:+      return self.shape[0]+    except IndexError:+      raise TypeError("len() of unsized object")  # same as numpy error++  def __repr__(self):+    dtype = self.dtype.dtype+    return f"{type(self).__name__}(shape={self.shape}, dtype={dtype.name})"

We haven't yet dropped Python 3.5 compatibility. I don't think you can use f-strings yet.

tomhennigan

comment created time in 13 days

push eventgoogle/jax

Peter Hawkins

commit sha e9d06ecf53a5227bb324b682fab74b628430ff9d

Reenable convolution gradient tests on TPU that now pass. (#2207)

view details

push time in 13 days

PR merged google/jax

Reenable convolution gradient tests on TPU that now pass. cla: yes
+0 -1

0 comment

1 changed file

hawkinsp

pr closed time in 13 days

PR opened google/jax

Reenable convolution gradient tests on TPU that now pass.
+0 -1

0 comment

1 changed file

pr created time in 13 days

push eventhawkinsp/jax

Julius Kunze

commit sha 9d12a24b63ac10943006d588f601718c135c12ef

Add categorical sampler

view details

Julius Kunze

commit sha 6178755281cd2ff7c48d26c584ac4e1e1f474d1c

Remove safe zip/map

view details

Julius Kunze

commit sha 698327080be9fd4658291fc918d764b8f214f0d2

Clarify documentation

view details

Clemens Schmid

commit sha 592f167e5bfd7f4d6ff67bdaba94008895c5b5fa

Implement numpy.gradient

view details

Clemens Schmid

commit sha b15a27a7fc4b8dce0056c53e5ec45d97ab40dee8

Tests for jax.numpy.gradient and minor tweaks

view details

Clemens Schmid

commit sha 48cb6af6b4125b5874c6ffc2586dbe9c5c1565f0

Support None and negative indices in slice_in_dim

view details

Clemens Schmid

commit sha ac1aaedc4f2ec6acf352e7ccb718a9a4fb59ae06

Change from swapaxes to slice_in_dim in numpy.gradient

view details

clemisch

commit sha c907504078f9ee75ef60e8d21fd0b1e69c03884f

Merge branch 'master' into master

view details

Clemens Schmid

commit sha 9ef9b38b4e2274608d7449269af9c29780482ed1

Put axis in named_parameters for numpy.gradient test

view details

Matthew Johnson

commit sha 327dca8f769a723ee47c560d4f7b5ca044b390b8

Merge pull request #1944 from clemisch/master Implement numpy.gradient

view details

Peter Hawkins

commit sha facbe0d76a58c45f48bc9338827328d70ba4a76b

Handle 0D convolutions correctly in shape rule. (#1972)

view details

Matthew Johnson

commit sha 8bca2c90e7e1ae22ac645cfb7175a76937ca7b9c

fix urllib import for py3

view details

Julius Kunze

commit sha f36d858c4ef62f5ea2f9616701f33fb38b74b464

Require shape = sample_shape + batch_shape in random.categorical

view details

Matthew Johnson

commit sha 00be20bdfa8fef2764fda04576c71fb6d699a99c

Merge pull request #1855 from JuliusKunze/categorical Add categorical sampler

view details

Skye Wanderman-Milne

commit sha 0417e1e5c38addc898eaca73c43115a6f4523c92

Fix `jax.lax.axis_index` in multi-host setting. (#1976)

view details

Skye Wanderman-Milne

commit sha 773ebe1323e6cf9783e6a02bf3e189e5bf9aeb00

Adjust tolerance for LaxTest.testConv0DIsDot. (#1978) This was failing on TPU.

view details

Roy Frostig

commit sha afb8af19ff7474561c3c904e03c63dbf8f57de3f

implement JVP of while loop Co-authored-by: Matthew Johnson <mattjj@google.com>

view details

Skye Wanderman-Milne

commit sha 160cc43a5d098a935c2c8bb2157915a64e0e5582

Disable failing GPU test for now pending XLA fix.

view details

Skye Wanderman-Milne

commit sha ed33d102796e5d9d32fe6b222f2bfb06ecff293f

Add ppermute as an allowed multi-host collective. (#1981) I manually tested that this works as of 0417e1e. The indices used in ppermute correspond to those returned by `axis_index`.

view details

archis

commit sha 05f09fc93511c3e1a480ccc81f055d14488779bc

added rfftfreq, tests, and documentation link.

view details

push time in 13 days

Pull request review commentgoogle/jax

Removed copyright from third-party/numpy

+This sub-directory contains third-party code for which Google does not have+copyright. Each sub-directory should correspond to a third-party library and+must contain the appropriate LICENSE file. +See [instructions](https://g3doc.corp.google.com/company/teams/opensource/releasing/preparing.md?cl=head#third-party-components).

This is an internal link, but happily Google's opensource documentation has itself been opensourced: https://opensource.google/docs/releasing/preparing/#third-party-components

gnecula

comment created time in 13 days

push eventgoogle/jax

lanctot

commit sha 051d7b895658f22e9ca64fc77961d61467e20e05

Fix broken link in README (#2196)

view details

push time in 15 days

PR merged google/jax

Fix broken link in README cla: yes

Hi all, just a quick fix to a link on your README.

+1 -1

2 comments

1 changed file

lanctot

pr closed time in 15 days

pull request commentgoogle/jax

Fix broken link in README

Thanks for the fix!

lanctot

comment created time in 15 days

push eventgoogle/jax

Du Phan

commit sha be5b24fa5d8dc2aa6128772571836949328128d0

relax the ndim>=1 condition of tensordot (#2191) * relax the ndim condition of tensordot * add test for scalar input with axes=0

view details

push time in 16 days

PR merged google/jax

relax the ndim>=1 condition of tensordot cla: yes

This condition is not necessary. NumPy allows scalar inputs with axes=0. In my opinion, the check

if axes > _min(a_ndim, b_ndim):

in a few lines below is enough. Currently, jax.numpy.tensordot(np.ones(3), 1., axes=0) raise an error while numpy.tensordot(np.ones(3), 1., axes=0) returns array([1., 1., 1.]).

+1 -3

2 comments

2 changed files

fehiepsi

pr closed time in 16 days

pull request commentgoogle/jax

relax the ndim>=1 condition of tensordot

Thanks for the PR!

fehiepsi

comment created time in 16 days

pull request commentgoogle/jax

relax the ndim>=1 condition of tensordot

Can you add a test? If it isn't tested, it will inevitably regress.

Thanks!

fehiepsi

comment created time in 16 days

issue closedgoogle/jax

How to convert jax model to tensorflow model

I got some code using jax to train a neural net model. I would like to save the results and load it into tensorflow for further training. Could it be possible and how I can do it? Any suggestion would be helpful. Thanks a lot!

closed time in 16 days

ai-gamer

issue commentgoogle/jax

How to convert jax model to tensorflow model

It's certainly possible to do what you ask manually, but there's no library support for it. For example, you can save the network weights using pickle or numpy's array I/O features and load the result into another system, such as Mxnet or PyTorch, provided you have an identical neural network definition in that system.

In general, this is probably out of scope for the core of JAX itself: JAX really focuses on NumPy + transformations, and this is a question for a neural network library built on top of JAX, of which there are several. For example, I believe Trax, which is built on top of JAX, may be able to do something along these lines. But JAX itself doesn't really have an opinion about neural networks or weights.

(We do have jax.experimental.stax, but it's more of a proof of concept than anything else.)

I hope that helps!

(Closing because I think this is out of scope for JAX itself.)

ai-gamer

comment created time in 16 days

push eventgoogle/jax

Peter Hawkins

commit sha 2e8798dd16acfe6373bf66780ae7402f7766854f

Use 64-bit integers for indexing if any tensor dimension exceeds 2^31 elements. (#2182)

view details

push time in 16 days

PR merged google/jax

Use 64-bit integers for indexing if any tensor dimension exceeds 2^31… cla: yes

… elements.

Fixes #2178

I didn't add a test because our Travis instances are quite small so we couldn't actually run such a test in our regression suite.

+6 -5

0 comment

1 changed file

hawkinsp

pr closed time in 16 days

issue closedgoogle/jax

Bug concatenating large arrays on CPU

On CPU, I get the following result when doing a concatenation of a large array with a smaller one

In [1]: import jax.numpy as np

In [2]: a = np.concatenate((np.ones(1 << 32), np.array([2., 3., 4.])))

In [3]: a[-4:]
Out[3]: DeviceArray([1., 1., 1., 1.], dtype=float32)

The result of the final line should be [1., 2., 3., 4], not [1., 1., 1., 1.]. I've had a look at the implementation of jax.numpy.concatenate and it seems to delegate pretty directly to lax.concatenate, which is itself just a direct wrapper around the XLA op, so I wonder if this might be a bug in XLA.

The result seems to be incorrect in the same way for other dtypes:

In [4]: a = np.concatenate((np.ones(1 << 32, 'float16'), np.array([2., 3., 4.], 'float16')))

In [5]: a[-4:]
Out[5]: DeviceArray([1., 1., 1., 1.], dtype=float16)

In [6]: a = np.concatenate((np.ones(1 << 32, 'uint16'), np.array([2., 3., 4.], 'uint16')))

In [7]: a[-4:]
Out[7]: DeviceArray([1, 1, 1, 1], dtype=uint16)

I haven't had time to test whether this same issue exists on GPU or TPU.

closed time in 16 days

j-towns

PR opened google/jax

Use 64-bit integers for indexing if any tensor dimension exceeds 2^31…

… elements.

Fixes #2178

I didn't add a test because our Travis instances are quite small so we couldn't actually run such a test in our regression suite.

+6 -5

0 comment

1 changed file

pr created time in 17 days

create barnchhawkinsp/jax

branch : overflow

created branch time in 17 days

Pull request review commentgoogle/jax

Disabled tests known to fail on Mac, and optionally slow tests.

   from contextlib import contextmanager+from distutils.util import strtobool

Yes: don't take the dependency, implement it ourselves in, say, config.py. It's probably only ~10 lines of code. It's very odd to depend on what seems to be an implementation detail of the Python module distribution library that probably shouldn't have been a public API in the first place.

gnecula

comment created time in 17 days

Pull request review commentgoogle/jax

Disabled tests known to fail on Mac, and optionally slow tests.

 def _skip_if_unsupported_type(dtype):       dtype in (onp.dtype('float64'), onp.dtype('complex128'))):     raise unittest.SkipTest("--jax_enable_x64 is not set") +# TODO(phawkins): bug https://github.com/google/jax/issues/2166

Update to issue #432

gnecula

comment created time in 17 days

Pull request review commentgoogle/jax

Disabled tests known to fail on Mac, and optionally slow tests.

   from contextlib import contextmanager+from distutils.util import strtobool

It feels like this is an inappropriate dependency. While distutils is in the Python standard library, this is a very strange reason to use it.

(I suspect this is probably copy-paste from another flag, but just an observation.)

gnecula

comment created time in 17 days

issue commentgoogle/jax

Problem with installing jax: cannot import name 'xla_data_pb2

Can you please share: a) what platform/OS you are using? b) the errors that pip emits? i.e., exactly what you ran, and exactly what was printed.

sergueev

comment created time in 17 days

issue commentgoogle/jax

Problem with installing jax: cannot import name 'xla_data_pb2

What version of pip do you have installed (pip show pip?) And what platform/OS are you running on?

sergueev

comment created time in 18 days

issue commentgoogle/jax

Problem with installing jax: cannot import name 'xla_data_pb2

It looks like you have very old versions of both jax and jaxlib installed. The current versions are jax: 0.1.58, jaxlib: 0.1.38.

Try:

pip uninstall jax jaxlib
pip install --upgrade pip
pip install --upgrade jax jaxlib

The pip install --upgrade pip is important.

sergueev

comment created time in 18 days

issue commentgoogle/jax

Problem with installing jax: cannot import name 'xla_data_pb2

Also, can you verify what version of jaxlib was successfully installed (pip show jaxlib)?

sergueev

comment created time in 18 days

issue commentgoogle/jax

Problem with installing jax: cannot import name 'xla_data_pb2

What version of jax do you have? (e.g., pip show jax?)

sergueev

comment created time in 18 days

issue commentgoogle/jax

Some linalg tests fail on Mac with latest scipy

Note that it is Mac OS specific.

gnecula

comment created time in 18 days

issue closedgoogle/jax

Some linalg tests fail on Mac with latest scipy

There are 7 tests in tests/linalg_tests.py that fail on Mac with scipy vesion 1.4.1 installed.

The tests are: 'tests/linalg_test.py::NumpyLinalgTest::testEigvals_shape=complex64[50,50]' 'tests/linalg_test.py::NumpyLinalgTest::testPinv_shape=complex64[7,10000]' 'tests/linalg_test.py::ScipyLinalgTest::testLuFactor_n=complex64[200,200]' 'tests/linalg_test.py::ScipyLinalgTest::testExpm_n=complex64[50,50]' 'tests/linalg_test.py::NumpyLinalgTest::testInv_shape=float32[200,200]' 'tests/linalg_test.py::NumpyLinalgTest::testPinv_shape=float32[7,10000]' 'tests/linalg_test.py::ScipyLinalgTest::testExpm_n=float32[50,50]'

The failure is

worker 'gw2' crashed while running 'tests/linalg_test.py::NumpyLinalgTest::testEigvals_shape=complex64[50,50]'

A longer stack trace is:

Fatal Python error: Bus error

Thread 0x000070000693a000 (most recent call first):
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 400 in read
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 432 in from_io
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 967 in _thread_receiver
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 220 in run
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 285 in _perform_spawn

Thread 0x0000000106b4d5c0 (most recent call first):
  File "/Users/necula/Source/jax/jax/interpreters/xla.py", line 731 in _value
  File "/Users/necula/Source/jax/jax/interpreters/xla.py", line 826 in __array__
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/numpy/core/_asarray.py", line 85 in asarray
  File "/Users/necula/Source/jax/jax/test_util.py", line 682 in assertAllClose
  File "/Users/necula/Source/jax/jax/test_util.py", line 727 in _CompileAndCheck
  File "/Users/necula/Source/jax/tests/linalg_test.py", line 994 in testExpm
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/absl/testing/parameterized.py", line 263 in bound_param_test
  File "/Users/necula/homebrew/Cellar/python/3.7.4/Frameworks/Python.framework/Versions/3.7/lib/python3.7/unittest/case.py", line 628 in run
  File "/Users/necula/homebrew/Cellar/python/3.7.4/Frameworks/Python.framework/Versions/3.7/lib/python3.7/unittest/case.py", line 676 in __call__
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/unittest.py", line 207 in runtest
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/runner.py", line 117 in pytest_runtest_call
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/manager.py", line 81 in <lambda>
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/manager.py", line 87 in _hookexec
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/hooks.py", line 289 in __call__
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/runner.py", line 192 in <lambda>
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/runner.py", line 220 in from_call
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/runner.py", line 192 in call_runtest_hook
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/runner.py", line 167 in call_and_report
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/runner.py", line 87 in runtestprotocol
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/runner.py", line 72 in pytest_runtest_protocol
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/manager.py", line 81 in <lambda>
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/manager.py", line 87 in _hookexec
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/hooks.py", line 289 in __call__
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/xdist/remote.py", line 85 in run_one_test
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/xdist/remote.py", line 71 in pytest_runtestloop
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/manager.py", line 81 in <lambda>
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/manager.py", line 87 in _hookexec
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/hooks.py", line 289 in __call__
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/main.py", line 235 in _main
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/main.py", line 191 in wrap_session
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/_pytest/main.py", line 228 in pytest_cmdline_main
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/manager.py", line 81 in <lambda>
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/manager.py", line 87 in _hookexec
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/pluggy/hooks.py", line 289 in __call__
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/xdist/remote.py", line 250 in <module>
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 1084 in executetask
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 220 in run
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 285 in _perform_spawn
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 267 in integrate_as_primary_thread
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 1060 in serve
  File "/Users/necula/.virtualenvs/jax/lib/python3.7/site-packages/execnet/gateway_base.py", line 1554 in serve
  File "<string>", line 8 in <module>
  File "<string>", line 1 in <module>
[gw0] node down: Not properly terminated
f
replacing crashed worker gw0

I will disable the tests for now.

closed time in 18 days

gnecula

issue commentgoogle/jax

Some linalg tests fail on Mac with latest scipy

This is a duplicate of https://github.com/google/jax/issues/432

gnecula

comment created time in 18 days

issue commentgoogle/jax

Using TPUs in coreless mode like with tf.device(None)

@shawwn Unfortunately that's a capability that is only available to TensorFlow at the moment, and not to other users of TPUs. It's possible that might change in the future, but we can't make any promises at this time.

Tenoke

comment created time in 20 days

push eventgoogle/jax

Peter Hawkins

commit sha 3c9ae5e221316c82f1dda34aa7f12173b12e3a21

Add jax.scipy.stats.logistic to documentation. (#2149)

view details

push time in 20 days

PR merged google/jax

Add jax.scipy.stats.logistic to documentation. cla: yes
+14 -0

0 comment

1 changed file

hawkinsp

pr closed time in 20 days

push eventgoogle/jax

Peter Hawkins

commit sha 0b1d2fc3d187f779934cfaeb9188e1fcb208a6fc

Avoid accidental type promotion in gamma sampler gradient. (#2150) Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.

view details

push time in 20 days

PR merged google/jax

Avoid accidental type promotion in gamma sampler gradient. cla: yes

Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.

Fixes #2130

+102 -93

0 comment

2 changed files

hawkinsp

pr closed time in 20 days

issue closedgoogle/jax

Gamma sample gradients has incorrect dtype in x64 mode

The following snippet of code fails in JAX:

import jax
import jax.numpy as jnp
import jax.random as jaxrand
jax.config.update('jax_enable_x64', True)

key = jaxrand.PRNGKey(0.)
a = jnp.array(1., dtype=jnp.float32)
b = jnp.array(3., dtype=jnp.float32)
f = lambda x, y: jaxrand.gamma(key=key, a=x, dtype=jnp.float32) / y
y, f_vjp = jax.vjp(f, a, b)

with

TypeError: body_fun output and input must have identical types, got
(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float64[]), ShapedArray(bool[]))
and
(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(bool[]))

Some dtype information is being lost along the way.

closed time in 20 days

srvasude

issue commentgoogle/jax

Reference implementations

In general I'd say: if you want to contribute more reference implementations to lax_reference.py, that'd be very welcome. You might find it hard to do in general, particularly for some of the more parametric operators.

I have sitting around in a branch a JAX transformation/interpreter that runs JAX programs using classic NumPy (as opposed to building an AST.) I should probably clean it up and check it in at some point.

notEvil

comment created time in 20 days

PR opened google/jax

Avoid accidental type promotion in gamma sampler gradient.

Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.

Fixes #2130

+102 -93

0 comment

2 changed files

pr created time in 20 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 32ac6b70ead08544c09b3c6ad51e51d68dbe572b

Avoid accidental type promotion in gamma sampler gradient. Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.

view details

push time in 20 days

create barnchhawkinsp/jax

branch : gamma

created branch time in 20 days

PR opened google/jax

Add jax.scipy.stats.logistic to documentation.
+14 -0

0 comment

1 changed file

pr created time in 20 days

push eventhawkinsp/jax

Roman Novak

commit sha 6a4bb9516925f28f613387d7fdc45e2643fc1de6

Mare the reverse operator work on empty list of dimensions Example that this fixes: ``` from jax import lax import jax.numpy as np from jax.api import jacrev x = np.ones((3, 5)) def f(x): return lax.conv_general_dilated(lhs=x, rhs=np.ones((5, 2)), window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) jacrev(f)(x) ``` currently gives ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-136-2ad65e41f1de> in <module>() 12 dimension_numbers=('NC', 'IO', 'NC')) 13 ---> 14 jacrev(f)(x).shape 15 frames google3/third_party/py/jax/api.py in jacfun(*args, **kwargs) 514 y, pullback = vjp(f_partial, *dyn_args) 515 holomorphic or tree_map(_check_real_output_jacrev, y) --> 516 jac = vmap(pullback)(_std_basis(y)) 517 jac = jac[0] if isinstance(argnums, int) else jac 518 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args google3/third_party/py/jax/api.py in batched_fun(*args) 692 _check_axis_sizes(in_tree, args_flat, in_axes_flat) 693 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat, --> 694 lambda: _flatten_axes(out_tree(), out_axes)) 695 return tree_unflatten(out_tree(), out_flat) 696 google3/third_party/py/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests) 38 def batch(fun, in_vals, in_dims, out_dim_dests): 39 size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} ---> 40 out_vals, out_dims = batch_fun(fun, in_vals, in_dims) 41 return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals) 42 google3/third_party/py/jax/interpreters/batching.py in batch_fun(fun, in_vals, in_dims) 44 with new_master(BatchTrace) as master: 45 fun, out_dims = batch_subtrace(fun, master, in_dims) ---> 46 out_vals = fun.call_wrapped(*in_vals) 47 del master 48 return out_vals, out_dims() google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 150 gen = None 151 --> 152 ans = self.f(*args, **dict(self.params, **kwargs)) 153 del args 154 while stack: google3/third_party/py/jax/api.py in _vjp_pullback_wrapper(fun, cotangent_dtypes, io_tree, py_args) 1237 "match type of corresponding primal output ({})") 1238 raise TypeError(msg.format(_dtype(a), dtype)) -> 1239 ans = fun(*args) 1240 return tree_unflatten(out_tree, ans) 1241 google3/third_party/py/jax/interpreters/ad.py in vjp_(*cts) 114 dummy_primals_and_cts = (core.unit,) * len(cts) + cts 115 dummy_args = (undefined_primal,) * len(jaxpr.invars) --> 116 _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts) 117 arg_cts = arg_cts[len(primals):] 118 return map(instantiate_zeros, primals, arg_cts) google3/third_party/py/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in) 222 map(write_cotangent, bound_vars, ct_free_vars_out) 223 else: --> 224 cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) 225 cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out 226 map(write_cotangent, eqn.invars, cts_out) google3/third_party/py/jax/interpreters/ad.py in bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs) 505 assert (x is undefined_primal) ^ (y is undefined_primal) 506 if x is undefined_primal: --> 507 out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs) 508 return out, None 509 else: google3/third_party/py/jax/lax/lax.py in _conv_general_dilated_transpose_lhs(g, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, lhs_shape, rhs_shape, precision) 2042 window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation, 2043 rhs_dilation) -> 2044 revd_weights = rev(rhs, rhs_sdims) 2045 return conv_general_dilated( 2046 g, revd_weights, window_strides=lhs_dilation, padding=padding, google3/third_party/py/jax/lax/lax.py in rev(operand, dimensions) 671 operator. 672 """ --> 673 return rev_p.bind(operand, dimensions=tuple(dimensions)) 674 675 def select(pred, on_true, on_false): google3/third_party/py/jax/core.py in bind(self, *args, **kwargs) 157 top_trace = find_top_trace(args) 158 if top_trace is None: --> 159 return self.impl(*args, **kwargs) 160 161 tracers = map(top_trace.full_raise, args) google3/third_party/py/jax/interpreters/xla.py in apply_primitive(prim, *args, **params) 159 def apply_primitive(prim, *args, **params): 160 """Impl rule that compiles and runs a single primitive 'prim' using XLA.""" --> 161 compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params) 162 return compiled_fun(*args) 163 google3/third_party/py/jax/interpreters/xla.py in xla_primitive_callable(prim, *arg_specs, **params) 167 device = _device_from_arg_devices(arg_devices) 168 backend = xb.get_device_backend(device) --> 169 aval_out = prim.abstract_eval(*avals, **params) 170 if not prim.multiple_results: 171 handle_result = aval_to_result_handler(device, aval_out) google3/third_party/py/jax/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs) 1540 return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs)) 1541 elif least_specialized is ShapedArray: -> 1542 return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)) 1543 elif least_specialized is UnshapedArray: 1544 return UnshapedArray(dtype_rule(*args, **kwargs)) google3/third_party/py/jax/lax/lax.py in _rev_shape_rule(operand, dimensions) 2620 msg = 'rev dimensions must be unique, got {}.' 2621 raise TypeError(msg.format(dimensions)) -> 2622 if not _max(dimensions) < operand.ndim: 2623 msg = ('rev dimensions must all be less than operand ndim, got dimensions ' 2624 '{} for operand ndim {}.') ValueError: max() arg is an empty sequence ```

view details

Roman Novak

commit sha 95ccaae8058f8fb49c81680f1f9061bf96d8d95e

Add test for empty dimension list for reversion

view details

Peter Hawkins

commit sha 9a0338d6aa1b6006be98983eb3d33c8507dcd383

Update README.md and CHANGELOG.md. (#2096)

view details

Peter Hawkins

commit sha 55f2d3be27eaf0f75aac9b2937e6fab87076315a

Update Jaxlib docker build. * work around https://github.com/bazelbuild/bazel/issues/9254 by setting BAZEL_LINKLIBS=-lstdc++ * drop CUDA 9.0 support, since we use a batched kernel only present in CUDA 9.2 or later. * drop Python 2.7 support.

view details

Peter Hawkins

commit sha 58f949f3168ebdc0264448022f90ba8271746356

Merge pull request #2098 from hawkinsp/jaxlib Update Jaxlib docker build.

view details

Peter Hawkins

commit sha b54c18efb4e30831c77cd8698dcdaa7864e74440

Use Device hash and equality instead of using a (class, id) pair. We couldn't figure out why we did it this way in the first place and all the tests we have pass.

view details

Daniel Johnson

commit sha b68d8b5c4fead01ba85da9a9574a686493a5b7ba

Clarify instructions for building from source. (#2093) Adds additional subsections of the `Building from source` documentation page to make it more obvious that you can install `jaxlib` from pip when doing Python-only development.

view details

Peter Hawkins

commit sha 126ae7fccfc9d2a93947f26fd427ed87c627bb08

Implement ndarray.tolist() on DeviceArray.

view details

Peter Hawkins

commit sha 35810c9dcd181b087959f6089deff320edef9872

Merge pull request #2101 from hawkinsp/tolist Implement ndarray.tolist() on DeviceArray.

view details

Matthew Johnson

commit sha 1afcac70dfeaa4ffc89d79dc64f72361e18c4a91

tweak readme not to have bad line wrap

view details

Skye Wanderman-Milne

commit sha 6aaf257d8a61e654bb7c42a44f1a37e6204d311c

Update WORKSPACE

view details

Matthew Johnson

commit sha 71811be3b9257ab1fc48fce8fb0512c0a384d901

tweak top-line announcement text in readme

view details

Matthew Johnson

commit sha d46e82d0abbb4e0d009d5b201a178871d5e2c672

tweak readme announcement text again

view details

Peter Hawkins

commit sha 9f7f161c5f9667c96b24db07325eddfce6dd26d4

Incorporate review comments.

view details

James Bradbury

commit sha 1a5d9c531a7ec37da5ccb0a904aea454f20df5ee

clear compilation cache before metadata tests (#2103)

view details

Peter Hawkins

commit sha 7b7c89db9833b1e05c56d65ab3806ae77159c4d7

Merge pull request #2086 from romanngg/patch-6 Make the reverse operator work on empty list of dimensions

view details

Tom Hennigan

commit sha 4e575e1492afc08c8307d77cf7c784a17a52016e

Support trees in lax parallel operations. (#1953) It is relatively common to apply collective operations to trees. For example in sync distributed training it is typical to sum all gradients across replicas `grads = jax.tree_map(partial(lax.psum, axis_name='i'), grads)`. We can make this a little more convenient by making lax parallel ops support trees directly: `grads = lax.psum(grads, 'i')`. There is room for improvement in this change. We should in some (all?) cases just pass a tuple of values to XLA (rather than bind the primivive n times bind once with a tuple of n values) however this produced strange values when combined with pmap and a fix was not obvious. This is something we can follow up on without users having to change their code.

view details

Peter Hawkins

commit sha 102ce6f0acbfbc151accb62186c48b9c473271f0

Merge pull request #2100 from hawkinsp/devices Use Device hash and equality instead of using a (class, id) pair.

view details

Peter Hawkins

commit sha 04befac4f64e712772dfe5e3c248ce6cee7b618d

Fix error case in tensordot. (#2111)

view details

Peter Hawkins

commit sha 0904e5ff742d3ee76dbca4440e9372f0c2b5595b

Fix implementation of cumsum/cumprod for boolean inputs. (#2112) Check for number inputs in the reduce_window_sum dtype rule.

view details

push time in 20 days

issue commentgoogle/jax

JAX argsort has different behavior on ties compared to {np,tf}.argsort

Although on reflection I note that both NumPy and TF classic treat them as equal:

In [9]: np.argsort([-0., 0.])                        
Out[9]: array([0, 1])

In [10]: np.argsort([0., -0.])                       
Out[10]: array([0, 1])

In [15]: tf.argsort(np.array([0., -0.]))                                                                  
Out[15]: <tf.Tensor: id=32, shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>

In [16]: tf.argsort(np.array([-0., 0.]))                                                                  
Out[16]: <tf.Tensor: id=43, shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>

@srvasude Can you say more about what behavior you expected? It seems that neither TF or NumPy guarantee anything about the relative ordering of -0. and 0.; the sort is neither stable nor ordered in any particular way.

srvasude

comment created time in 20 days

issue commentgoogle/jax

JAX argsort has different behavior on ties compared to {np,tf}.argsort

TF2XLA will sort -0 < 0, because it uses the following comparator: https://github.com/tensorflow/tensorflow/blob/97f3666d8c10d317a050d786c645d181cbb4d0bb/tensorflow/compiler/xla/client/lib/comparators.cc#L70

We can easily do the same in JAX.

srvasude

comment created time in 20 days

issue closedgoogle/jax

Using TPUs in coreless mode like with tf.device(None)

Is there any way to use jax with TPUs in coreless mode?

In TensorFlow you can just use tf.device(None) to use the TPU's 300gb RAM + cpu for bigger operations but after looking at xla, the bridge, trax (which is where I am using jaxlib) and jax, I only seem to run into stuff like this error - 'JAX cannot work yet with n_devices != all devices: 1 != 8'.

closed time in 20 days

Tenoke

issue commentgoogle/jax

Using TPUs in coreless mode like with tf.device(None)

I think there's nothing we can do along those lines at the moment, because the way this works at present is specific to TensorFlow. However, it's possible future evolutions of the cloud TPU product might make it possible for JAX to make more use of the TPU VM, as TensorFlow does, in addition to the TPU devices themselves. Watch this space!

However, since there's no action we can take at the moment, I'm going to close this issue.

Tenoke

comment created time in 20 days

more