profile
viewpoint

apple/swift 50670

The Swift Programming Language

apple/swift-evolution 10807

This maintains proposals for changes and user-visible enhancements to the Swift Programming Language.

apple/swift-package-manager 7539

The Package Manager for the Swift Programming Language

apple/swift-corelibs-foundation 3766

The Foundation Project, providing core utilities, internationalization, and OS independence

llvm/llvm-project 3740

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies. Note: the repository does not accept github pull requests at this moment. Please submit your patches at http://reviews.llvm.org.

apple/swift-corelibs-libdispatch 1737

The libdispatch Project, (a.k.a. Grand Central Dispatch), for concurrency on multicore hardware

apple/swift-corelibs-xctest 791

The XCTest Project, A Swift core library for providing unit test support

apple/swift-llbuild 748

A low-level build system, used by Xcode and the Swift Package Manager

pull request commentllvm/mlir-www

Add PlaidML to MLIR users list

nice, welcome!

dgkutnic

comment created time in 7 days

push eventllvm/mlir-www

Denise Kutnick

commit sha 18088f749385b6b88ebe500d1890e89aaf617229

Add PlaidML to MLIR users list (#12)

view details

push time in 7 days

PR merged llvm/mlir-www

Add PlaidML to MLIR users list
+4 -0

0 comment

1 changed file

dgkutnic

pr closed time in 7 days

issue commentSwiftCommunityPodcast/podcast

💡 Topic: Swift for Good

5pm on GMT+1 looks like it is 8am on pacific time. For me, the best time would be Saturday or Sunday, Feb 15 or 16. I can do any time 5-11pm on GMT on those days.

BasThomas

comment created time in 9 days

pull request commentapple/swift-evolution

Remove a double colon

thanks!

BasThomas

comment created time in a month

push eventapple/swift-evolution

Bas Broek

commit sha b17d85fcaf38598fd2ea19641d0e9c26c96747ec

Remove a double colon (#1108)

view details

push time in a month

PR merged apple/swift-evolution

Remove a double colon
+1 -1

0 comment

1 changed file

BasThomas

pr closed time in a month

startedJohnSundell/Publish

started time in 2 months

Pull request review commentapple/swift

[AutoDiff upstream] Add `@transpose` attribute.

+// RUN: %target-swift-frontend -parse -verify %s++/// Good++@transpose(of: foo)+func transpose(v: Float) -> Float++@transpose(of: foo(_:_:))+func transpose(v: Float) -> Float++@transpose(of: wrt, wrt: 0)+func transpose(v: Float) -> Float++@transpose(of: foo, wrt: 0)+func transpose(v: Float) -> Float++@transpose(of: foo, wrt: (0, 1))+func transpose(v: Float) -> (Float, Float)++@transpose(of: foo, wrt: (self, 0, 1, 2))+func transpose(v: Float) -> (Float, Float, Float, Float)++// Qualified declaration.+@transpose(of: A.B.C.foo(x:y:_:z:))+func transpose(v: Float) -> Float++// Qualified operator.+// TODO(TF-1065): Consider disallowing qualified operators.

FWIW, I think that your implemented behavior is correct and nothing needs to change here. "TF-1065" is really about the existing swift compiler being buggy and not allowing qualified names for operators like ..< . I'd suggest splitting that off to a swift bug, and I'd recommend that operator having to be spelled as Swift...<

saeta

comment created time in 2 months

Pull request review commentapple/swift

[AutoDiff upstream] Add `@transpose` attribute.

 class DerivativeAttr final   } }; +/// The `@transpose` attribute registers a function as a transpose of another

nit: @transpose(of:)

saeta

comment created time in 2 months

Pull request review commentapple/swift

[AutoDiff upstream] Upstream `@derivative` attribute type-checking.

+//===--------- AutoDiff.cpp - Swift Differentiable Programming ------------===//

nit: please use the standard header guard, which left aligns the filename, and please change the thing after it to be descriptive, not a marketing thing :-)

dan-zheng

comment created time in 2 months

issue commentonnx/onnx

[RFC] ONNX Dialect in MLIR

Super exciting - It would be great to see a proper ONNX dialect in upstream MLIR when it is baked out! This could help provide a nice reference implementation for ONNX

tjingrant

comment created time in 2 months

pull request commentapple/swift

[AutoDiff upstream] Conform floating-point types to `Differentiable`.

I consider you to be the code owners of the stdlib/public/Differentiation module, so feel free to get someone to do code reviews if they touch that code.

dan-zheng

comment created time in 2 months

pull request commentapple/swift-evolution

Fix a typo in the example code.

Thanks!

yilei

comment created time in 3 months

push eventapple/swift-evolution

Yilei (Dolee) Yang

commit sha 497f3166c9baac4480a2400a795f2ddfdf67cd5a

Fix a typo in the example code. (#1096)

view details

push time in 3 months

PR merged apple/swift-evolution

Fix a typo in the example code.
+1 -1

0 comment

1 changed file

yilei

pr closed time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Add `@differentiating` attribute.

 DECL_ATTR(differentiable, Differentiable,   ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,   91) +DECL_ATTR(differentiating, Differentiating,

+1 for @derivative(of: ...)

dan-zheng

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Add `@differentiating` attribute.

 bool Parser::parseDifferentiableAttributeArguments(   return false; } +ParserResult<DifferentiatingAttr>

I'd recommend adding a doc comment with the grammar for this production, like many of the other methods in this file do.

dan-zheng

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Add `@differentiating` attribute.

 bool Parser::parseDifferentiableAttributeArguments(   return false; } +ParserResult<DifferentiatingAttr>+Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {+  StringRef AttrName = "differentiating";+  SourceLoc lParenLoc = loc, rParenLoc = loc;+  DeclNameWithLoc original;+  bool linear = false;+  SmallVector<ParsedAutoDiffParameter, 8> params;++  // Parse trailing comma, if it exists, and check for errors.+  auto consumeIfTrailingComma = [&]() -> bool {+    if (!consumeIf(tok::comma)) return false;+    // Diagnose trailing comma before ')'.+    if (Tok.is(tok::r_paren)) {+      diagnose(Tok, diag::unexpected_separator, ",");+      return true;+    }+    // Check that token after comma is 'linear' or 'wrt:'.+    if (!Tok.is(tok::identifier) ||+        !(Tok.getText() == "linear" || Tok.getText() == "wrt")) {

This logic is very difficult to read with all the negation happening, I'd recommend inverting this to be "if (positive test) return false; emit error"

Also, please use Tok.isNot where appropriate in general across this file (I'd recommend doing this on the TF branch globally).

dan-zheng

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Add `@differentiating` attribute.

 DECL_ATTR(differentiable, Differentiable,   ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,   91) +DECL_ATTR(differentiating, Differentiating,

Bikeshed language design issue - feel free to punt to the formal review process, but this attribute is textually very similar to @differentiable. Did you consider something like @derivativeOf as the name of this attribute?

dan-zheng

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Add `@differentiating` attribute.

 bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,     break;   } +  case DAK_Differentiating: {+    Printer.printAttrName("@differentiating");+    Printer << '(';+    auto *attr = cast<DifferentiatingAttr>(this);+    auto *derivative = dyn_cast_or_null<AbstractFunctionDecl>(D);+    Printer << attr->getOriginal().Name;+    auto diffParamsString = getDifferentiationParametersClauseString(

Relatedly, in this case, can dyn_cast_or_null<AbstractFunctionDecl>(D); ever return null? My understanding is that this attribute is only OnFunc - so this should only be on FuncDecls. If this check always returns non-null, then please change this to cast<> instead of dyn_cast_or_null

dan-zheng

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Add `@differentiating` attribute.

 bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,     break;   } +  case DAK_Differentiating: {+    Printer.printAttrName("@differentiating");+    Printer << '(';+    auto *attr = cast<DifferentiatingAttr>(this);+    auto *derivative = dyn_cast_or_null<AbstractFunctionDecl>(D);+    Printer << attr->getOriginal().Name;+    auto diffParamsString = getDifferentiationParametersClauseString(

It is a bit unrelated to this patch, but the contract around the first parameter to getDifferentiationParametersClauseString being non-null is very unclear to me, and getDifferentiationParametersClauseString contains several unchecked dereferences of the first parameter.

Possibly as a follow-up patch, I think getDifferentiationParametersClauseString needs to be improved - the doc comment itself should acknowledge the first parameter being null, and I think an audit of this behavior is in order as well. It might be reasonable to split this function into two different things - one that takes a non-null function and one that never takes one.

dan-zheng

comment created time in 3 months

pull request commentapple/swift

[AutoDiff upstreaming] add @noDerivative to AnyFunctionType params

+1 for @noDerivative !

marcrasi

comment created time in 3 months

pull request commentapple/swift

[AutoDiff upstreaming] add @nondiff to AnyFunctionType params

Well, that isn't right either, because it is possible differentiable but you don't want it in this case - better names still desired :)

marcrasi

comment created time in 3 months

pull request commentapple/swift

[AutoDiff upstreaming] add @nondiff to AnyFunctionType params

The name @nondiff is a bit weird. Was @nondifferentiable considered?

marcrasi

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] AST bits for @differentiable fn ty

 Type ASTBuilder::createImplFunctionType(     break;   } -  auto einfo = SILFunctionType::ExtInfo(representation,-                                        flags.isPseudogeneric(),-                                        !flags.isEscaping());+  auto einfo = SILFunctionType::ExtInfo(

Is there a mangling for these flags? If so, it seems like it should be plumbed in here for the demangler to use.

marcrasi

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] AST bits for @differentiable fn ty

 class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,     // Is this function guaranteed to be no-escape by the type system?     bool isNoEscape() const { return Bits & NoEscapeMask; } +    bool isDifferentiable() const {+      return getDifferentiabilityKind() >

nit, but it seems like != would be more natural here than >

marcrasi

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

 Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) {                            ProtocolType.get(), MemberName, MemberNameLoc)); } +ParserResult<DifferentiableAttr>+Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {+  StringRef AttrName = "differentiable";+  SourceLoc lParenLoc = loc, rParenLoc = loc;+  bool linear = false;+  SmallVector<ParsedAutoDiffParameter, 8> params;+  Optional<DeclNameWithLoc> jvpSpec;+  Optional<DeclNameWithLoc> vjpSpec;+  TrailingWhereClause *whereClause = nullptr;++  // Parse '('.+  if (consumeIf(tok::l_paren, lParenLoc)) {+    // Parse @differentiable attribute arguments.+    if (parseDifferentiableAttributeArguments(linear, params, jvpSpec, vjpSpec,+                                              whereClause))+      return makeParserError();+    // Parse ')'.+    if (!consumeIf(tok::r_paren, rParenLoc)) {+      diagnose(getEndOfPreviousLoc(), diag::attr_expected_rparen, AttrName,+               /*DeclModifier=*/false);+      return makeParserError();+    }+  }++  return ParserResult<DifferentiableAttr>(+      DifferentiableAttr::create(Context, /*implicit*/ false, atLoc,+                                 SourceRange(loc, rParenLoc), linear,+                                 params, jvpSpec, vjpSpec, whereClause));+}++bool Parser::parseDifferentiationParametersClause(+    SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName) {+  // Set parse error, skip until ')' and parse it.+  auto errorAndSkipToEnd = [&](int parenDepth = 1) -> bool {+    for (int i = 0; i < parenDepth; i++) {+      skipUntil(tok::r_paren);+      if (!consumeIf(tok::r_paren))+        diagnose(Tok, diag::attr_expected_rparen, attrName,+                 /*DeclModifier=*/false);+    }+    return true;+  };++  SyntaxParsingContext DiffParamsClauseContext(+       SyntaxContext, SyntaxKind::DifferentiationParamsClause);+  consumeToken(tok::identifier);+  if (!consumeIf(tok::colon)) {+    diagnose(Tok, diag::expected_colon_after_label, "wrt");+    return errorAndSkipToEnd();+  }++  // Function that parses a parameter into `params`. Returns true if error+  // occurred.+  auto parseParam = [&](bool parseTrailingComma = true) -> bool {+    SyntaxParsingContext DiffParamContext(+        SyntaxContext, SyntaxKind::DifferentiationParam);+    SourceLoc paramLoc;+    switch (Tok.getKind()) {+      case tok::identifier: {+        Identifier paramName;+        if (parseIdentifier(paramName, paramLoc,+                            diag::diff_params_clause_expected_parameter))+          return true;+        params.push_back(ParsedAutoDiffParameter::getNamedParameter(+            paramLoc, paramName));+        break;+      }+      case tok::integer_literal: {+        unsigned paramNum;+        if (parseUnsignedInteger(+                paramNum, paramLoc,+                diag::diff_params_clause_expected_parameter))+          return true;++        params.push_back(ParsedAutoDiffParameter::getOrderedParameter(+            paramLoc, paramNum));+        break;+      }+      case tok::kw_self: {+        paramLoc = consumeToken(tok::kw_self);+        params.push_back(ParsedAutoDiffParameter::getSelfParameter(paramLoc));+        break;+      }+      default:+        diagnose(Tok, diag::diff_params_clause_expected_parameter);+        return true;+    }+    if (parseTrailingComma && Tok.isNot(tok::r_paren))+      return parseToken(tok::comma, diag::attr_expected_comma, attrName,+                        /*isDeclModifier=*/false);+    return false;+  };++  // Parse opening '(' of the parameter list.+  if (Tok.is(tok::l_paren)) {+    SyntaxParsingContext DiffParamsContext(+        SyntaxContext, SyntaxKind::DifferentiationParams);+    consumeToken(tok::l_paren);+    // Parse first parameter. At least one is required.+    if (parseParam())+      return errorAndSkipToEnd(2);+    // Parse remaining parameters until ')'.+    while (Tok.isNot(tok::r_paren))+      if (parseParam())+        return errorAndSkipToEnd(2);+    SyntaxContext->collectNodesInPlace(SyntaxKind::DifferentiationParamList);+    // Parse closing ')' of the parameter list.+    consumeToken(tok::r_paren);+  }+  // If no opening '(' for parameter list, parse a single parameter.+  else {+    if (parseParam(/*parseTrailingComma*/ false))+      return errorAndSkipToEnd();+  }+  return false;+}++bool Parser::parseDifferentiableAttributeArguments(+    bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,+    Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,+    TrailingWhereClause *&whereClause) {+  StringRef AttrName = "differentiable";++  // Set parse error, skip until ')' and parse it.+  auto errorAndSkipToEnd = [&](int parenDepth = 1) -> bool {+    for (int i = 0; i < parenDepth; i++) {+      skipUntil(tok::r_paren);+      if (!consumeIf(tok::r_paren))+        diagnose(Tok, diag::attr_expected_rparen, AttrName,+                 /*DeclModifier=*/false);+    }+    return true;+  };++  // Parse trailing comma, if it exists, and check for errors.+  auto consumeIfTrailingComma = [&]() -> bool {+    if (!consumeIf(tok::comma)) return false;+    // Diagnose trailing comma before 'where' or ')'.+    if (Tok.is(tok::kw_where) || Tok.is(tok::r_paren)) {+      diagnose(Tok, diag::unexpected_separator, ",");+      return true;+    }+    // Check that token after comma is 'wrt:' or a function specifier label.+    if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" ||+                                      Tok.getText() == "jvp" ||+                                      Tok.getText() == "vjp")) {

Good to know, but it is important to keep master consistent and following best practices. If these aren't important, then it would be fine to remove them from the patch. If they need to be in the patch, then please consider implementing them in a nice way :).

I'm not saying that classifyLabel is an appropriate thing, just saying that the rationale doesn't make sense.

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

 bool Parser::parseMatchingToken(tok K, SourceLoc &TokLoc, Diag<> ErrorDiag,   return false; } +bool Parser::parseUnsignedInteger(unsigned &Result, SourceLoc &Loc,

Why do you need this new method? DAK_Alignment handling in ParseDecl.cpp does this:

StringRef alignmentText = Tok.getText(); unsigned alignmentValue; if (alignmentText.getAsInteger(0, alignmentValue)) { diagnose(Loc, diag::alignment_must_be_positive_integer); return false; }

Won't this work for you at the callsite?

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

 bool Parser::parseMatchingToken(tok K, SourceLoc &TokLoc, Diag<> ErrorDiag,   return false; } +bool Parser::parseUnsignedInteger(unsigned &Result, SourceLoc &Loc,

If you're trying to introduce a new generally useful utility, I'd recommend doing that as a separate patch which also migrates the existing callsites to it.

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

+//===--- AutoDiff.h - Swift Automatic Differentiation ---------------------===//+//+// This source file is part of the Swift.org open source project+//+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors+// Licensed under Apache License v2.0 with Runtime Library Exception+//+// See https://swift.org/LICENSE.txt for license information+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors+//+//===----------------------------------------------------------------------===//+//+//  SWIFT_ENABLE_TENSORFLOW+//  This file defines AST support for automatic differentiation.+//+//===----------------------------------------------------------------------===//++#ifndef SWIFT_AST_AUTODIFF_H+#define SWIFT_AST_AUTODIFF_H++#include "ASTContext.h"+#include "llvm/ADT/SmallBitVector.h"+#include "swift/Basic/Range.h"++namespace swift {++class ParsedAutoDiffParameter {+public:+  enum class Kind { Named, Ordered, Self };++private:+  SourceLoc Loc;+  Kind Kind;+  union Value {+    struct { Identifier Name; } Named;+    struct { unsigned Index; } Ordered;+    struct {} Self;+    Value(Identifier name) : Named({name}) {}+    Value(unsigned index) : Ordered({index}) {}+    Value() {}+  } V;++public:+  ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)+    : Loc(loc), Kind(kind), V(value) {}++  ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, unsigned index)+  : Loc(loc), Kind(kind), V(index) {}++  static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,+                                                   Identifier name) {+    return { loc, Kind::Named, name };+  }++  static ParsedAutoDiffParameter getOrderedParameter(SourceLoc loc,+                                                     unsigned index) {+    return { loc, Kind::Ordered, index };+  }++  static ParsedAutoDiffParameter getSelfParameter(SourceLoc loc) {+    return { loc, Kind::Self, {} };+  }++  Identifier getName() const {+    assert(Kind == Kind::Named);+    return V.Named.Name;+  }++  unsigned getIndex() const {+    return V.Ordered.Index;+  }++  enum Kind getKind() const {

nit: s/enum Kind/Kind/

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

 Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) {                            ProtocolType.get(), MemberName, MemberNameLoc)); } +ParserResult<DifferentiableAttr>+Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {+  StringRef AttrName = "differentiable";+  SourceLoc lParenLoc = loc, rParenLoc = loc;+  bool linear = false;+  SmallVector<ParsedAutoDiffParameter, 8> params;+  Optional<DeclNameWithLoc> jvpSpec;+  Optional<DeclNameWithLoc> vjpSpec;+  TrailingWhereClause *whereClause = nullptr;++  // Parse '('.+  if (consumeIf(tok::l_paren, lParenLoc)) {+    // Parse @differentiable attribute arguments.+    if (parseDifferentiableAttributeArguments(linear, params, jvpSpec, vjpSpec,+                                              whereClause))+      return makeParserError();+    // Parse ')'.+    if (!consumeIf(tok::r_paren, rParenLoc)) {+      diagnose(getEndOfPreviousLoc(), diag::attr_expected_rparen, AttrName,+               /*DeclModifier=*/false);+      return makeParserError();+    }+  }++  return ParserResult<DifferentiableAttr>(+      DifferentiableAttr::create(Context, /*implicit*/ false, atLoc,+                                 SourceRange(loc, rParenLoc), linear,+                                 params, jvpSpec, vjpSpec, whereClause));+}++bool Parser::parseDifferentiationParametersClause(+    SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName) {+  // Set parse error, skip until ')' and parse it.+  auto errorAndSkipToEnd = [&](int parenDepth = 1) -> bool {+    for (int i = 0; i < parenDepth; i++) {+      skipUntil(tok::r_paren);+      if (!consumeIf(tok::r_paren))+        diagnose(Tok, diag::attr_expected_rparen, attrName,+                 /*DeclModifier=*/false);+    }+    return true;+  };++  SyntaxParsingContext DiffParamsClauseContext(+       SyntaxContext, SyntaxKind::DifferentiationParamsClause);+  consumeToken(tok::identifier);+  if (!consumeIf(tok::colon)) {+    diagnose(Tok, diag::expected_colon_after_label, "wrt");+    return errorAndSkipToEnd();+  }++  // Function that parses a parameter into `params`. Returns true if error+  // occurred.+  auto parseParam = [&](bool parseTrailingComma = true) -> bool {+    SyntaxParsingContext DiffParamContext(+        SyntaxContext, SyntaxKind::DifferentiationParam);+    SourceLoc paramLoc;+    switch (Tok.getKind()) {+      case tok::identifier: {+        Identifier paramName;+        if (parseIdentifier(paramName, paramLoc,+                            diag::diff_params_clause_expected_parameter))+          return true;+        params.push_back(ParsedAutoDiffParameter::getNamedParameter(+            paramLoc, paramName));+        break;+      }+      case tok::integer_literal: {+        unsigned paramNum;+        if (parseUnsignedInteger(+                paramNum, paramLoc,+                diag::diff_params_clause_expected_parameter))+          return true;++        params.push_back(ParsedAutoDiffParameter::getOrderedParameter(+            paramLoc, paramNum));+        break;+      }+      case tok::kw_self: {+        paramLoc = consumeToken(tok::kw_self);+        params.push_back(ParsedAutoDiffParameter::getSelfParameter(paramLoc));+        break;+      }+      default:+        diagnose(Tok, diag::diff_params_clause_expected_parameter);+        return true;+    }+    if (parseTrailingComma && Tok.isNot(tok::r_paren))+      return parseToken(tok::comma, diag::attr_expected_comma, attrName,+                        /*isDeclModifier=*/false);+    return false;+  };++  // Parse opening '(' of the parameter list.+  if (Tok.is(tok::l_paren)) {+    SyntaxParsingContext DiffParamsContext(+        SyntaxContext, SyntaxKind::DifferentiationParams);+    consumeToken(tok::l_paren);+    // Parse first parameter. At least one is required.+    if (parseParam())+      return errorAndSkipToEnd(2);+    // Parse remaining parameters until ')'.+    while (Tok.isNot(tok::r_paren))+      if (parseParam())+        return errorAndSkipToEnd(2);+    SyntaxContext->collectNodesInPlace(SyntaxKind::DifferentiationParamList);+    // Parse closing ')' of the parameter list.+    consumeToken(tok::r_paren);+  }+  // If no opening '(' for parameter list, parse a single parameter.+  else {+    if (parseParam(/*parseTrailingComma*/ false))+      return errorAndSkipToEnd();+  }+  return false;+}++bool Parser::parseDifferentiableAttributeArguments(+    bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,+    Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,+    TrailingWhereClause *&whereClause) {+  StringRef AttrName = "differentiable";++  // Set parse error, skip until ')' and parse it.+  auto errorAndSkipToEnd = [&](int parenDepth = 1) -> bool {+    for (int i = 0; i < parenDepth; i++) {+      skipUntil(tok::r_paren);+      if (!consumeIf(tok::r_paren))+        diagnose(Tok, diag::attr_expected_rparen, AttrName,+                 /*DeclModifier=*/false);+    }+    return true;+  };++  // Parse trailing comma, if it exists, and check for errors.+  auto consumeIfTrailingComma = [&]() -> bool {+    if (!consumeIf(tok::comma)) return false;+    // Diagnose trailing comma before 'where' or ')'.+    if (Tok.is(tok::kw_where) || Tok.is(tok::r_paren)) {+      diagnose(Tok, diag::unexpected_separator, ",");+      return true;+    }+    // Check that token after comma is 'wrt:' or a function specifier label.+    if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" ||+                                      Tok.getText() == "jvp" ||+                                      Tok.getText() == "vjp")) {+      diagnose(Tok, diag::attr_differentiable_expected_label);+      return true;+    }+    return false;+  };++  // Store starting parser position.+  auto startingLoc = Tok.getLoc();+  SyntaxParsingContext ContentContext(+      SyntaxContext, SyntaxKind::DifferentiableAttributeArguments);++  // Parse optional differentiation parameters.+  // Parse 'linear' label (optional).+  linear = false;+  if (Tok.is(tok::identifier) && Tok.getText() == "linear") {+    linear = true;+    consumeToken(tok::identifier);+    // If no trailing comma or 'where' clause, terminate parsing arguments.+    if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))+      return false;+    if (consumeIfTrailingComma())+      return errorAndSkipToEnd();+  }++  // If 'withRespectTo' is used, make the user change it to 'wrt'.+  if (Tok.is(tok::identifier) && Tok.getText() == "withRespectTo") {+    SourceRange withRespectToRange(Tok.getLoc(), peekToken().getLoc());+    diagnose(Tok, diag::attr_differentiable_use_wrt_not_withrespectto)+        .highlight(withRespectToRange)+        .fixItReplace(withRespectToRange, "wrt:");+    return errorAndSkipToEnd();+  }+  if (Tok.is(tok::identifier) && Tok.getText() == "wrt") {+    if (parseDifferentiationParametersClause(params, AttrName))+      return true;+    // If no trailing comma or 'where' clause, terminate parsing arguments.+    if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))+      return false;+    if (consumeIfTrailingComma())+      return errorAndSkipToEnd();+  }++  // Function that parses a label and a function specifier, e.g. 'vjp: foo(_:)'.+  // Return true on error.+  auto parseFuncSpec = [&](StringRef label, DeclNameWithLoc &result,+                           bool &terminateParsingArgs) -> bool {+    // Parse label.+    if (parseSpecificIdentifier(label,+            diag::attr_differentiable_missing_label, label) ||+        parseToken(tok::colon, diag::expected_colon_after_label, label))+      return true;+    // Parse the name of the function.+    SyntaxParsingContext FuncDeclNameContext(+         SyntaxContext, SyntaxKind::FunctionDeclName);+    Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID,+                        { label });+    result.Name =+        parseUnqualifiedDeclName(/*afterDot=*/false, result.Loc,+                                 funcDiag, /*allowOperators=*/true,+                                 /*allowZeroArgCompoundNames=*/true);+    // If no trailing comma or 'where' clause, terminate parsing arguments.+    if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))

Tok.isNot takes multiple arguments, please use that instead of && here and anywhere else this comes up.

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

 Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) {                            ProtocolType.get(), MemberName, MemberNameLoc)); } +ParserResult<DifferentiableAttr>+Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {+  StringRef AttrName = "differentiable";+  SourceLoc lParenLoc = loc, rParenLoc = loc;+  bool linear = false;+  SmallVector<ParsedAutoDiffParameter, 8> params;+  Optional<DeclNameWithLoc> jvpSpec;+  Optional<DeclNameWithLoc> vjpSpec;+  TrailingWhereClause *whereClause = nullptr;++  // Parse '('.+  if (consumeIf(tok::l_paren, lParenLoc)) {+    // Parse @differentiable attribute arguments.+    if (parseDifferentiableAttributeArguments(linear, params, jvpSpec, vjpSpec,+                                              whereClause))+      return makeParserError();+    // Parse ')'.+    if (!consumeIf(tok::r_paren, rParenLoc)) {+      diagnose(getEndOfPreviousLoc(), diag::attr_expected_rparen, AttrName,+               /*DeclModifier=*/false);+      return makeParserError();+    }+  }++  return ParserResult<DifferentiableAttr>(+      DifferentiableAttr::create(Context, /*implicit*/ false, atLoc,+                                 SourceRange(loc, rParenLoc), linear,+                                 params, jvpSpec, vjpSpec, whereClause));+}++bool Parser::parseDifferentiationParametersClause(+    SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName) {+  // Set parse error, skip until ')' and parse it.+  auto errorAndSkipToEnd = [&](int parenDepth = 1) -> bool {+    for (int i = 0; i < parenDepth; i++) {+      skipUntil(tok::r_paren);+      if (!consumeIf(tok::r_paren))+        diagnose(Tok, diag::attr_expected_rparen, attrName,+                 /*DeclModifier=*/false);+    }+    return true;+  };++  SyntaxParsingContext DiffParamsClauseContext(+       SyntaxContext, SyntaxKind::DifferentiationParamsClause);+  consumeToken(tok::identifier);+  if (!consumeIf(tok::colon)) {+    diagnose(Tok, diag::expected_colon_after_label, "wrt");+    return errorAndSkipToEnd();+  }++  // Function that parses a parameter into `params`. Returns true if error+  // occurred.+  auto parseParam = [&](bool parseTrailingComma = true) -> bool {+    SyntaxParsingContext DiffParamContext(+        SyntaxContext, SyntaxKind::DifferentiationParam);+    SourceLoc paramLoc;+    switch (Tok.getKind()) {+      case tok::identifier: {+        Identifier paramName;+        if (parseIdentifier(paramName, paramLoc,+                            diag::diff_params_clause_expected_parameter))+          return true;+        params.push_back(ParsedAutoDiffParameter::getNamedParameter(+            paramLoc, paramName));+        break;+      }+      case tok::integer_literal: {+        unsigned paramNum;+        if (parseUnsignedInteger(+                paramNum, paramLoc,+                diag::diff_params_clause_expected_parameter))+          return true;++        params.push_back(ParsedAutoDiffParameter::getOrderedParameter(+            paramLoc, paramNum));+        break;+      }+      case tok::kw_self: {+        paramLoc = consumeToken(tok::kw_self);+        params.push_back(ParsedAutoDiffParameter::getSelfParameter(paramLoc));+        break;+      }+      default:+        diagnose(Tok, diag::diff_params_clause_expected_parameter);+        return true;+    }+    if (parseTrailingComma && Tok.isNot(tok::r_paren))+      return parseToken(tok::comma, diag::attr_expected_comma, attrName,+                        /*isDeclModifier=*/false);+    return false;+  };++  // Parse opening '(' of the parameter list.+  if (Tok.is(tok::l_paren)) {+    SyntaxParsingContext DiffParamsContext(+        SyntaxContext, SyntaxKind::DifferentiationParams);+    consumeToken(tok::l_paren);+    // Parse first parameter. At least one is required.+    if (parseParam())+      return errorAndSkipToEnd(2);+    // Parse remaining parameters until ')'.+    while (Tok.isNot(tok::r_paren))+      if (parseParam())+        return errorAndSkipToEnd(2);+    SyntaxContext->collectNodesInPlace(SyntaxKind::DifferentiationParamList);+    // Parse closing ')' of the parameter list.+    consumeToken(tok::r_paren);+  }+  // If no opening '(' for parameter list, parse a single parameter.+  else {+    if (parseParam(/*parseTrailingComma*/ false))+      return errorAndSkipToEnd();+  }+  return false;+}++bool Parser::parseDifferentiableAttributeArguments(+    bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,+    Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,+    TrailingWhereClause *&whereClause) {+  StringRef AttrName = "differentiable";++  // Set parse error, skip until ')' and parse it.+  auto errorAndSkipToEnd = [&](int parenDepth = 1) -> bool {+    for (int i = 0; i < parenDepth; i++) {+      skipUntil(tok::r_paren);+      if (!consumeIf(tok::r_paren))+        diagnose(Tok, diag::attr_expected_rparen, AttrName,+                 /*DeclModifier=*/false);+    }+    return true;+  };++  // Parse trailing comma, if it exists, and check for errors.+  auto consumeIfTrailingComma = [&]() -> bool {+    if (!consumeIf(tok::comma)) return false;+    // Diagnose trailing comma before 'where' or ')'.+    if (Tok.is(tok::kw_where) || Tok.is(tok::r_paren)) {+      diagnose(Tok, diag::unexpected_separator, ",");+      return true;+    }+    // Check that token after comma is 'wrt:' or a function specifier label.+    if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" ||+                                      Tok.getText() == "jvp" ||+                                      Tok.getText() == "vjp")) {

Would it help to have a function like:

enum { wrt, jvp, vjp, invalid } classifyLabel(StringRef str);

function to keep all the classification logic in sync, and move the magic string parsing into a single place?

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

+// RUN: %target-swift-frontend -parse -verify %s++/// Good

nice, thank you for the testcases!

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

+//===--- AutoDiff.h - Swift Automatic Differentiation ---------------------===//+//+// This source file is part of the Swift.org open source project+//+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors+// Licensed under Apache License v2.0 with Runtime Library Exception+//+// See https://swift.org/LICENSE.txt for license information+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors+//+//===----------------------------------------------------------------------===//+//+//  SWIFT_ENABLE_TENSORFLOW+//  This file defines AST support for automatic differentiation.+//+//===----------------------------------------------------------------------===//++#ifndef SWIFT_AST_AUTODIFF_H+#define SWIFT_AST_AUTODIFF_H++#include "ASTContext.h"+#include "llvm/ADT/SmallBitVector.h"

Also, is ASTContext.h necessary here? Can you forward declare or include smaller headers? ASTContext.h includes the world.

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

+//===--- AutoDiff.h - Swift Automatic Differentiation ---------------------===//+//+// This source file is part of the Swift.org open source project+//+// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors+// Licensed under Apache License v2.0 with Runtime Library Exception+//+// See https://swift.org/LICENSE.txt for license information+// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors+//+//===----------------------------------------------------------------------===//+//+//  SWIFT_ENABLE_TENSORFLOW+//  This file defines AST support for automatic differentiation.+//+//===----------------------------------------------------------------------===//++#ifndef SWIFT_AST_AUTODIFF_H+#define SWIFT_AST_AUTODIFF_H++#include "ASTContext.h"+#include "llvm/ADT/SmallBitVector.h"

SmallBitVector.h looks unused.

bgogul

comment created time in 3 months

Pull request review commentapple/swift

[AutoDiff upstream] Introduce @differentiable attribute to mark functions as differentiable.

 SIMPLE_DECL_ATTR(_nonEphemeral, NonEphemeral,   ABIStableToAdd | ABIStableToRemove | APIBreakingToAdd | APIStableToRemove,   90) +DECL_ATTR(differentiable, Differentiable,+  OnAccessor | OnConstructor | OnFunc | OnVar | OnSubscript | LongAttribute |+  AllowMultipleAttributes |+  ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,

I believe this affects ABI, so I'd recommend removing the "ABIStableToAdd | ABIStableToRemove" tags, as well as APIStableToRemove which seems inappropriate.

bgogul

comment created time in 3 months

startedllvm/llvm-project

started time in 4 months

startedgoogle-research/swift-tfp

started time in 4 months

pull request commentapple/swift

[AutoDiff upstream] Add `-enable-experimental-differentiable-programming` frontend flag.

@jrose-apple @jckarter @rjmccall Anyone have an opinion on the driver flag naming? Ted approved landing this work in master under a flag, but we need a flag :-).

dan-zheng

comment created time in 5 months

more